diff --git a/timm/layers/__init__.py b/timm/layers/__init__.py index 0c52123c76..f40a1f77cb 100644 --- a/timm/layers/__init__.py +++ b/timm/layers/__init__.py @@ -42,7 +42,7 @@ from .create_conv2d import create_conv2d from .create_norm import get_norm_layer, create_norm_layer from .create_norm_act import get_norm_act_layer, create_norm_act_layer, get_norm_act_layer -from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path +from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path, calculate_drop_path_rates from .eca import EcaModule, CecaModule, EfficientChannelAttn, CircularEfficientChannelAttn from .evo_norm import ( EvoNorm2dB0, diff --git a/timm/layers/drop.py b/timm/layers/drop.py index 289245f5ad..a2e59dcfa0 100644 --- a/timm/layers/drop.py +++ b/timm/layers/drop.py @@ -14,6 +14,8 @@ Hacked together by / Copyright 2020 Ross Wightman """ +from typing import List, Union + import torch import torch.nn as nn import torch.nn.functional as F @@ -180,3 +182,44 @@ def forward(self, x): def extra_repr(self): return f'drop_prob={round(self.drop_prob,3):0.3f}' + + +def calculate_drop_path_rates( + drop_path_rate: float, + depths: Union[int, List[int]], + stagewise: bool = False, +) -> Union[List[float], List[List[float]]]: + """Generate drop path rates for stochastic depth. + + This function handles two common patterns for drop path rate scheduling: + 1. Per-block: Linear increase from 0 to drop_path_rate across all blocks + 2. Stage-wise: Linear increase across stages, with same rate within each stage + + Args: + drop_path_rate: Maximum drop path rate (at the end). + depths: Either a single int for total depth (per-block mode) or + list of ints for depths per stage (stage-wise mode). + stagewise: If True, use stage-wise pattern. If False, use per-block pattern. + When depths is a list, stagewise defaults to True. + + Returns: + For per-block mode: List of drop rates, one per block. + For stage-wise mode: List of lists, drop rates per stage. + """ + if isinstance(depths, int): + # Single depth value - per-block pattern + if stagewise: + raise ValueError("stagewise=True requires depths to be a list of stage depths") + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depths, device='cpu')] + return dpr + else: + # List of depths - can be either pattern + total_depth = sum(depths) + if stagewise: + # Stage-wise pattern: same drop rate within each stage + dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, total_depth, device='cpu').split(depths)] + return dpr + else: + # Per-block pattern across all stages + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, total_depth, device='cpu')] + return dpr diff --git a/timm/models/_efficientnet_blocks.py b/timm/models/_efficientnet_blocks.py index ab5864a455..0903b9e595 100644 --- a/timm/models/_efficientnet_blocks.py +++ b/timm/models/_efficientnet_blocks.py @@ -150,7 +150,7 @@ def __init__( if s2d == 1: sd_chs = int(in_chs * 4) self.conv_s2d = create_conv2d(in_chs, sd_chs, kernel_size=2, stride=2, padding='same') - self.bn_s2d = norm_act_layer(sd_chs, sd_chs) + self.bn_s2d = norm_act_layer(sd_chs) dw_kernel_size = (dw_kernel_size + 1) // 2 dw_pad_type = 'same' if dw_kernel_size == 2 else pad_type in_chs = sd_chs diff --git a/timm/models/beit.py b/timm/models/beit.py index 5b5973801d..04ef1ffa75 100644 --- a/timm/models/beit.py +++ b/timm/models/beit.py @@ -46,7 +46,7 @@ import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import PatchEmbed, Mlp, SwiGLU, LayerNorm, DropPath, trunc_normal_, use_fused_attn +from timm.layers import PatchEmbed, Mlp, SwiGLU, LayerNorm, DropPath, calculate_drop_path_rates, trunc_normal_, use_fused_attn from timm.layers import resample_patch_embed, resample_abs_pos_embed, resize_rel_pos_bias_table, ndgrid from ._builder import build_model_with_cfg @@ -448,7 +448,7 @@ def __init__( else: self.rel_pos_bias = None - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + dpr = calculate_drop_path_rates(drop_path_rate, depth) # stochastic depth decay rule self.blocks = nn.ModuleList([ Block( dim=embed_dim, diff --git a/timm/models/byobnet.py b/timm/models/byobnet.py index e37d25b6ef..e90f2cc9b6 100644 --- a/timm/models/byobnet.py +++ b/timm/models/byobnet.py @@ -39,7 +39,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD from timm.layers import ( ClassifierHead, NormMlpClassifierHead, ConvNormAct, BatchNormAct2d, EvoNorm2dS0a, - AttentionPool2d, RotAttentionPool2d, DropPath, AvgPool2dSame, + AttentionPool2d, RotAttentionPool2d, DropPath, calculate_drop_path_rates, AvgPool2dSame, create_conv2d, get_act_layer, get_norm_act_layer, get_attn, make_divisible, to_2tuple, ) from ._builder import build_model_with_cfg @@ -1212,7 +1212,7 @@ def create_byob_stages( feature_info = [] block_cfgs = [expand_blocks_cfg(s) for s in cfg.blocks] depths = [sum([bc.d for bc in stage_bcs]) for stage_bcs in block_cfgs] - dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)] + dpr = calculate_drop_path_rates(drop_path_rate, depths, stagewise=True) dilation = 1 net_stride = stem_feat['reduction'] prev_chs = stem_feat['num_chs'] diff --git a/timm/models/coat.py b/timm/models/coat.py index e0b0bcfb84..3fa4f69666 100644 --- a/timm/models/coat.py +++ b/timm/models/coat.py @@ -417,9 +417,7 @@ def __init__( self.crpe3 = ConvRelPosEnc(head_chs=embed_dims[2] // num_heads, num_heads=num_heads, window=crpe_window) self.crpe4 = ConvRelPosEnc(head_chs=embed_dims[3] // num_heads, num_heads=num_heads, window=crpe_window) - # Disable stochastic depth. dpr = drop_path_rate - assert dpr == 0.0 skwargs = dict( num_heads=num_heads, qkv_bias=qkv_bias, diff --git a/timm/models/convit.py b/timm/models/convit.py index cbe3b51ece..3dd8adfd23 100644 --- a/timm/models/convit.py +++ b/timm/models/convit.py @@ -27,7 +27,7 @@ import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import DropPath, trunc_normal_, PatchEmbed, Mlp, LayerNorm, HybridEmbed +from timm.layers import DropPath, calculate_drop_path_rates, trunc_normal_, PatchEmbed, Mlp, LayerNorm, HybridEmbed from ._builder import build_model_with_cfg from ._features_fx import register_notrace_module from ._registry import register_model, generate_default_cfgs @@ -292,7 +292,7 @@ def __init__( self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) trunc_normal_(self.pos_embed, std=.02) - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + dpr = calculate_drop_path_rates(drop_path_rate, depth) # stochastic depth decay rule self.blocks = nn.ModuleList([ Block( dim=embed_dim, diff --git a/timm/models/convnext.py b/timm/models/convnext.py index cdc34eba2f..c236571058 100644 --- a/timm/models/convnext.py +++ b/timm/models/convnext.py @@ -44,7 +44,7 @@ import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD -from timm.layers import trunc_normal_, AvgPool2dSame, DropPath, Mlp, GlobalResponseNormMlp, \ +from timm.layers import trunc_normal_, AvgPool2dSame, DropPath, calculate_drop_path_rates, Mlp, GlobalResponseNormMlp, \ LayerNorm2d, LayerNorm, RmsNorm2d, RmsNorm, create_conv2d, get_act_layer, get_norm_layer, make_divisible, to_ntuple from timm.layers import SimpleNorm2d, SimpleNorm from timm.layers import NormMlpClassifierHead, ClassifierHead @@ -377,7 +377,7 @@ def __init__( stem_stride = 4 self.stages = nn.Sequential() - dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)] + dp_rates = calculate_drop_path_rates(drop_path_rate, depths, stagewise=True) stages = [] prev_chs = dims[0] curr_stride = stem_stride diff --git a/timm/models/crossvit.py b/timm/models/crossvit.py index 0e1de2fadd..c2f32b754a 100644 --- a/timm/models/crossvit.py +++ b/timm/models/crossvit.py @@ -27,7 +27,7 @@ import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import DropPath, to_2tuple, trunc_normal_, _assert +from timm.layers import DropPath, calculate_drop_path_rates, to_2tuple, trunc_normal_, _assert from ._builder import build_model_with_cfg from ._features_fx import register_notrace_function from ._registry import register_model, generate_default_cfgs @@ -346,7 +346,7 @@ def __init__( self.pos_drop = nn.Dropout(p=pos_drop_rate) total_depth = sum([sum(x[-2:]) for x in depth]) - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, total_depth)] # stochastic depth decay rule + dpr = calculate_drop_path_rates(drop_path_rate, total_depth) # stochastic depth decay rule dpr_ptr = 0 self.blocks = nn.ModuleList() for idx, block_cfg in enumerate(depth): diff --git a/timm/models/cspnet.py b/timm/models/cspnet.py index 81d11a0654..5dfebc56a9 100644 --- a/timm/models/cspnet.py +++ b/timm/models/cspnet.py @@ -20,7 +20,7 @@ import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import ClassifierHead, ConvNormAct, DropPath, get_attn, create_act_layer, make_divisible +from timm.layers import ClassifierHead, ConvNormAct, DropPath, calculate_drop_path_rates, get_attn, create_act_layer, make_divisible from ._builder import build_model_with_cfg from ._manipulate import named_apply, MATCH_PREV_GROUP from ._registry import register_model, generate_default_cfgs @@ -569,7 +569,7 @@ def create_csp_stages( cfg_dict = asdict(cfg.stages) num_stages = len(cfg.stages.depth) cfg_dict['block_dpr'] = [None] * num_stages if not drop_path_rate else \ - [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg.stages.depth)).split(cfg.stages.depth)] + calculate_drop_path_rates(drop_path_rate, cfg.stages.depth, stagewise=True) stage_args = [dict(zip(cfg_dict.keys(), values)) for values in zip(*cfg_dict.values())] block_kwargs = dict( act_layer=cfg.act_layer, diff --git a/timm/models/davit.py b/timm/models/davit.py index 22b4a1a05f..87d90b062f 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -20,7 +20,7 @@ from torch import Tensor from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import DropPath, to_2tuple, trunc_normal_, Mlp, LayerNorm2d, get_norm_layer, use_fused_attn +from timm.layers import DropPath, calculate_drop_path_rates, to_2tuple, trunc_normal_, Mlp, LayerNorm2d, get_norm_layer, use_fused_attn from timm.layers import NormMlpClassifierHead, ClassifierHead from ._builder import build_model_with_cfg from ._features import feature_take_indices @@ -555,7 +555,7 @@ def __init__( self.stem = Stem(in_chans, embed_dims[0], norm_layer=norm_layer) in_chs = embed_dims[0] - dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)] + dpr = calculate_drop_path_rates(drop_path_rate, depths, stagewise=True) stages = [] for i in range(num_stages): out_chs = embed_dims[i] diff --git a/timm/models/edgenext.py b/timm/models/edgenext.py index 333f73f0d5..c63d485d77 100644 --- a/timm/models/edgenext.py +++ b/timm/models/edgenext.py @@ -16,7 +16,7 @@ from torch import nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import trunc_normal_tf_, DropPath, LayerNorm2d, Mlp, create_conv2d, \ +from timm.layers import trunc_normal_tf_, DropPath, calculate_drop_path_rates, LayerNorm2d, Mlp, create_conv2d, \ NormMlpClassifierHead, ClassifierHead from ._builder import build_model_with_cfg from ._features import feature_take_indices @@ -343,7 +343,7 @@ def __init__( curr_stride = 4 stages = [] - dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)] + dp_rates = calculate_drop_path_rates(drop_path_rate, depths, stagewise=True) in_chs = dims[0] for i in range(4): stride = 2 if curr_stride == 2 or i > 0 else 1 diff --git a/timm/models/efficientformer.py b/timm/models/efficientformer.py index 669062d365..75e5615e00 100644 --- a/timm/models/efficientformer.py +++ b/timm/models/efficientformer.py @@ -18,7 +18,7 @@ import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import DropPath, trunc_normal_, to_2tuple, Mlp, ndgrid +from timm.layers import DropPath, calculate_drop_path_rates, trunc_normal_, to_2tuple, Mlp, ndgrid from ._builder import build_model_with_cfg from ._features import feature_take_indices from ._manipulate import checkpoint_seq @@ -385,7 +385,7 @@ def __init__( # stochastic depth decay rule self.num_stages = len(depths) last_stage = self.num_stages - 1 - dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)] + dpr = calculate_drop_path_rates(drop_path_rate, depths, stagewise=True) downsamples = downsamples or (False,) + (True,) * (self.num_stages - 1) stages = [] self.feature_info = [] diff --git a/timm/models/efficientformer_v2.py b/timm/models/efficientformer_v2.py index 39c832c8c2..5bc2112f20 100644 --- a/timm/models/efficientformer_v2.py +++ b/timm/models/efficientformer_v2.py @@ -23,7 +23,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import create_conv2d, create_norm_layer, get_act_layer, get_norm_layer, ConvNormAct -from timm.layers import DropPath, trunc_normal_, to_2tuple, to_ntuple, ndgrid +from timm.layers import DropPath, calculate_drop_path_rates, trunc_normal_, to_2tuple, to_ntuple, ndgrid from ._builder import build_model_with_cfg from ._features import feature_take_indices from ._manipulate import checkpoint_seq @@ -543,7 +543,7 @@ def __init__( stride = 4 num_stages = len(depths) - dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)] + dpr = calculate_drop_path_rates(drop_path_rate, depths, stagewise=True) downsamples = downsamples or (False,) + (True,) * (len(depths) - 1) mlp_ratios = to_ntuple(num_stages)(mlp_ratios) stages = [] diff --git a/timm/models/efficientnet.py b/timm/models/efficientnet.py index d1976944da..d6140de9f3 100644 --- a/timm/models/efficientnet.py +++ b/timm/models/efficientnet.py @@ -1434,6 +1434,12 @@ def _cfg(url='', **kwargs): 'efficientnet_b3_g8_gn.untrained': _cfg( input_size=(3, 288, 288), pool_size=(9, 9), test_input_size=(3, 320, 320), crop_pct=1.0), 'efficientnet_blur_b0.untrained': _cfg(), + 'efficientnet_h_b5.untrained': _cfg( + url='', input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0), + 'efficientnet_x_b3.untrained': _cfg( + url='', input_size=(3, 288, 288), pool_size=(9, 9), crop_pct=0.95), + 'efficientnet_x_b5.untrained': _cfg( + url='', input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0), 'efficientnet_es.ra_in1k': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_es_ra-f111e99c.pth', @@ -2708,7 +2714,7 @@ def efficientnet_x_b3(pretrained=False, **kwargs) -> EfficientNet: """ EfficientNet-B3 """ # NOTE for train, drop_rate should be 0.3, drop_path_rate should be 0.2 model = _gen_efficientnet_x( - 'efficientnet_b3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) + 'efficientnet_x_b3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) return model @@ -2716,7 +2722,7 @@ def efficientnet_x_b3(pretrained=False, **kwargs) -> EfficientNet: def efficientnet_x_b5(pretrained=False, **kwargs) -> EfficientNet: """ EfficientNet-B5 """ model = _gen_efficientnet_x( - 'efficientnet_b5', channel_multiplier=1.6, depth_multiplier=2.2, pretrained=pretrained, **kwargs) + 'efficientnet_x_b5', channel_multiplier=1.6, depth_multiplier=2.2, pretrained=pretrained, **kwargs) return model @@ -2724,7 +2730,7 @@ def efficientnet_x_b5(pretrained=False, **kwargs) -> EfficientNet: def efficientnet_h_b5(pretrained=False, **kwargs) -> EfficientNet: """ EfficientNet-B5 """ model = _gen_efficientnet_x( - 'efficientnet_b5', channel_multiplier=1.92, depth_multiplier=2.2, version=2, pretrained=pretrained, **kwargs) + 'efficientnet_h_b5', channel_multiplier=1.92, depth_multiplier=2.2, version=2, pretrained=pretrained, **kwargs) return model diff --git a/timm/models/eva.py b/timm/models/eva.py index 514d590e99..f2da2ba55e 100644 --- a/timm/models/eva.py +++ b/timm/models/eva.py @@ -78,7 +78,7 @@ GluMlp, SwiGLU, LayerNorm, - DropPath, + DropPath, calculate_drop_path_rates, PatchDropoutWithIndices, create_rope_embed, apply_rot_embed_cat, @@ -365,7 +365,7 @@ def __init__( attn_type: str = 'eva', rotate_half: bool = False, swiglu_mlp: bool = False, - swiglu_aligh_to: int = 0, + swiglu_align_to: int = 0, scale_mlp: bool = False, scale_attn_inner: bool = False, num_prefix_tokens: int = 1, @@ -425,6 +425,7 @@ def __init__( hidden_features=hidden_features, norm_layer=norm_layer if scale_mlp else None, drop=proj_drop, + align_to=swiglu_align_to, ) else: # w/o any extra norm, an impl with packed fc1 weights is used, matches existing GluMLP @@ -637,7 +638,7 @@ def __init__( self.norm_pre = norm_layer(embed_dim) if activate_pre_norm else nn.Identity() - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + dpr = calculate_drop_path_rates(drop_path_rate, depth) # stochastic depth decay rule block_fn = EvaBlockPostNorm if use_post_norm else EvaBlock self.blocks = nn.ModuleList([ block_fn( diff --git a/timm/models/fasternet.py b/timm/models/fasternet.py index b9e4aed249..747cdf6b39 100644 --- a/timm/models/fasternet.py +++ b/timm/models/fasternet.py @@ -23,7 +23,7 @@ import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import SelectAdaptivePool2d, Linear, DropPath, trunc_normal_, LayerType +from timm.layers import SelectAdaptivePool2d, Linear, DropPath, trunc_normal_, LayerType, calculate_drop_path_rates from ._builder import build_model_with_cfg from ._features import feature_take_indices from ._manipulate import checkpoint_seq @@ -216,7 +216,7 @@ def __init__( norm_layer=norm_layer if patch_norm else nn.Identity, ) # stochastic depth decay rule - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] + dpr = calculate_drop_path_rates(drop_path_rate, sum(depths)) # build layers stages_list = [] diff --git a/timm/models/fastvit.py b/timm/models/fastvit.py index 9939095c43..be47d378af 100644 --- a/timm/models/fastvit.py +++ b/timm/models/fastvit.py @@ -14,7 +14,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD from timm.layers import ( - DropPath, + DropPath, calculate_drop_path_rates, trunc_normal_, create_conv2d, ConvNormAct, @@ -1156,7 +1156,7 @@ def __init__( # Build the main stages of the network architecture prev_dim = embed_dims[0] scale = 1 - dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(layers)).split(layers)] + dpr = calculate_drop_path_rates(drop_path_rate, layers, stagewise=True) stages = [] for i in range(len(layers)): downsample = downsamples[i] or prev_dim != embed_dims[i] diff --git a/timm/models/focalnet.py b/timm/models/focalnet.py index 3c2bd75643..775b45dc2f 100644 --- a/timm/models/focalnet.py +++ b/timm/models/focalnet.py @@ -24,7 +24,7 @@ import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import Mlp, DropPath, LayerNorm2d, trunc_normal_, ClassifierHead, NormMlpClassifierHead +from timm.layers import Mlp, DropPath, LayerNorm2d, trunc_normal_, ClassifierHead, NormMlpClassifierHead, calculate_drop_path_rates from ._builder import build_model_with_cfg from ._features import feature_take_indices from ._manipulate import named_apply, checkpoint @@ -378,7 +378,7 @@ def __init__( ) in_dim = embed_dim[0] - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + dpr = calculate_drop_path_rates(drop_path_rate, sum(depths)) # stochastic depth decay rule layers = [] for i_layer in range(self.num_layers): out_dim = embed_dim[i_layer] diff --git a/timm/models/gcvit.py b/timm/models/gcvit.py index 367e5dfff5..b7569a63b5 100644 --- a/timm/models/gcvit.py +++ b/timm/models/gcvit.py @@ -27,7 +27,7 @@ import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import DropPath, to_2tuple, to_ntuple, Mlp, ClassifierHead, LayerNorm2d, \ +from timm.layers import DropPath, calculate_drop_path_rates, to_2tuple, to_ntuple, Mlp, ClassifierHead, LayerNorm2d, \ get_attn, get_act_layer, get_norm_layer, RelPosBias, _assert from ._builder import build_model_with_cfg from ._features import feature_take_indices @@ -419,7 +419,7 @@ def __init__( norm_layer=norm_layer ) - dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)] + dpr = calculate_drop_path_rates(drop_path_rate, depths, stagewise=True) stages = [] for i in range(num_stages): last_stage = i == num_stages - 1 diff --git a/timm/models/hgnet.py b/timm/models/hgnet.py index 5c49f9ddca..d6eee3f76b 100644 --- a/timm/models/hgnet.py +++ b/timm/models/hgnet.py @@ -13,7 +13,7 @@ import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import SelectAdaptivePool2d, DropPath, create_conv2d +from timm.layers import SelectAdaptivePool2d, DropPath, calculate_drop_path_rates, create_conv2d from ._builder import build_model_with_cfg from ._features import feature_take_indices from ._registry import register_model, generate_default_cfgs @@ -448,7 +448,7 @@ def __init__( stages = [] self.feature_info = [] block_depths = [c[3] for c in stages_cfg] - dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(block_depths)).split(block_depths)] + dpr = calculate_drop_path_rates(drop_path_rate, block_depths, stagewise=True) for i, stage_config in enumerate(stages_cfg): in_chs, mid_chs, out_chs, block_num, downsample, light_block, kernel_size, layer_num = stage_config stages += [HighPerfGpuStage( diff --git a/timm/models/hiera.py b/timm/models/hiera.py index fa9d6d2833..0e9e86441a 100644 --- a/timm/models/hiera.py +++ b/timm/models/hiera.py @@ -31,7 +31,7 @@ import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import DropPath, Mlp, LayerScale, ClNormMlpClassifierHead, use_fused_attn, \ +from timm.layers import DropPath, calculate_drop_path_rates, Mlp, LayerScale, ClNormMlpClassifierHead, use_fused_attn, \ _assert, get_norm_layer, to_2tuple, init_weight_vit, init_weight_jax from ._registry import generate_default_cfgs, register_model @@ -515,7 +515,7 @@ def __init__( # Transformer blocks cur_stage = 0 depth = sum(stages) - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + dpr = calculate_drop_path_rates(drop_path_rate, depth) # stochastic depth decay rule self.blocks = nn.ModuleList() self.feature_info = [] for i in range(depth): diff --git a/timm/models/hieradet_sam2.py b/timm/models/hieradet_sam2.py index cd8bc6c8e9..9e7b4634ef 100644 --- a/timm/models/hieradet_sam2.py +++ b/timm/models/hieradet_sam2.py @@ -8,7 +8,7 @@ import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import PatchEmbed, Mlp, DropPath, ClNormMlpClassifierHead, LayerScale, \ +from timm.layers import PatchEmbed, Mlp, DropPath, calculate_drop_path_rates, ClNormMlpClassifierHead, LayerScale, \ get_norm_layer, get_act_layer, init_weight_jax, init_weight_vit, to_2tuple, use_fused_attn from ._builder import build_model_with_cfg @@ -325,7 +325,7 @@ def __init__( self.pos_embed = nn.Parameter(torch.zeros(1, embed_dim, *self.global_pos_size)) self.pos_embed_window = nn.Parameter(torch.zeros(1, embed_dim, self.window_spec[0], self.window_spec[0])) - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + dpr = calculate_drop_path_rates(drop_path_rate, depth) # stochastic depth decay rule cur_stage = 0 self.blocks = nn.Sequential() self.feature_info = [] diff --git a/timm/models/inception_next.py b/timm/models/inception_next.py index 3159ac03be..326e081732 100644 --- a/timm/models/inception_next.py +++ b/timm/models/inception_next.py @@ -10,7 +10,7 @@ import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import trunc_normal_, DropPath, to_2tuple, get_padding, SelectAdaptivePool2d +from timm.layers import trunc_normal_, DropPath, calculate_drop_path_rates, to_2tuple, get_padding, SelectAdaptivePool2d from ._builder import build_model_with_cfg from ._features import feature_take_indices from ._manipulate import checkpoint_seq @@ -283,7 +283,7 @@ def __init__( norm_layer(dims[0]) ) - dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)] + dp_rates = calculate_drop_path_rates(drop_path_rate, depths, stagewise=True) prev_chs = dims[0] curr_stride = 4 dilation = 1 diff --git a/timm/models/mambaout.py b/timm/models/mambaout.py index b28827ed71..c33c8d75e8 100644 --- a/timm/models/mambaout.py +++ b/timm/models/mambaout.py @@ -12,7 +12,7 @@ from torch import nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import trunc_normal_, DropPath, LayerNorm, LayerScale, ClNormMlpClassifierHead, get_act_layer +from timm.layers import trunc_normal_, DropPath, calculate_drop_path_rates, LayerNorm, LayerScale, ClNormMlpClassifierHead, get_act_layer from ._builder import build_model_with_cfg from ._features import feature_take_indices from ._manipulate import checkpoint_seq @@ -338,7 +338,7 @@ def __init__( norm_layer=norm_layer, ) prev_dim = dims[0] - dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)] + dp_rates = calculate_drop_path_rates(drop_path_rate, depths, stagewise=True) cur = 0 curr_stride = 4 self.stages = nn.Sequential() diff --git a/timm/models/maxxvit.py b/timm/models/maxxvit.py index 85cc3b4238..3be91f3f94 100644 --- a/timm/models/maxxvit.py +++ b/timm/models/maxxvit.py @@ -45,7 +45,7 @@ from torch.jit import Final from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import Mlp, ConvMlp, DropPath, LayerNorm, ClassifierHead, NormMlpClassifierHead +from timm.layers import Mlp, ConvMlp, DropPath, calculate_drop_path_rates, LayerNorm, ClassifierHead, NormMlpClassifierHead from timm.layers import create_attn, get_act_layer, get_norm_layer, get_norm_act_layer, create_conv2d, create_pool2d from timm.layers import trunc_normal_tf_, to_2tuple, extend_tuple, make_divisible, _assert from timm.layers import RelPosMlp, RelPosBias, RelPosBiasTf, use_fused_attn, resize_rel_pos_bias_table @@ -1341,7 +1341,7 @@ def __init__( num_stages = len(cfg.embed_dim) assert len(cfg.depths) == num_stages - dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg.depths)).split(cfg.depths)] + dpr = calculate_drop_path_rates(drop_path_rate, cfg.depths, stagewise=True) in_chs = self.stem.out_chs stages = [] for i in range(num_stages): diff --git a/timm/models/metaformer.py b/timm/models/metaformer.py index 3364a79563..0a46624976 100644 --- a/timm/models/metaformer.py +++ b/timm/models/metaformer.py @@ -37,7 +37,7 @@ from torch.jit import Final from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import trunc_normal_, DropPath, SelectAdaptivePool2d, GroupNorm1, LayerNorm, LayerNorm2d, Mlp, \ +from timm.layers import trunc_normal_, DropPath, calculate_drop_path_rates, SelectAdaptivePool2d, GroupNorm1, LayerNorm, LayerNorm2d, Mlp, \ use_fused_attn from ._builder import build_model_with_cfg from ._features import feature_take_indices @@ -524,7 +524,7 @@ def __init__( stages = [] prev_dim = dims[0] - dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)] + dp_rates = calculate_drop_path_rates(drop_path_rate, depths, stagewise=True) for i in range(self.num_stages): stages += [MetaFormerStage( prev_dim, diff --git a/timm/models/mvitv2.py b/timm/models/mvitv2.py index 01c4550ed8..514e8733f7 100644 --- a/timm/models/mvitv2.py +++ b/timm/models/mvitv2.py @@ -23,7 +23,7 @@ from torch import nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import Mlp, DropPath, trunc_normal_tf_, get_norm_layer, to_2tuple +from timm.layers import Mlp, DropPath, calculate_drop_path_rates, trunc_normal_tf_, get_norm_layer, to_2tuple from ._builder import build_model_with_cfg from ._features import feature_take_indices from ._features_fx import register_notrace_function @@ -749,7 +749,7 @@ def __init__( num_stages = len(cfg.embed_dim) feat_size = patch_dims curr_stride = max(cfg.patch_stride) - dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg.depths)).split(cfg.depths)] + dpr = calculate_drop_path_rates(drop_path_rate, cfg.depths, stagewise=True) self.stages = nn.ModuleList() self.feature_info = [] for i in range(num_stages): diff --git a/timm/models/naflexvit.py b/timm/models/naflexvit.py index f8e27b5044..67f658b3c4 100644 --- a/timm/models/naflexvit.py +++ b/timm/models/naflexvit.py @@ -39,6 +39,7 @@ get_norm_layer, apply_keep_indices_nlc, disable_compiler, + calculate_drop_path_rates, ) from ._builder import build_model_with_cfg from ._features import feature_take_indices @@ -1182,7 +1183,7 @@ def __init__( self.patch_drop = None # Transformer blocks - dpr = [x.item() for x in torch.linspace(0, cfg.drop_path_rate, cfg.depth)] # stochastic depth decay rule + dpr = calculate_drop_path_rates(cfg.drop_path_rate, cfg.depth) # stochastic depth decay rule # Create transformer blocks self.blocks = nn.Sequential(*[ block_fn( diff --git a/timm/models/nest.py b/timm/models/nest.py index 9a423a9776..120550001f 100644 --- a/timm/models/nest.py +++ b/timm/models/nest.py @@ -26,7 +26,7 @@ from torch import nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import PatchEmbed, Mlp, DropPath, create_classifier, trunc_normal_, _assert +from timm.layers import PatchEmbed, Mlp, DropPath, calculate_drop_path_rates, create_classifier, trunc_normal_, _assert from timm.layers import create_conv2d, create_pool2d, to_ntuple, use_fused_attn, LayerNorm from ._builder import build_model_with_cfg from ._features import feature_take_indices @@ -346,7 +346,7 @@ def __init__( # Build up each hierarchical level levels = [] - dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)] + dp_rates = calculate_drop_path_rates(drop_path_rate, depths, stagewise=True) prev_dim = None curr_stride = 4 for i in range(len(self.num_blocks)): diff --git a/timm/models/nextvit.py b/timm/models/nextvit.py index 402a9d76ea..250161c510 100644 --- a/timm/models/nextvit.py +++ b/timm/models/nextvit.py @@ -13,7 +13,7 @@ from torch import nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import DropPath, trunc_normal_, ConvMlp, get_norm_layer, get_act_layer, use_fused_attn +from timm.layers import DropPath, calculate_drop_path_rates, trunc_normal_, ConvMlp, get_norm_layer, get_act_layer, use_fused_attn from timm.layers import ClassifierHead from ._builder import build_model_with_cfg from ._features import feature_take_indices @@ -498,7 +498,7 @@ def __init__( in_chs = out_chs = stem_chs[-1] stages = [] idx = 0 - dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)] + dpr = calculate_drop_path_rates(drop_path_rate, depths, stagewise=True) for stage_idx in range(len(depths)): stage = NextStage( in_chs=in_chs, diff --git a/timm/models/nfnet.py b/timm/models/nfnet.py index 68b8b1b6d6..309ee8c5fe 100644 --- a/timm/models/nfnet.py +++ b/timm/models/nfnet.py @@ -25,7 +25,7 @@ import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import ClassifierHead, DropPath, AvgPool2dSame, ScaledStdConv2d, ScaledStdConv2dSame, \ +from timm.layers import ClassifierHead, DropPath, calculate_drop_path_rates, AvgPool2dSame, ScaledStdConv2d, ScaledStdConv2dSame, \ get_act_layer, get_act_fn, get_attn, make_divisible from ._builder import build_model_with_cfg from ._features_fx import register_notrace_module @@ -426,7 +426,7 @@ def __init__( ) self.feature_info = [stem_feat] - drop_path_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg.depths)).split(cfg.depths)] + drop_path_rates = calculate_drop_path_rates(drop_path_rate, cfg.depths, stagewise=True) prev_chs = stem_chs net_stride = stem_stride dilation = 1 diff --git a/timm/models/pit.py b/timm/models/pit.py index b7985cf998..c8b1965a6b 100644 --- a/timm/models/pit.py +++ b/timm/models/pit.py @@ -20,7 +20,7 @@ from torch import nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import trunc_normal_, to_2tuple +from timm.layers import trunc_normal_, to_2tuple, calculate_drop_path_rates from ._builder import build_model_with_cfg from ._features import feature_take_indices from ._registry import register_model, generate_default_cfgs @@ -184,7 +184,7 @@ def __init__( transformers = [] # stochastic depth decay rule - dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depth)).split(depth)] + dpr = calculate_drop_path_rates(drop_path_rate, depth, stagewise=True) prev_dim = embed_dim for i in range(len(depth)): pool = None diff --git a/timm/models/pvt_v2.py b/timm/models/pvt_v2.py index 0259a1f64e..ac8c90492a 100644 --- a/timm/models/pvt_v2.py +++ b/timm/models/pvt_v2.py @@ -23,7 +23,7 @@ import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import DropPath, to_2tuple, to_ntuple, trunc_normal_, LayerNorm, use_fused_attn +from timm.layers import DropPath, calculate_drop_path_rates, to_2tuple, to_ntuple, trunc_normal_, LayerNorm, use_fused_attn from ._builder import build_model_with_cfg from ._features import feature_take_indices from ._manipulate import checkpoint @@ -313,7 +313,7 @@ def __init__( embed_dim=embed_dims[0], ) - dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)] + dpr = calculate_drop_path_rates(drop_path_rate, depths, stagewise=True) cur = 0 prev_dim = embed_dims[0] stages = [] diff --git a/timm/models/rdnet.py b/timm/models/rdnet.py index 5764b6ed82..a5961571c0 100644 --- a/timm/models/rdnet.py +++ b/timm/models/rdnet.py @@ -11,7 +11,7 @@ import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import DropPath, NormMlpClassifierHead, ClassifierHead, EffectiveSEModule, \ +from timm.layers import DropPath, calculate_drop_path_rates, NormMlpClassifierHead, ClassifierHead, EffectiveSEModule, \ make_divisible, get_act_layer, get_norm_layer from ._builder import build_model_with_cfg from ._features import feature_take_indices @@ -213,7 +213,7 @@ def __init__( self.num_stages = len(growth_rates) curr_stride = stem_stride num_features = num_init_features - dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(num_blocks_list)).split(num_blocks_list)] + dp_rates = calculate_drop_path_rates(drop_path_rate, num_blocks_list, stagewise=True) dense_stages = [] for i in range(self.num_stages): diff --git a/timm/models/regnet.py b/timm/models/regnet.py index f05bad37ba..e42bf44641 100644 --- a/timm/models/regnet.py +++ b/timm/models/regnet.py @@ -32,7 +32,7 @@ import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import ClassifierHead, AvgPool2dSame, ConvNormAct, SEModule, DropPath, GroupNormAct +from timm.layers import ClassifierHead, AvgPool2dSame, ConvNormAct, SEModule, DropPath, GroupNormAct, calculate_drop_path_rates from timm.layers import get_act_layer, get_norm_act_layer, create_conv2d, make_divisible from ._builder import build_model_with_cfg from ._features import feature_take_indices @@ -657,11 +657,7 @@ def _get_stage_args( net_stride *= stride stage_strides.append(stride) stage_dilations.append(dilation) - dpr_tensor = torch.linspace(0, drop_path_rate, sum(stage_depths)) - split_indices = torch.cumsum(torch.tensor(stage_depths[:-1]), dim=0) - stage_dpr = torch.tensor_split(dpr_tensor, split_indices.tolist()) - stage_dpr = [dpr.tolist() for dpr in stage_dpr] - + stage_dpr = calculate_drop_path_rates(drop_path_rate, stage_depths, stagewise=True) # Adjust the compatibility of ws and gws stage_widths, stage_gs = adjust_widths_groups_comp( stage_widths, stage_br, stage_gs, min_ratio=cfg.group_min_ratio) diff --git a/timm/models/resnest.py b/timm/models/resnest.py index 5b1438017e..ed25c6e448 100644 --- a/timm/models/resnest.py +++ b/timm/models/resnest.py @@ -45,9 +45,8 @@ def __init__( ): super(ResNestBottleneck, self).__init__() assert reduce_first == 1 # not supported - assert attn_layer is None # not supported - assert aa_layer is None # TODO not yet supported - assert drop_path is None # TODO not yet supported + assert attn_layer is None, 'attn_layer is not supported' # not supported + assert aa_layer is None, 'aa_layer is not supported' # TODO not yet supported group_width = int(planes * (base_width / 64.)) * cardinality first_dilation = first_dilation or dilation @@ -83,6 +82,7 @@ def __init__( self.bn3 = norm_layer(planes*4) self.act3 = act_layer(inplace=True) self.downsample = downsample + self.drop_path = drop_path def zero_init_last(self): if getattr(self.bn3, 'weight', None) is not None: @@ -109,6 +109,9 @@ def forward(self, x): out = self.conv3(out) out = self.bn3(out) + if self.drop_path is not None: + x = self.drop_path(x) + if self.downsample is not None: shortcut = self.downsample(x) diff --git a/timm/models/resnetv2.py b/timm/models/resnetv2.py index 0b78e7b44d..7c1424d93c 100644 --- a/timm/models/resnetv2.py +++ b/timm/models/resnetv2.py @@ -38,7 +38,7 @@ from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from timm.layers import GroupNormAct, BatchNormAct2d, EvoNorm2dS0, FilterResponseNormTlu2d, ClassifierHead, \ - DropPath, AvgPool2dSame, create_pool2d, StdConv2d, create_conv2d, get_act_layer, get_norm_act_layer, make_divisible + DropPath, calculate_drop_path_rates, AvgPool2dSame, create_pool2d, StdConv2d, create_conv2d, get_act_layer, get_norm_act_layer, make_divisible from ._builder import build_model_with_cfg from ._features import feature_take_indices from ._manipulate import checkpoint_seq, named_apply, adapt_input_conv @@ -566,7 +566,7 @@ def __init__( prev_chs = stem_chs curr_stride = 4 dilation = 1 - block_dprs = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(layers)).split(layers)] + block_dprs = calculate_drop_path_rates(drop_path_rate, layers, stagewise=True) if preact: block_fn = PreActBasic if basic else PreActBottleneck else: diff --git a/timm/models/starnet.py b/timm/models/starnet.py index 9ed32a85d7..414b0aa8a4 100644 --- a/timm/models/starnet.py +++ b/timm/models/starnet.py @@ -16,7 +16,7 @@ import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import DropPath, SelectAdaptivePool2d, Linear, LayerType, trunc_normal_ +from timm.layers import DropPath, SelectAdaptivePool2d, Linear, LayerType, trunc_normal_, calculate_drop_path_rates from ._builder import build_model_with_cfg from ._features import feature_take_indices from ._manipulate import checkpoint_seq @@ -103,7 +103,7 @@ def __init__( prev_chs = stem_chs # build stages - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth + dpr = calculate_drop_path_rates(drop_path_rate, sum(depths)) # stochastic depth stages = [] cur = 0 for i_layer in range(len(depths)): diff --git a/timm/models/swin_transformer.py b/timm/models/swin_transformer.py index e17f16746c..98409975ab 100644 --- a/timm/models/swin_transformer.py +++ b/timm/models/swin_transformer.py @@ -23,7 +23,7 @@ import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import PatchEmbed, Mlp, DropPath, ClassifierHead, to_2tuple, to_ntuple, trunc_normal_, \ +from timm.layers import PatchEmbed, Mlp, DropPath, calculate_drop_path_rates, ClassifierHead, to_2tuple, to_ntuple, trunc_normal_, \ use_fused_attn, resize_rel_pos_bias_table, resample_patch_embed, ndgrid from ._builder import build_model_with_cfg from ._features import feature_take_indices @@ -689,7 +689,7 @@ def __init__( window_size = (window_size,) * self.num_layers assert len(window_size) == self.num_layers mlp_ratio = to_ntuple(self.num_layers)(mlp_ratio) - dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)] + dpr = calculate_drop_path_rates(drop_path_rate, depths, stagewise=True) layers = [] in_dim = embed_dim[0] scale = 1 diff --git a/timm/models/swin_transformer_v2.py b/timm/models/swin_transformer_v2.py index 35c0daa8ac..96e2be2730 100644 --- a/timm/models/swin_transformer_v2.py +++ b/timm/models/swin_transformer_v2.py @@ -20,7 +20,7 @@ import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_, ClassifierHead,\ +from timm.layers import PatchEmbed, Mlp, DropPath, calculate_drop_path_rates, to_2tuple, trunc_normal_, ClassifierHead,\ resample_patch_embed, ndgrid, get_act_layer, LayerType from ._builder import build_model_with_cfg from ._features import feature_take_indices @@ -715,7 +715,7 @@ def __init__( ) grid_size = self.patch_embed.grid_size - dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)] + dpr = calculate_drop_path_rates(drop_path_rate, depths, stagewise=True) layers = [] in_dim = embed_dim[0] scale = 1 diff --git a/timm/models/swin_transformer_v2_cr.py b/timm/models/swin_transformer_v2_cr.py index 1ef3164fd9..7ac7d7aef3 100644 --- a/timm/models/swin_transformer_v2_cr.py +++ b/timm/models/swin_transformer_v2_cr.py @@ -36,7 +36,7 @@ import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import DropPath, Mlp, ClassifierHead, to_2tuple, _assert, ndgrid +from timm.layers import DropPath, calculate_drop_path_rates, Mlp, ClassifierHead, to_2tuple, _assert, ndgrid from ._builder import build_model_with_cfg from ._features import feature_take_indices from ._features_fx import register_notrace_function @@ -718,7 +718,7 @@ def __init__( else: self.window_size = to_2tuple(window_size) - dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)] + dpr = calculate_drop_path_rates(drop_path_rate, depths, stagewise=True) stages = [] in_dim = embed_dim in_scale = 1 diff --git a/timm/models/tiny_vit.py b/timm/models/tiny_vit.py index 39bacc850c..2da814ab45 100644 --- a/timm/models/tiny_vit.py +++ b/timm/models/tiny_vit.py @@ -18,7 +18,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import LayerNorm2d, NormMlpClassifierHead, DropPath,\ - trunc_normal_, resize_rel_pos_bias_table_levit, use_fused_attn + trunc_normal_, resize_rel_pos_bias_table_levit, use_fused_attn, calculate_drop_path_rates from ._builder import build_model_with_cfg from ._features import feature_take_indices from ._features_fx import register_notrace_module @@ -449,7 +449,7 @@ def __init__( ) # stochastic depth rate rule - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] + dpr = calculate_drop_path_rates(drop_path_rate, sum(depths)) # build stages self.stages = nn.Sequential() diff --git a/timm/models/tnt.py b/timm/models/tnt.py index 0ecd8e72a4..a801062425 100644 --- a/timm/models/tnt.py +++ b/timm/models/tnt.py @@ -16,7 +16,7 @@ import torch.nn as nn from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD -from timm.layers import Mlp, DropPath, trunc_normal_, _assert, to_2tuple, resample_abs_pos_embed +from timm.layers import Mlp, DropPath, calculate_drop_path_rates, trunc_normal_, _assert, to_2tuple, resample_abs_pos_embed from ._builder import build_model_with_cfg from ._features import feature_take_indices from ._manipulate import checkpoint @@ -273,7 +273,7 @@ def __init__( self.pixel_pos = nn.Parameter(torch.zeros(1, inner_dim, new_patch_size[0], new_patch_size[1])) self.pos_drop = nn.Dropout(p=pos_drop_rate) - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + dpr = calculate_drop_path_rates(drop_path_rate, depth) # stochastic depth decay rule blocks = [] for i in range(depth): blocks.append(Block( diff --git a/timm/models/tresnet.py b/timm/models/tresnet.py index 2c452e4707..26d44c5f0c 100644 --- a/timm/models/tresnet.py +++ b/timm/models/tresnet.py @@ -12,7 +12,7 @@ import torch import torch.nn as nn -from timm.layers import SpaceToDepth, BlurPool2d, ClassifierHead, SEModule, ConvNormAct, DropPath +from timm.layers import SpaceToDepth, BlurPool2d, ClassifierHead, SEModule, ConvNormAct, DropPath, calculate_drop_path_rates from ._builder import build_model_with_cfg from ._features import feature_take_indices from ._manipulate import checkpoint, checkpoint_seq @@ -136,7 +136,7 @@ def __init__( self.inplanes = self.inplanes // 8 * 8 self.planes = self.planes // 8 * 8 - dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(layers)).split(layers)] + dpr = calculate_drop_path_rates(drop_path_rate, layers, stagewise=True) conv1 = ConvNormAct(in_chans * 16, self.planes, stride=1, kernel_size=3, act_layer=act_layer) layer1 = self._make_layer( Bottleneck if v2 else BasicBlock, diff --git a/timm/models/twins.py b/timm/models/twins.py index 74d22e24c1..2d1e23bd8d 100644 --- a/timm/models/twins.py +++ b/timm/models/twins.py @@ -20,7 +20,7 @@ import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import Mlp, DropPath, to_2tuple, trunc_normal_, use_fused_attn +from timm.layers import Mlp, DropPath, to_2tuple, trunc_normal_, use_fused_attn, calculate_drop_path_rates from ._builder import build_model_with_cfg from ._features import feature_take_indices from ._features_fx import register_notrace_module @@ -326,7 +326,7 @@ def __init__( self.blocks = nn.ModuleList() self.feature_info = [] - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + dpr = calculate_drop_path_rates(drop_path_rate, sum(depths)) # stochastic depth decay rule cur = 0 for k in range(len(depths)): _block = nn.ModuleList([block_cls( diff --git a/timm/models/visformer.py b/timm/models/visformer.py index 2ed3be5da8..deadd98a6a 100644 --- a/timm/models/visformer.py +++ b/timm/models/visformer.py @@ -11,7 +11,7 @@ import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import to_2tuple, trunc_normal_, DropPath, PatchEmbed, LayerNorm2d, create_classifier, use_fused_attn +from timm.layers import to_2tuple, trunc_normal_, DropPath, calculate_drop_path_rates, PatchEmbed, LayerNorm2d, create_classifier, use_fused_attn from ._builder import build_model_with_cfg from ._manipulate import checkpoint_seq from ._registry import register_model, generate_default_cfgs @@ -202,7 +202,7 @@ def __init__( self.use_pos_embed = use_pos_embed self.grad_checkpointing = False - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] + dpr = calculate_drop_path_rates(drop_path_rate, depth) # stage 1 if self.vit_stem: self.stem = None diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 1c21d4ffb2..d65174c535 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -55,6 +55,7 @@ LayerNorm, RmsNorm, DropPath, + calculate_drop_path_rates, PatchDropout, trunc_normal_, lecun_normal_, @@ -578,7 +579,7 @@ def __init__( self.patch_drop = nn.Identity() self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity() - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + dpr = calculate_drop_path_rates(drop_path_rate, depth) # stochastic depth decay rule self.blocks = nn.Sequential(*[ block_fn( dim=embed_dim, @@ -1736,7 +1737,9 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]: 'vit_medium_patch16_gap_384.sw_in12k_ft_in1k': _cfg( hf_hub_id='timm/', input_size=(3, 384, 384), crop_pct=0.95, crop_mode='squash'), - 'vit_base_patch16_gap_224': _cfg(), + 'vit_betwixt_patch16_gap_256.untrained': _cfg( + input_size=(3, 256, 256), crop_pct=0.95), + 'vit_base_patch16_gap_224.untrained': _cfg(), # CLIP pretrained image tower and related fine-tuned weights 'vit_base_patch32_clip_224.laion2b_ft_in12k_in1k': _cfg( @@ -2965,7 +2968,7 @@ def vit_betwixt_patch16_gap_256(pretrained: bool = False, **kwargs) -> VisionTra patch_size=16, embed_dim=640, depth=12, num_heads=10, class_token=False, global_pool='avg', qkv_bias=False, init_values=1e-6, fc_norm=False) model = _create_vision_transformer( - 'vit_medium_patch16_gap_256', pretrained=pretrained, **dict(model_args, **kwargs)) + 'vit_betwixt_patch16_gap_256', pretrained=pretrained, **dict(model_args, **kwargs)) return model diff --git a/timm/models/vision_transformer_relpos.py b/timm/models/vision_transformer_relpos.py index dcccba73ba..8a5e5c6d65 100644 --- a/timm/models/vision_transformer_relpos.py +++ b/timm/models/vision_transformer_relpos.py @@ -19,7 +19,7 @@ from torch.jit import Final from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD -from timm.layers import PatchEmbed, Mlp, DropPath, RelPosMlp, RelPosBias, use_fused_attn, LayerType +from timm.layers import PatchEmbed, Mlp, DropPath, calculate_drop_path_rates, RelPosMlp, RelPosBias, use_fused_attn, LayerType from ._builder import build_model_with_cfg from ._features import feature_take_indices from ._manipulate import named_apply, checkpoint @@ -316,7 +316,7 @@ def __init__( self.cls_token = nn.Parameter(torch.zeros(1, self.num_prefix_tokens, embed_dim)) if class_token else None - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + dpr = calculate_drop_path_rates(drop_path_rate, depth) # stochastic depth decay rule self.blocks = nn.ModuleList([ block_fn( dim=embed_dim, diff --git a/timm/models/vision_transformer_sam.py b/timm/models/vision_transformer_sam.py index df70f4a251..a8cce36c0c 100644 --- a/timm/models/vision_transformer_sam.py +++ b/timm/models/vision_transformer_sam.py @@ -17,7 +17,7 @@ import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD -from timm.layers import PatchEmbed, Mlp, DropPath, PatchDropout, LayerNorm2d, ClassifierHead, NormMlpClassifierHead, \ +from timm.layers import PatchEmbed, Mlp, DropPath, calculate_drop_path_rates, PatchDropout, LayerNorm2d, ClassifierHead, NormMlpClassifierHead, \ Format, resample_abs_pos_embed_nhwc, RotaryEmbeddingCat, apply_rot_embed_cat, to_2tuple, use_fused_attn from torch.jit import Final @@ -449,7 +449,7 @@ def __init__( self.rope_window = None # stochastic depth decay rule - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] + dpr = calculate_drop_path_rates(drop_path_rate, depth) self.blocks = nn.Sequential(*[ block_fn( dim=embed_dim, diff --git a/timm/models/vitamin.py b/timm/models/vitamin.py index fa76275426..0effa5bbfc 100644 --- a/timm/models/vitamin.py +++ b/timm/models/vitamin.py @@ -310,7 +310,7 @@ def _cfg(url='', **kwargs): default_cfgs = generate_default_cfgs({ 'vitamin_small_224.datacomp1b_clip_ltt': _cfg( - hf_hub_id='jienengchen/ViTamin-S-LTT', num_classes=384), + hf_hub_id='jienengchen/ViTamin-S-LTT', num_classes=768), 'vitamin_small_224.datacomp1b_clip': _cfg( hf_hub_id='jienengchen/ViTamin-S', num_classes=384), 'vitamin_base_224.datacomp1b_clip_ltt': _cfg( diff --git a/timm/models/vovnet.py b/timm/models/vovnet.py index 8da9431f09..948ea501dc 100644 --- a/timm/models/vovnet.py +++ b/timm/models/vovnet.py @@ -18,7 +18,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import ConvNormAct, SeparableConvNormAct, BatchNormAct2d, ClassifierHead, DropPath, \ - create_attn, create_norm_act_layer + create_attn, create_norm_act_layer, calculate_drop_path_rates from ._builder import build_model_with_cfg from ._features import feature_take_indices from ._manipulate import checkpoint_seq @@ -214,7 +214,7 @@ def __init__( current_stride = stem_stride # OSA stages - stage_dpr = torch.split(torch.linspace(0, drop_path_rate, sum(block_per_stage)), block_per_stage) + stage_dpr = calculate_drop_path_rates(drop_path_rate, block_per_stage, stagewise=True) in_ch_list = stem_chs[-1:] + stage_out_chs[:-1] stage_args = dict(residual=cfg["residual"], depthwise=cfg["depthwise"], attn=cfg["attn"], **conv_kwargs) stages = []