Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion timm/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from .create_conv2d import create_conv2d
from .create_norm import get_norm_layer, create_norm_layer
from .create_norm_act import get_norm_act_layer, create_norm_act_layer, get_norm_act_layer
from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path
from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path, calculate_drop_path_rates
from .eca import EcaModule, CecaModule, EfficientChannelAttn, CircularEfficientChannelAttn
from .evo_norm import (
EvoNorm2dB0,
Expand Down
43 changes: 43 additions & 0 deletions timm/layers/drop.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

Hacked together by / Copyright 2020 Ross Wightman
"""
from typing import List, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
Expand Down Expand Up @@ -180,3 +182,44 @@ def forward(self, x):

def extra_repr(self):
return f'drop_prob={round(self.drop_prob,3):0.3f}'


def calculate_drop_path_rates(
drop_path_rate: float,
depths: Union[int, List[int]],
stagewise: bool = False,
) -> Union[List[float], List[List[float]]]:
"""Generate drop path rates for stochastic depth.

This function handles two common patterns for drop path rate scheduling:
1. Per-block: Linear increase from 0 to drop_path_rate across all blocks
2. Stage-wise: Linear increase across stages, with same rate within each stage

Args:
drop_path_rate: Maximum drop path rate (at the end).
depths: Either a single int for total depth (per-block mode) or
list of ints for depths per stage (stage-wise mode).
stagewise: If True, use stage-wise pattern. If False, use per-block pattern.
When depths is a list, stagewise defaults to True.

Returns:
For per-block mode: List of drop rates, one per block.
For stage-wise mode: List of lists, drop rates per stage.
"""
if isinstance(depths, int):
# Single depth value - per-block pattern
if stagewise:
raise ValueError("stagewise=True requires depths to be a list of stage depths")
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depths, device='cpu')]
return dpr
else:
# List of depths - can be either pattern
total_depth = sum(depths)
if stagewise:
# Stage-wise pattern: same drop rate within each stage
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, total_depth, device='cpu').split(depths)]
return dpr
else:
# Per-block pattern across all stages
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, total_depth, device='cpu')]
return dpr
2 changes: 1 addition & 1 deletion timm/models/_efficientnet_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def __init__(
if s2d == 1:
sd_chs = int(in_chs * 4)
self.conv_s2d = create_conv2d(in_chs, sd_chs, kernel_size=2, stride=2, padding='same')
self.bn_s2d = norm_act_layer(sd_chs, sd_chs)
self.bn_s2d = norm_act_layer(sd_chs)
dw_kernel_size = (dw_kernel_size + 1) // 2
dw_pad_type = 'same' if dw_kernel_size == 2 else pad_type
in_chs = sd_chs
Expand Down
4 changes: 2 additions & 2 deletions timm/models/beit.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
import torch.nn.functional as F

from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import PatchEmbed, Mlp, SwiGLU, LayerNorm, DropPath, trunc_normal_, use_fused_attn
from timm.layers import PatchEmbed, Mlp, SwiGLU, LayerNorm, DropPath, calculate_drop_path_rates, trunc_normal_, use_fused_attn
from timm.layers import resample_patch_embed, resample_abs_pos_embed, resize_rel_pos_bias_table, ndgrid

from ._builder import build_model_with_cfg
Expand Down Expand Up @@ -448,7 +448,7 @@ def __init__(
else:
self.rel_pos_bias = None

dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
dpr = calculate_drop_path_rates(drop_path_rate, depth) # stochastic depth decay rule
self.blocks = nn.ModuleList([
Block(
dim=embed_dim,
Expand Down
4 changes: 2 additions & 2 deletions timm/models/byobnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
from timm.layers import (
ClassifierHead, NormMlpClassifierHead, ConvNormAct, BatchNormAct2d, EvoNorm2dS0a,
AttentionPool2d, RotAttentionPool2d, DropPath, AvgPool2dSame,
AttentionPool2d, RotAttentionPool2d, DropPath, calculate_drop_path_rates, AvgPool2dSame,
create_conv2d, get_act_layer, get_norm_act_layer, get_attn, make_divisible, to_2tuple,
)
from ._builder import build_model_with_cfg
Expand Down Expand Up @@ -1212,7 +1212,7 @@ def create_byob_stages(
feature_info = []
block_cfgs = [expand_blocks_cfg(s) for s in cfg.blocks]
depths = [sum([bc.d for bc in stage_bcs]) for stage_bcs in block_cfgs]
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
dpr = calculate_drop_path_rates(drop_path_rate, depths, stagewise=True)
dilation = 1
net_stride = stem_feat['reduction']
prev_chs = stem_feat['num_chs']
Expand Down
2 changes: 0 additions & 2 deletions timm/models/coat.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,9 +417,7 @@ def __init__(
self.crpe3 = ConvRelPosEnc(head_chs=embed_dims[2] // num_heads, num_heads=num_heads, window=crpe_window)
self.crpe4 = ConvRelPosEnc(head_chs=embed_dims[3] // num_heads, num_heads=num_heads, window=crpe_window)

# Disable stochastic depth.
dpr = drop_path_rate
assert dpr == 0.0
skwargs = dict(
num_heads=num_heads,
qkv_bias=qkv_bias,
Expand Down
4 changes: 2 additions & 2 deletions timm/models/convit.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import torch.nn as nn

from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import DropPath, trunc_normal_, PatchEmbed, Mlp, LayerNorm, HybridEmbed
from timm.layers import DropPath, calculate_drop_path_rates, trunc_normal_, PatchEmbed, Mlp, LayerNorm, HybridEmbed
from ._builder import build_model_with_cfg
from ._features_fx import register_notrace_module
from ._registry import register_model, generate_default_cfgs
Expand Down Expand Up @@ -292,7 +292,7 @@ def __init__(
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
trunc_normal_(self.pos_embed, std=.02)

dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
dpr = calculate_drop_path_rates(drop_path_rate, depth) # stochastic depth decay rule
self.blocks = nn.ModuleList([
Block(
dim=embed_dim,
Expand Down
4 changes: 2 additions & 2 deletions timm/models/convnext.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
import torch.nn as nn

from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
from timm.layers import trunc_normal_, AvgPool2dSame, DropPath, Mlp, GlobalResponseNormMlp, \
from timm.layers import trunc_normal_, AvgPool2dSame, DropPath, calculate_drop_path_rates, Mlp, GlobalResponseNormMlp, \
LayerNorm2d, LayerNorm, RmsNorm2d, RmsNorm, create_conv2d, get_act_layer, get_norm_layer, make_divisible, to_ntuple
from timm.layers import SimpleNorm2d, SimpleNorm
from timm.layers import NormMlpClassifierHead, ClassifierHead
Expand Down Expand Up @@ -377,7 +377,7 @@ def __init__(
stem_stride = 4

self.stages = nn.Sequential()
dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
dp_rates = calculate_drop_path_rates(drop_path_rate, depths, stagewise=True)
stages = []
prev_chs = dims[0]
curr_stride = stem_stride
Expand Down
4 changes: 2 additions & 2 deletions timm/models/crossvit.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import torch.nn as nn

from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import DropPath, to_2tuple, trunc_normal_, _assert
from timm.layers import DropPath, calculate_drop_path_rates, to_2tuple, trunc_normal_, _assert
from ._builder import build_model_with_cfg
from ._features_fx import register_notrace_function
from ._registry import register_model, generate_default_cfgs
Expand Down Expand Up @@ -346,7 +346,7 @@ def __init__(
self.pos_drop = nn.Dropout(p=pos_drop_rate)

total_depth = sum([sum(x[-2:]) for x in depth])
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, total_depth)] # stochastic depth decay rule
dpr = calculate_drop_path_rates(drop_path_rate, total_depth) # stochastic depth decay rule
dpr_ptr = 0
self.blocks = nn.ModuleList()
for idx, block_cfg in enumerate(depth):
Expand Down
4 changes: 2 additions & 2 deletions timm/models/cspnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import torch.nn as nn

from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import ClassifierHead, ConvNormAct, DropPath, get_attn, create_act_layer, make_divisible
from timm.layers import ClassifierHead, ConvNormAct, DropPath, calculate_drop_path_rates, get_attn, create_act_layer, make_divisible
from ._builder import build_model_with_cfg
from ._manipulate import named_apply, MATCH_PREV_GROUP
from ._registry import register_model, generate_default_cfgs
Expand Down Expand Up @@ -569,7 +569,7 @@ def create_csp_stages(
cfg_dict = asdict(cfg.stages)
num_stages = len(cfg.stages.depth)
cfg_dict['block_dpr'] = [None] * num_stages if not drop_path_rate else \
[x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg.stages.depth)).split(cfg.stages.depth)]
calculate_drop_path_rates(drop_path_rate, cfg.stages.depth, stagewise=True)
stage_args = [dict(zip(cfg_dict.keys(), values)) for values in zip(*cfg_dict.values())]
block_kwargs = dict(
act_layer=cfg.act_layer,
Expand Down
4 changes: 2 additions & 2 deletions timm/models/davit.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from torch import Tensor

from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import DropPath, to_2tuple, trunc_normal_, Mlp, LayerNorm2d, get_norm_layer, use_fused_attn
from timm.layers import DropPath, calculate_drop_path_rates, to_2tuple, trunc_normal_, Mlp, LayerNorm2d, get_norm_layer, use_fused_attn
from timm.layers import NormMlpClassifierHead, ClassifierHead
from ._builder import build_model_with_cfg
from ._features import feature_take_indices
Expand Down Expand Up @@ -555,7 +555,7 @@ def __init__(
self.stem = Stem(in_chans, embed_dims[0], norm_layer=norm_layer)
in_chs = embed_dims[0]

dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
dpr = calculate_drop_path_rates(drop_path_rate, depths, stagewise=True)
stages = []
for i in range(num_stages):
out_chs = embed_dims[i]
Expand Down
4 changes: 2 additions & 2 deletions timm/models/edgenext.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from torch import nn

from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import trunc_normal_tf_, DropPath, LayerNorm2d, Mlp, create_conv2d, \
from timm.layers import trunc_normal_tf_, DropPath, calculate_drop_path_rates, LayerNorm2d, Mlp, create_conv2d, \
NormMlpClassifierHead, ClassifierHead
from ._builder import build_model_with_cfg
from ._features import feature_take_indices
Expand Down Expand Up @@ -343,7 +343,7 @@ def __init__(

curr_stride = 4
stages = []
dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
dp_rates = calculate_drop_path_rates(drop_path_rate, depths, stagewise=True)
in_chs = dims[0]
for i in range(4):
stride = 2 if curr_stride == 2 or i > 0 else 1
Expand Down
4 changes: 2 additions & 2 deletions timm/models/efficientformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import torch.nn as nn

from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import DropPath, trunc_normal_, to_2tuple, Mlp, ndgrid
from timm.layers import DropPath, calculate_drop_path_rates, trunc_normal_, to_2tuple, Mlp, ndgrid
from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._manipulate import checkpoint_seq
Expand Down Expand Up @@ -385,7 +385,7 @@ def __init__(
# stochastic depth decay rule
self.num_stages = len(depths)
last_stage = self.num_stages - 1
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
dpr = calculate_drop_path_rates(drop_path_rate, depths, stagewise=True)
downsamples = downsamples or (False,) + (True,) * (self.num_stages - 1)
stages = []
self.feature_info = []
Expand Down
4 changes: 2 additions & 2 deletions timm/models/efficientformer_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import create_conv2d, create_norm_layer, get_act_layer, get_norm_layer, ConvNormAct
from timm.layers import DropPath, trunc_normal_, to_2tuple, to_ntuple, ndgrid
from timm.layers import DropPath, calculate_drop_path_rates, trunc_normal_, to_2tuple, to_ntuple, ndgrid
from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._manipulate import checkpoint_seq
Expand Down Expand Up @@ -543,7 +543,7 @@ def __init__(
stride = 4

num_stages = len(depths)
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
dpr = calculate_drop_path_rates(drop_path_rate, depths, stagewise=True)
downsamples = downsamples or (False,) + (True,) * (len(depths) - 1)
mlp_ratios = to_ntuple(num_stages)(mlp_ratios)
stages = []
Expand Down
12 changes: 9 additions & 3 deletions timm/models/efficientnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -1434,6 +1434,12 @@ def _cfg(url='', **kwargs):
'efficientnet_b3_g8_gn.untrained': _cfg(
input_size=(3, 288, 288), pool_size=(9, 9), test_input_size=(3, 320, 320), crop_pct=1.0),
'efficientnet_blur_b0.untrained': _cfg(),
'efficientnet_h_b5.untrained': _cfg(
url='', input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0),
'efficientnet_x_b3.untrained': _cfg(
url='', input_size=(3, 288, 288), pool_size=(9, 9), crop_pct=0.95),
'efficientnet_x_b5.untrained': _cfg(
url='', input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0),

'efficientnet_es.ra_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_es_ra-f111e99c.pth',
Expand Down Expand Up @@ -2708,23 +2714,23 @@ def efficientnet_x_b3(pretrained=False, **kwargs) -> EfficientNet:
""" EfficientNet-B3 """
# NOTE for train, drop_rate should be 0.3, drop_path_rate should be 0.2
model = _gen_efficientnet_x(
'efficientnet_b3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs)
'efficientnet_x_b3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs)
return model


@register_model
def efficientnet_x_b5(pretrained=False, **kwargs) -> EfficientNet:
""" EfficientNet-B5 """
model = _gen_efficientnet_x(
'efficientnet_b5', channel_multiplier=1.6, depth_multiplier=2.2, pretrained=pretrained, **kwargs)
'efficientnet_x_b5', channel_multiplier=1.6, depth_multiplier=2.2, pretrained=pretrained, **kwargs)
return model


@register_model
def efficientnet_h_b5(pretrained=False, **kwargs) -> EfficientNet:
""" EfficientNet-B5 """
model = _gen_efficientnet_x(
'efficientnet_b5', channel_multiplier=1.92, depth_multiplier=2.2, version=2, pretrained=pretrained, **kwargs)
'efficientnet_h_b5', channel_multiplier=1.92, depth_multiplier=2.2, version=2, pretrained=pretrained, **kwargs)
return model


Expand Down
7 changes: 4 additions & 3 deletions timm/models/eva.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@
GluMlp,
SwiGLU,
LayerNorm,
DropPath,
DropPath, calculate_drop_path_rates,
PatchDropoutWithIndices,
create_rope_embed,
apply_rot_embed_cat,
Expand Down Expand Up @@ -365,7 +365,7 @@ def __init__(
attn_type: str = 'eva',
rotate_half: bool = False,
swiglu_mlp: bool = False,
swiglu_aligh_to: int = 0,
swiglu_align_to: int = 0,
scale_mlp: bool = False,
scale_attn_inner: bool = False,
num_prefix_tokens: int = 1,
Expand Down Expand Up @@ -425,6 +425,7 @@ def __init__(
hidden_features=hidden_features,
norm_layer=norm_layer if scale_mlp else None,
drop=proj_drop,
align_to=swiglu_align_to,
)
else:
# w/o any extra norm, an impl with packed fc1 weights is used, matches existing GluMLP
Expand Down Expand Up @@ -637,7 +638,7 @@ def __init__(

self.norm_pre = norm_layer(embed_dim) if activate_pre_norm else nn.Identity()

dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
dpr = calculate_drop_path_rates(drop_path_rate, depth) # stochastic depth decay rule
block_fn = EvaBlockPostNorm if use_post_norm else EvaBlock
self.blocks = nn.ModuleList([
block_fn(
Expand Down
4 changes: 2 additions & 2 deletions timm/models/fasternet.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import torch.nn.functional as F

from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import SelectAdaptivePool2d, Linear, DropPath, trunc_normal_, LayerType
from timm.layers import SelectAdaptivePool2d, Linear, DropPath, trunc_normal_, LayerType, calculate_drop_path_rates
from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._manipulate import checkpoint_seq
Expand Down Expand Up @@ -216,7 +216,7 @@ def __init__(
norm_layer=norm_layer if patch_norm else nn.Identity,
)
# stochastic depth decay rule
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
dpr = calculate_drop_path_rates(drop_path_rate, sum(depths))

# build layers
stages_list = []
Expand Down
4 changes: 2 additions & 2 deletions timm/models/fastvit.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
from timm.layers import (
DropPath,
DropPath, calculate_drop_path_rates,
trunc_normal_,
create_conv2d,
ConvNormAct,
Expand Down Expand Up @@ -1156,7 +1156,7 @@ def __init__(
# Build the main stages of the network architecture
prev_dim = embed_dims[0]
scale = 1
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(layers)).split(layers)]
dpr = calculate_drop_path_rates(drop_path_rate, layers, stagewise=True)
stages = []
for i in range(len(layers)):
downsample = downsamples[i] or prev_dim != embed_dims[i]
Expand Down
4 changes: 2 additions & 2 deletions timm/models/focalnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import torch.nn as nn

from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import Mlp, DropPath, LayerNorm2d, trunc_normal_, ClassifierHead, NormMlpClassifierHead
from timm.layers import Mlp, DropPath, LayerNorm2d, trunc_normal_, ClassifierHead, NormMlpClassifierHead, calculate_drop_path_rates
from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._manipulate import named_apply, checkpoint
Expand Down Expand Up @@ -378,7 +378,7 @@ def __init__(
)
in_dim = embed_dim[0]

dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
dpr = calculate_drop_path_rates(drop_path_rate, sum(depths)) # stochastic depth decay rule
layers = []
for i_layer in range(self.num_layers):
out_dim = embed_dim[i_layer]
Expand Down
Loading