Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add drop path schedule #1835

Open
wants to merge 19 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/build_documentation.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ jobs:
commit_sha: ${{ github.sha }}
package: pytorch-image-models
package_name: timm
repo_owner: rwightman
path_to_docs: pytorch-image-models/hfdocs/source
version_tag_suffix: ""
secrets:
token: ${{ secrets.HUGGINGFACE_PUSH }}
token: ${{ secrets.HUGGINGFACE_PUSH }}
hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }}
1 change: 0 additions & 1 deletion .github/workflows/build_pr_documentation.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,5 @@ jobs:
pr_number: ${{ github.event.number }}
package: pytorch-image-models
package_name: timm
repo_owner: rwightman
path_to_docs: pytorch-image-models/hfdocs/source
version_tag_suffix: ""
14 changes: 7 additions & 7 deletions .github/workflows/delete_doc_comment.yml
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
name: Delete dev documentation
name: Delete doc comment

on:
pull_request:
types: [ closed ]

workflow_run:
workflows: ["Delete doc comment trigger"]
types:
- completed

jobs:
delete:
uses: huggingface/doc-builder/.github/workflows/delete_doc_comment.yml@main
with:
pr_number: ${{ github.event.number }}
package: timm
secrets:
comment_bot_token: ${{ secrets.COMMENT_BOT_TOKEN }}
12 changes: 12 additions & 0 deletions .github/workflows/delete_doc_comment_trigger.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
name: Delete doc comment trigger

on:
pull_request:
types: [ closed ]


jobs:
delete:
uses: huggingface/doc-builder/.github/workflows/delete_doc_comment_trigger.yml@main
with:
pr_number: ${{ github.event.number }}
16 changes: 16 additions & 0 deletions .github/workflows/upload_pr_documentation.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
name: Upload PR Documentation

on:
workflow_run:
workflows: ["Build PR Documentation"]
types:
- completed

jobs:
build:
uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@main
with:
package_name: timm
secrets:
hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }}
comment_bot_token: ${{ secrets.COMMENT_BOT_TOKEN }}
2 changes: 1 addition & 1 deletion timm/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from .create_conv2d import create_conv2d
from .create_norm import get_norm_layer, create_norm_layer
from .create_norm_act import get_norm_act_layer, create_norm_act_layer, get_norm_act_layer
from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path
from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path, efficient_drop_path
from .eca import EcaModule, CecaModule, EfficientChannelAttn, CircularEfficientChannelAttn
from .evo_norm import EvoNorm2dB0, EvoNorm2dB1, EvoNorm2dB2,\
EvoNorm2dS0, EvoNorm2dS0a, EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a
Expand Down
138 changes: 102 additions & 36 deletions timm/layers/drop.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,22 @@

Hacked together by / Copyright 2020 Ross Wightman
"""
from typing import Callable, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F


def drop_block_2d(
x, drop_prob: float = 0.1, block_size: int = 7, gamma_scale: float = 1.0,
with_noise: bool = False, inplace: bool = False, batchwise: bool = False):
""" DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
x,
drop_prob: float = 0.1,
block_size: int = 7,
gamma_scale: float = 1.0,
with_noise: bool = False,
inplace: bool = False,
batchwise: bool = False,
):
"""DropBlock. See https://arxiv.org/pdf/1810.12890.pdf

DropBlock with an experimental gaussian noise option. This layer has been tested on a few training
runs with success, but needs further validation and possibly optimization for lower runtime impact.
Expand All @@ -31,13 +38,15 @@ def drop_block_2d(
total_size = W * H
clipped_block_size = min(block_size, min(W, H))
# seed_drop_rate, the gamma parameter
gamma = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / (
(W - block_size + 1) * (H - block_size + 1))
gamma = (
gamma_scale * drop_prob * total_size / clipped_block_size**2 / ((W - block_size + 1) * (H - block_size + 1))
)

# Forces the block to be inside the feature map.
w_i, h_i = torch.meshgrid(torch.arange(W).to(x.device), torch.arange(H).to(x.device))
valid_block = ((w_i >= clipped_block_size // 2) & (w_i < W - (clipped_block_size - 1) // 2)) & \
((h_i >= clipped_block_size // 2) & (h_i < H - (clipped_block_size - 1) // 2))
valid_block = ((w_i >= clipped_block_size // 2) & (w_i < W - (clipped_block_size - 1) // 2)) & (
(h_i >= clipped_block_size // 2) & (h_i < H - (clipped_block_size - 1) // 2)
)
valid_block = torch.reshape(valid_block, (1, 1, H, W)).to(dtype=x.dtype)

if batchwise:
Expand All @@ -47,10 +56,8 @@ def drop_block_2d(
uniform_noise = torch.rand_like(x)
block_mask = ((2 - gamma - valid_block + uniform_noise) >= 1).to(dtype=x.dtype)
block_mask = -F.max_pool2d(
-block_mask,
kernel_size=clipped_block_size, # block_size,
stride=1,
padding=clipped_block_size // 2)
-block_mask, kernel_size=clipped_block_size, stride=1, padding=clipped_block_size // 2 # block_size,
)

if with_noise:
normal_noise = torch.randn((1, C, H, W), dtype=x.dtype, device=x.device) if batchwise else torch.randn_like(x)
Expand All @@ -68,29 +75,36 @@ def drop_block_2d(


def drop_block_fast_2d(
x: torch.Tensor, drop_prob: float = 0.1, block_size: int = 7,
gamma_scale: float = 1.0, with_noise: bool = False, inplace: bool = False):
""" DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
x: torch.Tensor,
drop_prob: float = 0.1,
block_size: int = 7,
gamma_scale: float = 1.0,
with_noise: bool = False,
inplace: bool = False,
):
"""DropBlock. See https://arxiv.org/pdf/1810.12890.pdf

DropBlock with an experimental gaussian noise option. Simplied from above without concern for valid
block mask at edges.
"""
B, C, H, W = x.shape
total_size = W * H
clipped_block_size = min(block_size, min(W, H))
gamma = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / (
(W - block_size + 1) * (H - block_size + 1))
gamma = (
gamma_scale * drop_prob * total_size / clipped_block_size**2 / ((W - block_size + 1) * (H - block_size + 1))
)

block_mask = torch.empty_like(x).bernoulli_(gamma)
block_mask = F.max_pool2d(
block_mask.to(x.dtype), kernel_size=clipped_block_size, stride=1, padding=clipped_block_size // 2)
block_mask.to(x.dtype), kernel_size=clipped_block_size, stride=1, padding=clipped_block_size // 2
)

if with_noise:
normal_noise = torch.empty_like(x).normal_()
if inplace:
x.mul_(1. - block_mask).add_(normal_noise * block_mask)
x.mul_(1.0 - block_mask).add_(normal_noise * block_mask)
else:
x = x * (1. - block_mask) + normal_noise * block_mask
x = x * (1.0 - block_mask) + normal_noise * block_mask
else:
block_mask = 1 - block_mask
normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-6)).to(dtype=x.dtype)
Expand All @@ -102,18 +116,18 @@ def drop_block_fast_2d(


class DropBlock2d(nn.Module):
""" DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
"""
"""DropBlock. See https://arxiv.org/pdf/1810.12890.pdf"""

def __init__(
self,
drop_prob: float = 0.1,
block_size: int = 7,
gamma_scale: float = 1.0,
with_noise: bool = False,
inplace: bool = False,
batchwise: bool = False,
fast: bool = True):
self,
drop_prob: float = 0.1,
block_size: int = 7,
gamma_scale: float = 1.0,
with_noise: bool = False,
inplace: bool = False,
batchwise: bool = False,
fast: bool = True,
):
super(DropBlock2d, self).__init__()
self.drop_prob = drop_prob
self.gamma_scale = gamma_scale
Expand All @@ -128,13 +142,15 @@ def forward(self, x):
return x
if self.fast:
return drop_block_fast_2d(
x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace)
x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace
)
else:
return drop_block_2d(
x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace, self.batchwise)
x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace, self.batchwise
)


def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):
def drop_path(x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).

This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
Expand All @@ -144,7 +160,7 @@ def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: b
'survival rate' as the argument.

"""
if drop_prob == 0. or not training:
if drop_prob == 0.0 or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
Expand All @@ -155,9 +171,9 @@ def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: b


class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""

def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
self.scale_by_keep = scale_by_keep
Expand All @@ -167,3 +183,53 @@ def forward(self, x):

def extra_repr(self):
return f'drop_prob={round(self.drop_prob,3):0.3f}'


def get_subset_index_and_scale_factor(x: torch.Tensor, drop_ratio: float = 0.0) -> Tuple[torch.Tensor, float]:
# random selection of the subset of the batch
B, _, _ = x.shape
selected_subset_size = max(int(B * (1 - drop_ratio)), 1)
selected_indicies = (torch.randperm(B, device=x.device))[:selected_subset_size]

return selected_indicies, B / selected_subset_size


def apply_residual(
x: torch.Tensor, selected_indicies: torch.Tensor, residual: torch.Tensor, residual_scale_factor: float
) -> torch.Tensor:
residual = residual.to(dtype=x.dtype)
x_flat, residual_flat = x.flatten(1), residual.flatten(1)

return torch.index_add(x_flat, 0, selected_indicies, residual_flat, alpha=residual_scale_factor).view_as(x)


def efficient_drop_path(
x: torch.Tensor, func: Callable[[torch.Tensor], torch.Tensor], drop_ratio: float = 0.0, training: bool = False
) -> torch.Tensor:
"""Efficient Drop Path implementation.
Ref impl: https://github.com/facebookresearch/dinov2/blob/main/dinov2/layers/block.py

Args:
x (torch.Tensor): input tensor
func (Callable[[torch.Tensor], torch.Tensor]): function to calculate residual
drop_ratio (float, optional): Drop ratio. Defaults to 0.0.
training (bool, optional): training mode. Defaults to False.

Returns:
torch.Tensor: output tensor
"""

if not training or drop_ratio == 0.0:
return func(x)

if drop_ratio <= 0.1:
# there is an overhead of using fast drop block for small drop ratio
return drop_path(func(x), drop_ratio, training=training)

# extract subset of the batch
selected_indicies, residual_scale_factor = get_subset_index_and_scale_factor(x, drop_ratio=drop_ratio)

# apply residual
residual = func(x[selected_indicies])

return apply_residual(x, selected_indicies, residual, residual_scale_factor)
8 changes: 4 additions & 4 deletions timm/models/_efficientnet_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def feature_info(self, location):
if location == 'expansion': # output of conv after act, same as block coutput
return dict(module='bn1', hook_type='forward', num_chs=self.conv.out_channels)
else: # location == 'bottleneck', block output
return dict(module='', hook_type='', num_chs=self.conv.out_channels)
return dict(module='', num_chs=self.conv.out_channels)

def forward(self, x):
shortcut = x
Expand Down Expand Up @@ -116,7 +116,7 @@ def feature_info(self, location):
if location == 'expansion': # after SE, input to PW
return dict(module='conv_pw', hook_type='forward_pre', num_chs=self.conv_pw.in_channels)
else: # location == 'bottleneck', block output
return dict(module='', hook_type='', num_chs=self.conv_pw.out_channels)
return dict(module='', num_chs=self.conv_pw.out_channels)

def forward(self, x):
shortcut = x
Expand Down Expand Up @@ -173,7 +173,7 @@ def feature_info(self, location):
if location == 'expansion': # after SE, input to PWL
return dict(module='conv_pwl', hook_type='forward_pre', num_chs=self.conv_pwl.in_channels)
else: # location == 'bottleneck', block output
return dict(module='', hook_type='', num_chs=self.conv_pwl.out_channels)
return dict(module='', num_chs=self.conv_pwl.out_channels)

def forward(self, x):
shortcut = x
Expand Down Expand Up @@ -266,7 +266,7 @@ def feature_info(self, location):
if location == 'expansion': # after SE, before PWL
return dict(module='conv_pwl', hook_type='forward_pre', num_chs=self.conv_pwl.in_channels)
else: # location == 'bottleneck', block output
return dict(module='', hook_type='', num_chs=self.conv_pwl.out_channels)
return dict(module='', num_chs=self.conv_pwl.out_channels)

def forward(self, x):
shortcut = x
Expand Down
16 changes: 10 additions & 6 deletions timm/models/_efficientnet_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,9 +370,7 @@ def __call__(self, in_chs, model_block_args):
stages = []
if model_block_args[0][0]['stride'] > 1:
# if the first block starts with a stride, we need to extract first level feat from stem
feature_info = dict(
module='act1', num_chs=in_chs, stage=0, reduction=current_stride,
hook_type='forward' if self.feature_location != 'bottleneck' else '')
feature_info = dict(module='bn1', num_chs=in_chs, stage=0, reduction=current_stride)
self.features.append(feature_info)

# outer list of block_args defines the stacks
Expand Down Expand Up @@ -418,10 +416,16 @@ def __call__(self, in_chs, model_block_args):
# stash feature module name and channel info for model feature extraction
if extract_features:
feature_info = dict(
stage=stack_idx + 1, reduction=current_stride, **block.feature_info(self.feature_location))
module_name = f'blocks.{stack_idx}.{block_idx}'
stage=stack_idx + 1,
reduction=current_stride,
**block.feature_info(self.feature_location),
)
leaf_name = feature_info.get('module', '')
feature_info['module'] = '.'.join([module_name, leaf_name]) if leaf_name else module_name
if leaf_name:
feature_info['module'] = '.'.join([f'blocks.{stack_idx}.{block_idx}', leaf_name])
else:
assert last_block
feature_info['module'] = f'blocks.{stack_idx}'
self.features.append(feature_info)

total_block_idx += 1 # incr global block idx (across all stacks)
Expand Down
3 changes: 2 additions & 1 deletion timm/models/_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,13 @@ class FeatureInfo:

def __init__(self, feature_info: List[Dict], out_indices: Tuple[int]):
prev_reduction = 1
for fi in feature_info:
for i, fi in enumerate(feature_info):
# sanity check the mandatory fields, there may be additional fields depending on the model
assert 'num_chs' in fi and fi['num_chs'] > 0
assert 'reduction' in fi and fi['reduction'] >= prev_reduction
prev_reduction = fi['reduction']
assert 'module' in fi
fi.setdefault('index', i)
self.out_indices = out_indices
self.info = feature_info

Expand Down