diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 393e862b7a..7dd1ded48e 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -16,11 +16,11 @@ jobs: strategy: matrix: os: [ubuntu-latest] - python: ['3.10', '3.12'] - torch: [{base: '1.13.0', vision: '0.14.0'}, {base: '2.5.1', vision: '0.20.1'}] + python: ['3.10', '3.13'] + torch: [{base: '1.13.0', vision: '0.14.0'}, {base: '2.9.1', vision: '0.24.1'}] testmarker: ['-k "not test_models"', '-m base', '-m cfg', '-m torchscript', '-m features', '-m fxforward', '-m fxbackward'] exclude: - - python: '3.12' + - python: '3.13' torch: {base: '1.13.0', vision: '0.14.0'} runs-on: ${{ matrix.os }} diff --git a/tests/test_layers_pool.py b/tests/test_layers_pool.py new file mode 100644 index 0000000000..a282d74340 --- /dev/null +++ b/tests/test_layers_pool.py @@ -0,0 +1,358 @@ +"""Tests for timm pooling layers.""" +import pytest +import torch +import torch.nn as nn + +import importlib +import os + +torch_backend = os.environ.get('TORCH_BACKEND') +if torch_backend is not None: + importlib.import_module(torch_backend) +torch_device = os.environ.get('TORCH_DEVICE', 'cpu') + + +# Adaptive Avg/Max Pooling Tests + +class TestAdaptiveAvgMaxPool: + """Test adaptive_avgmax_pool module.""" + + def test_adaptive_avgmax_pool2d(self): + from timm.layers import adaptive_avgmax_pool2d + x = torch.randn(2, 64, 7, 7, device=torch_device) + out = adaptive_avgmax_pool2d(x, 1) + assert out.shape == (2, 64, 1, 1) + # Should be average of avg and max + expected = 0.5 * (x.mean(dim=(2, 3), keepdim=True) + x.amax(dim=(2, 3), keepdim=True)) + assert torch.allclose(out, expected) + + def test_select_adaptive_pool2d(self): + from timm.layers import select_adaptive_pool2d + x = torch.randn(2, 64, 7, 7, device=torch_device) + + out_avg = select_adaptive_pool2d(x, pool_type='avg', output_size=1) + assert out_avg.shape == (2, 64, 1, 1) + assert torch.allclose(out_avg, x.mean(dim=(2, 3), keepdim=True)) + + out_max = select_adaptive_pool2d(x, pool_type='max', output_size=1) + assert out_max.shape == (2, 64, 1, 1) + assert torch.allclose(out_max, x.amax(dim=(2, 3), keepdim=True)) + + def test_adaptive_avgmax_pool2d_module(self): + from timm.layers import AdaptiveAvgMaxPool2d + x = torch.randn(2, 64, 14, 14, device=torch_device) + pool = AdaptiveAvgMaxPool2d(output_size=1).to(torch_device) + out = pool(x) + assert out.shape == (2, 64, 1, 1) + + def test_select_adaptive_pool2d_module(self): + from timm.layers import SelectAdaptivePool2d + x = torch.randn(2, 64, 7, 7, device=torch_device) + + for pool_type in ['avg', 'max', 'avgmax', 'catavgmax']: + pool = SelectAdaptivePool2d(pool_type=pool_type, flatten=True).to(torch_device) + out = pool(x) + if pool_type == 'catavgmax': + assert out.shape == (2, 128) # concatenated + else: + assert out.shape == (2, 64) + + def test_select_adaptive_pool2d_fast(self): + from timm.layers import SelectAdaptivePool2d + x = torch.randn(2, 64, 7, 7, device=torch_device) + + for pool_type in ['fast', 'fastavg', 'fastmax', 'fastavgmax', 'fastcatavgmax']: + pool = SelectAdaptivePool2d(pool_type=pool_type, flatten=True).to(torch_device) + out = pool(x) + if 'cat' in pool_type: + assert out.shape == (2, 128) + else: + assert out.shape == (2, 64) + + +# Attention Pool Tests + +class TestAttentionPool: + """Test attention-based pooling layers.""" + + def test_attention_pool_latent_basic(self): + from timm.layers import AttentionPoolLatent + x = torch.randn(2, 49, 64, device=torch_device) + pool = AttentionPoolLatent(in_features=64, num_heads=4).to(torch_device) + out = pool(x) + assert out.shape == (2, 64) + + def test_attention_pool_latent_multi_latent(self): + from timm.layers import AttentionPoolLatent + x = torch.randn(2, 49, 64, device=torch_device) + pool = AttentionPoolLatent( + in_features=64, + num_heads=4, + latent_len=4, + pool_type='avg', + ).to(torch_device) + out = pool(x) + assert out.shape == (2, 64) + + def test_attention_pool2d_basic(self): + from timm.layers import AttentionPool2d + x = torch.randn(2, 64, 7, 7, device=torch_device) + pool = AttentionPool2d(in_features=64, feat_size=7).to(torch_device) + out = pool(x) + assert out.shape == (2, 64) + + def test_attention_pool2d_different_feat_size(self): + from timm.layers import AttentionPool2d + # Test with different spatial sizes (requires pos_embed interpolation) + pool = AttentionPool2d(in_features=64, feat_size=7).to(torch_device) + for size in [7, 14]: + x = torch.randn(2, 64, size, size, device=torch_device) + out = pool(x) + assert out.shape == (2, 64) + + def test_rot_attention_pool2d_basic(self): + from timm.layers import RotAttentionPool2d + x = torch.randn(2, 64, 7, 7, device=torch_device) + pool = RotAttentionPool2d(in_features=64, ref_feat_size=7).to(torch_device) + out = pool(x) + assert out.shape == (2, 64) + + def test_rot_attention_pool2d_different_sizes(self): + from timm.layers import RotAttentionPool2d + pool = RotAttentionPool2d(in_features=64, ref_feat_size=7).to(torch_device) + for size in [7, 14, 10]: + x = torch.randn(2, 64, size, size, device=torch_device) + out = pool(x) + assert out.shape == (2, 64) + + def test_rot_attention_pool2d_rope_types(self): + from timm.layers import RotAttentionPool2d + x = torch.randn(2, 64, 7, 7, device=torch_device) + for rope_type in ['base', 'cat', 'dinov3']: + pool = RotAttentionPool2d( + in_features=64, + ref_feat_size=7, + rope_type=rope_type, + ).to(torch_device) + out = pool(x) + assert out.shape == (2, 64) + + +# LSE Pool Tests + +class TestLsePool: + """Test LogSumExp pooling layers.""" + + def test_lse_plus_2d_basic(self): + from timm.layers import LsePlus2d + x = torch.randn(2, 64, 7, 7, device=torch_device) + pool = LsePlus2d().to(torch_device) + out = pool(x) + # Default is flatten=True + assert out.shape == (2, 64) + + def test_lse_plus_2d_no_flatten(self): + from timm.layers import LsePlus2d + x = torch.randn(2, 64, 7, 7, device=torch_device) + pool = LsePlus2d(flatten=False).to(torch_device) + out = pool(x) + assert out.shape == (2, 64, 1, 1) + + def test_lse_plus_1d_basic(self): + from timm.layers import LsePlus1d + x = torch.randn(2, 49, 64, device=torch_device) + pool = LsePlus1d().to(torch_device) + out = pool(x) + assert out.shape == (2, 64) + + def test_lse_high_r_approximates_max(self): + from timm.layers import LsePlus2d + x = torch.randn(2, 64, 7, 7, device=torch_device) + pool = LsePlus2d(r=100.0, r_learnable=False).to(torch_device) + out = pool(x) + out_max = x.amax(dim=(2, 3)) + assert torch.allclose(out, out_max, atol=0.1) + + def test_lse_low_r_approximates_avg(self): + from timm.layers import LsePlus2d + x = torch.randn(2, 64, 7, 7, device=torch_device) + pool = LsePlus2d(r=0.01, r_learnable=False).to(torch_device) + out = pool(x) + out_avg = x.mean(dim=(2, 3)) + assert torch.allclose(out, out_avg, atol=0.1) + + def test_lse_learnable_r_gradient(self): + from timm.layers import LsePlus2d + x = torch.randn(2, 64, 7, 7, device=torch_device) + pool = LsePlus2d(r=10.0, r_learnable=True).to(torch_device) + out = pool(x).sum() + out.backward() + assert pool.r.grad is not None + assert pool.r.grad.abs() > 0 + + +# SimPool Tests + +class TestSimPool: + """Test SimPool attention-based pooling layers.""" + + def test_simpool_2d_basic(self): + from timm.layers import SimPool2d + x = torch.randn(2, 64, 7, 7, device=torch_device) + pool = SimPool2d(dim=64).to(torch_device) + out = pool(x) + assert out.shape == (2, 64) + + def test_simpool_1d_basic(self): + from timm.layers import SimPool1d + x = torch.randn(2, 49, 64, device=torch_device) + pool = SimPool1d(dim=64).to(torch_device) + out = pool(x) + assert out.shape == (2, 64) + + def test_simpool_multi_head(self): + from timm.layers import SimPool2d + x = torch.randn(2, 64, 7, 7, device=torch_device) + for num_heads in [1, 2, 4, 8]: + pool = SimPool2d(dim=64, num_heads=num_heads).to(torch_device) + out = pool(x) + assert out.shape == (2, 64) + + def test_simpool_with_gamma(self): + from timm.layers import SimPool2d + x = torch.randn(2, 64, 7, 7, device=torch_device) + pool = SimPool2d(dim=64, gamma=2.0).to(torch_device) + out = pool(x) + assert out.shape == (2, 64) + assert not torch.isnan(out).any() + + def test_simpool_qk_norm(self): + from timm.layers import SimPool2d + x = torch.randn(2, 64, 7, 7, device=torch_device) + pool = SimPool2d(dim=64, qk_norm=True).to(torch_device) + out = pool(x) + assert out.shape == (2, 64) + + +# Common Tests (Gradient, JIT, dtype) + +class TestPoolingCommon: + """Common tests across all pooling layers.""" + + @pytest.mark.parametrize('pool_cls,kwargs,input_shape', [ + ('LsePlus2d', {}, (2, 64, 7, 7)), + ('LsePlus1d', {}, (2, 49, 64)), + ('SimPool2d', {'dim': 64}, (2, 64, 7, 7)), + ('SimPool1d', {'dim': 64}, (2, 49, 64)), + ('SelectAdaptivePool2d', {'pool_type': 'avg', 'flatten': True}, (2, 64, 7, 7)), + ('AttentionPoolLatent', {'in_features': 64, 'num_heads': 4}, (2, 49, 64)), + ('AttentionPool2d', {'in_features': 64, 'feat_size': 7}, (2, 64, 7, 7)), + ('RotAttentionPool2d', {'in_features': 64, 'ref_feat_size': 7}, (2, 64, 7, 7)), + ]) + def test_gradient_flow(self, pool_cls, kwargs, input_shape): + import timm.layers as layers + x = torch.randn(*input_shape, device=torch_device, requires_grad=True) + pool = getattr(layers, pool_cls)(**kwargs).to(torch_device) + out = pool(x) + loss = out.sum() + loss.backward() + assert x.grad is not None + assert x.grad.abs().sum() > 0 + + @pytest.mark.parametrize('pool_cls,kwargs,input_shape', [ + ('LsePlus2d', {}, (2, 64, 7, 7)), + ('LsePlus1d', {}, (2, 49, 64)), + ('SimPool2d', {'dim': 64}, (2, 64, 7, 7)), + ('SimPool1d', {'dim': 64}, (2, 49, 64)), + ('AttentionPool2d', {'in_features': 64, 'feat_size': 7}, (2, 64, 7, 7)), + ('RotAttentionPool2d', {'in_features': 64, 'ref_feat_size': 7}, (2, 64, 7, 7)), + ]) + def test_torchscript(self, pool_cls, kwargs, input_shape): + import timm.layers as layers + x = torch.randn(*input_shape, device=torch_device) + pool = getattr(layers, pool_cls)(**kwargs).to(torch_device) + pool.eval() + scripted = torch.jit.script(pool) + out_orig = pool(x) + out_script = scripted(x) + assert torch.allclose(out_orig, out_script, atol=1e-5) + + @pytest.mark.parametrize('pool_cls,kwargs,input_shape', [ + ('LsePlus2d', {}, (2, 64, 7, 7)), + ('LsePlus1d', {}, (2, 49, 64)), + ('SimPool2d', {'dim': 64}, (2, 64, 7, 7)), + ('SimPool1d', {'dim': 64}, (2, 49, 64)), + ('AttentionPool2d', {'in_features': 64, 'feat_size': 7}, (2, 64, 7, 7)), + ('RotAttentionPool2d', {'in_features': 64, 'ref_feat_size': 7}, (2, 64, 7, 7)), + ]) + def test_eval_deterministic(self, pool_cls, kwargs, input_shape): + import timm.layers as layers + x = torch.randn(*input_shape, device=torch_device) + pool = getattr(layers, pool_cls)(**kwargs).to(torch_device) + pool.eval() + with torch.no_grad(): + out1 = pool(x) + out2 = pool(x) + assert torch.allclose(out1, out2) + + @pytest.mark.parametrize('pool_cls,kwargs,input_shape', [ + ('LsePlus2d', {}, (2, 64, 7, 7)), + ('SimPool2d', {'dim': 64}, (2, 64, 7, 7)), + ('RotAttentionPool2d', {'in_features': 64, 'ref_feat_size': 7}, (2, 64, 7, 7)), + ]) + def test_different_spatial_sizes(self, pool_cls, kwargs, input_shape): + import timm.layers as layers + B, C, _, _ = input_shape + pool = getattr(layers, pool_cls)(**kwargs).to(torch_device) + for H, W in [(7, 7), (14, 14), (1, 1), (3, 5)]: + x = torch.randn(B, C, H, W, device=torch_device) + out = pool(x) + assert out.shape[0] == B + assert out.shape[-1] == C + + +# BlurPool Tests + +class TestBlurPool: + """Test BlurPool anti-aliasing layer.""" + + def test_blur_pool_2d_basic(self): + from timm.layers import BlurPool2d + x = torch.randn(2, 64, 14, 14, device=torch_device) + pool = BlurPool2d(channels=64).to(torch_device) + out = pool(x) + assert out.shape == (2, 64, 7, 7) + + def test_blur_pool_2d_stride(self): + from timm.layers import BlurPool2d + x = torch.randn(2, 64, 28, 28, device=torch_device) + pool = BlurPool2d(channels=64, stride=4).to(torch_device) + out = pool(x) + assert out.shape == (2, 64, 8, 8) + + +# Pool1d Tests + +class TestPool1d: + """Test 1D pooling utilities.""" + + def test_global_pool_nlc(self): + from timm.layers import global_pool_nlc + x = torch.randn(2, 49, 64, device=torch_device) + + # By default, avg/max excludes first token (num_prefix_tokens=1) + out_avg = global_pool_nlc(x, pool_type='avg') + assert out_avg.shape == (2, 64) + assert torch.allclose(out_avg, x[:, 1:].mean(dim=1)) + + out_max = global_pool_nlc(x, pool_type='max') + assert out_max.shape == (2, 64) + assert torch.allclose(out_max, x[:, 1:].amax(dim=1)) + + out_first = global_pool_nlc(x, pool_type='token') + assert out_first.shape == (2, 64) + assert torch.allclose(out_first, x[:, 0]) + + # Test with reduce_include_prefix=True + out_avg_all = global_pool_nlc(x, pool_type='avg', reduce_include_prefix=True) + assert torch.allclose(out_avg_all, x.mean(dim=1)) diff --git a/timm/layers/__init__.py b/timm/layers/__init__.py index f40a1f77cb..f9148db665 100644 --- a/timm/layers/__init__.py +++ b/timm/layers/__init__.py @@ -18,7 +18,7 @@ from .attention import Attention, AttentionRope, maybe_add_mask from .attention2d import MultiQueryAttention2d, Attention2d, MultiQueryAttentionV2 from .attention_pool import AttentionPoolLatent -from .attention_pool2d import AttentionPool2d, RotAttentionPool2d, RotaryEmbedding +from .attention_pool2d import AttentionPool2d, RotAttentionPool2d from .blur_pool import BlurPool2d, create_aa from .classifier import create_classifier, ClassifierHead, NormMlpClassifierHead, ClNormMlpClassifierHead from .cond_conv2d import CondConv2d, get_condconv_initializer @@ -107,6 +107,7 @@ from .patch_dropout import PatchDropout, PatchDropoutWithIndices, patch_dropout_forward from .patch_embed import PatchEmbed, PatchEmbedWithSize, PatchEmbedInterpolator, resample_patch_embed from .pool1d import global_pool_nlc +from .other_pool import LsePlus2d, LsePlus1d, SimPool2d, SimPool1d from .pool2d_same import AvgPool2dSame, create_pool2d from .pos_embed import resample_abs_pos_embed, resample_abs_pos_embed_nhwc from .pos_embed_rel import ( diff --git a/timm/layers/attention_pool2d.py b/timm/layers/attention_pool2d.py index cc26aecdf4..6a813630d3 100644 --- a/timm/layers/attention_pool2d.py +++ b/timm/layers/attention_pool2d.py @@ -15,7 +15,7 @@ from .config import use_fused_attn from .helpers import to_2tuple from .pos_embed import resample_abs_pos_embed -from .pos_embed_sincos import apply_rot_embed, RotaryEmbedding +from .pos_embed_sincos import apply_rot_embed_cat, create_rope_embed from .weight_init import trunc_normal_ @@ -44,6 +44,7 @@ def __init__( pool_type: str = 'token', class_token: bool = False, drop_rate: float = 0., + rope_type: str = 'cat', device=None, dtype=None, ): @@ -65,6 +66,7 @@ def __init__( self.pool_type = pool_type.lower() self.scale = self.head_dim ** -0.5 self.fused_attn = use_fused_attn() + self.rope_type = rope_type if class_token: self.cls_token = nn.Parameter(torch.zeros(1, embed_dim, **dd)) @@ -80,7 +82,16 @@ def __init__( self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias, **dd) self.drop = nn.Dropout(drop_rate) self.proj = nn.Linear(embed_dim, self.out_features, **dd) - self.pos_embed = RotaryEmbedding(self.head_dim, in_pixels=False, ref_feat_shape=ref_feat_size, **dd) + + self.pos_embed = create_rope_embed( + rope_type=rope_type, + dim=embed_dim, + num_heads=num_heads, + in_pixels=False, + ref_feat_shape=ref_feat_size, + rotate_half=False, + **dd, + ) def init_weights(self, zero_init_last: bool = False): if self.qkv is None: @@ -129,9 +140,12 @@ def forward(self, x, pre_logits: bool = False): x = self.qkv(x).reshape(B, N + 1, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) q, k, v = x.unbind(0) - rse, rce = self.pos_embed.get_embed((H, W)) - q = torch.cat([q[:, :, :1, :], apply_rot_embed(q[:, :, 1:, :], rse, rce)], dim=2).type_as(v) - k = torch.cat([k[:, :, :1, :], apply_rot_embed(k[:, :, 1:, :], rse, rce)], dim=2).type_as(v) + rope = self.pos_embed.get_embed((H, W)) + if isinstance(rope, tuple): + # RotaryEmbedding returns (sin, cos) tuple - concatenate for apply_rot_embed_cat + rope = torch.cat(rope, dim=-1) + q = torch.cat([q[:, :, :1, :], apply_rot_embed_cat(q[:, :, 1:, :], rope)], dim=2).type_as(v) + k = torch.cat([k[:, :, :1, :], apply_rot_embed_cat(k[:, :, 1:, :], rope)], dim=2).type_as(v) if self.fused_attn: x = nn.functional.scaled_dot_product_attention(q, k, v) diff --git a/timm/layers/other_pool.py b/timm/layers/other_pool.py new file mode 100644 index 0000000000..cced920fbf --- /dev/null +++ b/timm/layers/other_pool.py @@ -0,0 +1,286 @@ +""" Non-Local Attention Pooling Layers + +A collection of global pooling layers that go beyond simple avg/max pooling. + +LSEPool - LogSumExp pooling, a smooth approximation between avg and max pooling +SimPool - Attention-based pooling from 'Keep It SimPool' (ICCV 2023) + +Based on implementations from: +* LSE Pooling: custom implementation by Bill Psomas +* SimPool: https://arxiv.org/abs/2309.06891 - 'Keep It SimPool: Who Said Supervised Transformers + Suffer from Attention Deficit?' by Bill Psomas et al. + +Hacked together by / Copyright 2024 Ross Wightman, original code by Bill Psomas +""" +from typing import Optional, Type, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .config import use_fused_attn + + +class LsePlus2d(nn.Module): + """LogSumExp (LSE) Pooling for 2D inputs. + + A smooth approximation to max pooling that provides a learnable interpolation between + average and max pooling. When r is large, LSE approaches max pooling; when r is small, + it approaches average pooling. + + Implements: (1/r) * log((1/n) * sum(exp(r * (x - x_max)))) + x_max + + The x_max subtraction provides numerical stability. + """ + + def __init__( + self, + r: float = 10.0, + r_learnable: bool = True, + flatten: bool = True, + device=None, + dtype=None, + ): + """ + Args: + r: Initial value of the pooling parameter. Higher = closer to max pooling. + r_learnable: If True, r is a learnable parameter. + flatten: If True, flatten spatial dims in output. + """ + super().__init__() + if r_learnable: + self.r = nn.Parameter(torch.tensor(r, device=device, dtype=dtype)) + else: + self.register_buffer('r', torch.tensor(r, device=device, dtype=dtype)) + self.flatten = flatten + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x_max = F.adaptive_max_pool2d(x, 1) + exp_x = torch.exp(self.r * (x - x_max)) + sum_exp = exp_x.mean(dim=(2, 3), keepdim=True) + out = x_max + (1.0 / self.r) * torch.log(sum_exp) + if self.flatten: + out = out.flatten(1) + return out + + +class LsePlus1d(nn.Module): + """LogSumExp (LSE) Pooling for sequence (NLC) inputs. + + A smooth approximation to max pooling that provides a learnable interpolation between + average and max pooling. When r is large, LSE approaches max pooling; when r is small, + it approaches average pooling. + """ + + def __init__( + self, + r: float = 10.0, + r_learnable: bool = True, + device=None, + dtype=None, + ): + """ + Args: + r: Initial value of the pooling parameter. Higher = closer to max pooling. + r_learnable: If True, r is a learnable parameter. + """ + super().__init__() + if r_learnable: + self.r = nn.Parameter(torch.tensor(r, device=device, dtype=dtype)) + else: + self.register_buffer('r', torch.tensor(r, device=device, dtype=dtype)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # x: (B, N, C) + x_max = x.max(dim=1, keepdim=True).values + exp_x = torch.exp(self.r * (x - x_max)) + sum_exp = exp_x.mean(dim=1, keepdim=True) + out = x_max + (1.0 / self.r) * torch.log(sum_exp) + return out.squeeze(1) # (B, C) + + +class SimPool2d(nn.Module): + """SimPool: Simple Attention-Based Pooling for 2D (NCHW) inputs. + + From 'Keep It SimPool: Who Said Supervised Transformers Suffer from Attention Deficit?' + https://arxiv.org/abs/2309.06891 + + Uses GAP as query initialization and applies cross-attention between the GAP query + and spatial features to produce a weighted pooled representation. + """ + fused_attn: torch.jit.Final[bool] + + def __init__( + self, + dim: int, + num_heads: int = 1, + qkv_bias: bool = False, + qk_norm: bool = False, + gamma: Optional[float] = None, + norm_layer: Optional[Type[nn.Module]] = None, + device=None, + dtype=None, + ): + """ + Args: + dim: Input feature dimension (number of channels). + num_heads: Number of attention heads. + qkv_bias: If True, add bias to query and key projections. + qk_norm: If True, apply normalization to queries and keys. + gamma: If provided, apply power normalization to values with this exponent. + norm_layer: Normalization layer for patches and optionally qk_norm. + flatten: If True, flatten output to (B, C). + """ + super().__init__() + dd = {'device': device, 'dtype': dtype} + assert dim % num_heads == 0, 'dim must be divisible by num_heads' + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim ** -0.5 + self.gamma = gamma + self.fused_attn = use_fused_attn() + + norm_layer = norm_layer or nn.LayerNorm + self.norm = norm_layer(dim, **dd) + self.q = nn.Linear(dim, dim, bias=qkv_bias, **dd) + self.k = nn.Linear(dim, dim, bias=qkv_bias, **dd) + if qk_norm: + self.q_norm = norm_layer(self.head_dim, **dd) + self.k_norm = norm_layer(self.head_dim, **dd) + else: + self.q_norm = nn.Identity() + self.k_norm = nn.Identity() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, C, H, W = x.shape + N = H * W + + # Reshape to (B, N, C) for attention + x = x.flatten(2).transpose(1, 2) # (B, N, C) + + # GAP as query initialization + q = x.mean(dim=1, keepdim=True) # (B, 1, C) + + # Normalize patches for keys and values + x_norm = self.norm(x) + + # Project query and keys + q = self.q(q).reshape(B, 1, self.num_heads, self.head_dim).transpose(1, 2) + k = self.k(x_norm).reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2) + v = x_norm.reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2) + + q, k = self.q_norm(q), self.k_norm(k) + + if self.gamma is not None: + # Power normalization on values + v_min = v.amin(dim=-2, keepdim=True) + v_shifted = v - v_min + 1e-6 + if self.fused_attn: + attn_out = F.scaled_dot_product_attention(q, k, v_shifted.pow(self.gamma)) + else: + attn = (q * self.scale) @ k.transpose(-2, -1) + attn = attn.softmax(dim=-1) + attn_out = attn @ v_shifted.pow(self.gamma) + out = attn_out.pow(1.0 / self.gamma) + else: + if self.fused_attn: + out = F.scaled_dot_product_attention(q, k, v) + else: + attn = (q * self.scale) @ k.transpose(-2, -1) + attn = attn.softmax(dim=-1) + out = attn @ v + + # (B, num_heads, 1, head_dim) -> (B, C) or (B, C) + out = out.transpose(1, 2).reshape(B, C) + return out + + +class SimPool1d(nn.Module): + """SimPool: Simple Attention-Based Pooling for sequence (NLC) inputs. + + From 'Keep It SimPool: Who Said Supervised Transformers Suffer from Attention Deficit?' + https://arxiv.org/abs/2309.06891 + + Uses GAP as query initialization and applies cross-attention between the GAP query + and sequence tokens to produce a weighted pooled representation. + """ + fused_attn: torch.jit.Final[bool] + + def __init__( + self, + dim: int, + num_heads: int = 1, + qkv_bias: bool = False, + qk_norm: bool = False, + gamma: Optional[float] = None, + norm_layer: Optional[Type[nn.Module]] = None, + device=None, + dtype=None, + ): + """ + Args: + dim: Input feature dimension. + num_heads: Number of attention heads. + qkv_bias: If True, add bias to query and key projections. + qk_norm: If True, apply normalization to queries and keys. + gamma: If provided, apply power normalization to values with this exponent. + norm_layer: Normalization layer for tokens and optionally qk_norm. + """ + super().__init__() + dd = {'device': device, 'dtype': dtype} + assert dim % num_heads == 0, 'dim must be divisible by num_heads' + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim ** -0.5 + self.gamma = gamma + self.fused_attn = use_fused_attn() + + norm_layer = norm_layer or nn.LayerNorm + self.norm = norm_layer(dim, **dd) + self.q = nn.Linear(dim, dim, bias=qkv_bias, **dd) + self.k = nn.Linear(dim, dim, bias=qkv_bias, **dd) + if qk_norm: + self.q_norm = norm_layer(self.head_dim, **dd) + self.k_norm = norm_layer(self.head_dim, **dd) + else: + self.q_norm = nn.Identity() + self.k_norm = nn.Identity() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, N, C = x.shape + + # GAP as query initialization + q = x.mean(dim=1, keepdim=True) # (B, 1, C) + + # Normalize tokens for keys and values + x_norm = self.norm(x) + + # Project query and keys + q = self.q(q).reshape(B, 1, self.num_heads, self.head_dim).transpose(1, 2) + k = self.k(x_norm).reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2) + v = x_norm.reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2) + + q, k = self.q_norm(q), self.k_norm(k) + + if self.gamma is not None: + # Power normalization on values + v_min = v.amin(dim=-2, keepdim=True) + v_shifted = v - v_min + 1e-6 + if self.fused_attn: + attn_out = F.scaled_dot_product_attention(q, k, v_shifted.pow(self.gamma)) + else: + attn = (q * self.scale) @ k.transpose(-2, -1) + attn = attn.softmax(dim=-1) + attn_out = attn @ v_shifted.pow(self.gamma) + out = attn_out.pow(1.0 / self.gamma) + else: + if self.fused_attn: + out = F.scaled_dot_product_attention(q, k, v) + else: + attn = (q * self.scale) @ k.transpose(-2, -1) + attn = attn.softmax(dim=-1) + out = attn @ v + + # (B, num_heads, 1, head_dim) -> (B, C) + out = out.transpose(1, 2).reshape(B, C) + return out diff --git a/timm/layers/pos_embed_sincos.py b/timm/layers/pos_embed_sincos.py index 9d314e3a26..dde0d5e1a6 100644 --- a/timm/layers/pos_embed_sincos.py +++ b/timm/layers/pos_embed_sincos.py @@ -1164,13 +1164,19 @@ def create_rope_embed( Rotary embedding module """ if rope_type == 'base': + kwargs.pop('rotate_half', None) # doesn't support return RotaryEmbedding(dim=dim // num_heads, **kwargs) elif rope_type == 'cat': + kwargs.pop('rotate_half', None) # doesn't support return RotaryEmbeddingCat(dim=dim // num_heads, **kwargs) elif rope_type == 'mixed': # Mixed requires depth parameter, generates differing embeddings per layer and head + kwargs.pop('in_pixels', None) # doesn't support + kwargs.pop('ref_feat_shape', None) # doesn't support return RotaryEmbeddingMixed(dim=dim, num_heads=num_heads, **kwargs) elif rope_type == 'dinov3': + kwargs.pop('in_pixels', None) # doesn't support + kwargs.pop('ref_feat_shape', None) # doesn't support return RotaryEmbeddingDinoV3(dim=dim // num_heads, **kwargs) else: raise ValueError(f"Unknown RoPE type: {rope_type}") diff --git a/timm/models/byobnet.py b/timm/models/byobnet.py index 0e33aabcfd..75bbde3a2f 100644 --- a/timm/models/byobnet.py +++ b/timm/models/byobnet.py @@ -42,6 +42,7 @@ NormMlpClassifierHead, ConvNormAct, BatchNormAct2d, + DropBlock2d, EvoNorm2dS0a, AttentionPool2d, RotAttentionPool2d, @@ -1339,11 +1340,42 @@ def update_block_kwargs(block_kwargs: Dict[str, Any], block_cfg: ByoBlockCfg, mo block_kwargs.update(override_kwargs(block_cfg.block_kwargs, model_cfg.block_kwargs)) +def drop_blocks( + drop_prob: float = 0., + block_size: int = 3, + num_stages: int = 4, +) -> List[Optional[partial]]: + """Create DropBlock layer partials for each stage. + + DropBlock is applied to the last two stages only, following common practice. + The block_size specifies the size for the final stage; the second-to-last + stage uses a larger block size scaled to account for 2x larger feature maps. + + Args: + drop_prob: Drop probability for DropBlock. + block_size: Block size for the final stage. Second-to-last stage + uses `block_size * 2 - 1` to scale with feature map size. + num_stages: Number of stages in the model. + + Returns: + List of DropBlock partial instances or None for each stage. + """ + assert num_stages >= 2 + dbs = [None] * num_stages + if drop_prob: + # Scale block size for second-to-last stage (2x larger feature maps) + dbs[-2] = partial(DropBlock2d, drop_prob=drop_prob, block_size=block_size * 2 - 1, gamma_scale=0.25) + dbs[-1] = partial(DropBlock2d, drop_prob=drop_prob, block_size=block_size, gamma_scale=1.00) + return dbs + + def create_byob_stages( cfg: ByoModelCfg, drop_path_rate: float, output_stride: int, stem_feat: Dict[str, Any], + drop_block_rate: float = 0., + drop_block_size: int = 3, feat_size: Optional[int] = None, layers: Optional[LayerFn] = None, block_kwargs_fn: Optional[Callable] = update_block_kwargs, @@ -1353,8 +1385,10 @@ def create_byob_stages( layers = layers or LayerFn() feature_info = [] block_cfgs = [expand_blocks_cfg(s) for s in cfg.blocks] + num_stages = len(block_cfgs) depths = [sum([bc.d for bc in stage_bcs]) for stage_bcs in block_cfgs] dpr = calculate_drop_path_rates(drop_path_rate, depths, stagewise=True) + dbs = drop_blocks(drop_block_rate, drop_block_size, num_stages) dilation = 1 net_stride = stem_feat['reduction'] prev_chs = stem_feat['num_chs'] @@ -1384,6 +1418,7 @@ def create_byob_stages( group_size=group_size, bottle_ratio=block_cfg.br, downsample=cfg.downsample, + drop_block=dbs[stage_idx], drop_path_rate=dpr[stage_idx][block_idx], layers=layers, device=device, @@ -1437,6 +1472,8 @@ def __init__( output_stride: int = 32, img_size: Optional[Union[int, Tuple[int, int]]] = None, drop_rate: float = 0., + drop_block_rate: float = 0., + drop_block_size: int = 3, drop_path_rate: float = 0., zero_init_last: bool = True, device=None, @@ -1452,6 +1489,8 @@ def __init__( output_stride: Output stride of network, one of (8, 16, 32). img_size: Image size for fixed image size models (i.e. self-attn). drop_rate: Classifier dropout rate. + drop_block_rate: DropBlock drop rate. + drop_block_size: DropBlock block size for final stage (scales up for earlier stages). drop_path_rate: Stochastic depth drop-path rate. zero_init_last: Zero-init last weight of residual path. **kwargs: Extra kwargs overlayed onto cfg. @@ -1490,6 +1529,8 @@ def __init__( drop_path_rate, output_stride, stem_feat[-1], + drop_block_rate=drop_block_rate, + drop_block_size=drop_block_size, layers=stage_layers, feat_size=feat_size, **dd,