Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,9 @@

# models with forward_intermediates() and support for FeatureGetterNet features_only wrapper
FEAT_INTER_FILTERS = [
'vit_*', 'twins_*', 'deit*', 'beit*', 'mvitv2*', 'eva*', 'samvit_*', 'flexivit*',
'cait_*', 'xcit_*', 'volo_*',
'vision_transformer', 'vision_transformer_sam', 'vision_transformer_hybrid', 'vision_transformer_relpos',
'beit', 'mvitv2', 'eva', 'cait', 'xcit', 'volo', 'twins', 'deit', 'swin_transformer', 'swin_transformer_v2',
'swin_transformer_v2_cr', 'maxxvit', 'efficientnet', 'mobilenetv3', 'levit', 'efficientformer', 'resnet'
]

# transformer / hybrid models don't support full set of spatial / feature APIs and/or have spatial output.
Expand Down Expand Up @@ -388,13 +389,12 @@ def test_model_forward_features(model_name, batch_size):

@pytest.mark.features
@pytest.mark.timeout(120)
@pytest.mark.parametrize('model_name', list_models(FEAT_INTER_FILTERS, exclude_filters=EXCLUDE_FILTERS))
@pytest.mark.parametrize('model_name', list_models(module=FEAT_INTER_FILTERS, exclude_filters=EXCLUDE_FILTERS + ['*pruned*']))
@pytest.mark.parametrize('batch_size', [1])
def test_model_forward_intermediates_features(model_name, batch_size):
"""Run a single forward pass with each model in feature extraction mode"""
model = create_model(model_name, pretrained=False, features_only=True)
model = create_model(model_name, pretrained=False, features_only=True, feature_cls='getter')
model.eval()
print(model.feature_info.out_indices)
expected_channels = model.feature_info.channels()
expected_reduction = model.feature_info.reduction()

Expand All @@ -420,7 +420,7 @@ def test_model_forward_intermediates_features(model_name, batch_size):

@pytest.mark.features
@pytest.mark.timeout(120)
@pytest.mark.parametrize('model_name', list_models(FEAT_INTER_FILTERS, exclude_filters=EXCLUDE_FILTERS))
@pytest.mark.parametrize('model_name', list_models(module=FEAT_INTER_FILTERS, exclude_filters=EXCLUDE_FILTERS + ['*pruned*']))
@pytest.mark.parametrize('batch_size', [1])
def test_model_forward_intermediates(model_name, batch_size):
"""Run a single forward pass with each model in feature extraction mode"""
Expand All @@ -429,18 +429,19 @@ def test_model_forward_intermediates(model_name, batch_size):
feature_info = timm.models.FeatureInfo(model.feature_info, len(model.feature_info))
expected_channels = feature_info.channels()
expected_reduction = feature_info.reduction()
assert len(expected_channels) >= 4 # all models here should have at least 4 feature levels by default, some 5 or 6
assert len(expected_channels) >= 3 # all models here should have at least 3 feature levels

input_size = _get_input_size(model=model, target=TARGET_FFEAT_SIZE)
if max(input_size) > MAX_FFEAT_SIZE:
pytest.skip("Fixed input size model > limit.")
output_fmt = getattr(model, 'output_fmt', 'NCHW')
output_fmt = 'NCHW' # NOTE output_fmt determined by forward_intermediates() arg, not model attribute
feat_axis = get_channel_dim(output_fmt)
spatial_axis = get_spatial_dim(output_fmt)
import math

output, intermediates = model.forward_intermediates(
torch.randn((batch_size, *input_size)),
output_fmt=output_fmt,
)
assert len(expected_channels) == len(intermediates)
spatial_size = input_size[-2:]
Expand Down
9 changes: 7 additions & 2 deletions timm/layers/adaptive_avgmax_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def __init__(
super(SelectAdaptivePool2d, self).__init__()
assert input_fmt in ('NCHW', 'NHWC')
self.pool_type = pool_type or '' # convert other falsy values to empty string for consistent TS typing
pool_type = pool_type.lower()
if not pool_type:
self.pool = nn.Identity() # pass through
self.flatten = nn.Flatten(1) if flatten else nn.Identity()
Expand All @@ -145,8 +146,10 @@ def __init__(
self.pool = FastAdaptiveAvgMaxPool(flatten, input_fmt=input_fmt)
elif pool_type.endswith('max'):
self.pool = FastAdaptiveMaxPool(flatten, input_fmt=input_fmt)
else:
elif pool_type == 'fast' or pool_type.endswith('avg'):
self.pool = FastAdaptiveAvgPool(flatten, input_fmt=input_fmt)
else:
assert False, 'Invalid pool type: %s' % pool_type
self.flatten = nn.Identity()
else:
assert input_fmt == 'NCHW'
Expand All @@ -156,8 +159,10 @@ def __init__(
self.pool = AdaptiveCatAvgMaxPool2d(output_size)
elif pool_type == 'max':
self.pool = nn.AdaptiveMaxPool2d(output_size)
else:
elif pool_type == 'avg':
self.pool = nn.AdaptiveAvgPool2d(output_size)
else:
assert False, 'Invalid pool type: %s' % pool_type
self.flatten = nn.Flatten(1) if flatten else nn.Identity()

def is_identity(self):
Expand Down
42 changes: 28 additions & 14 deletions timm/models/_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
import os
from copy import deepcopy
from typing import Optional, Dict, Callable, Any, Tuple
from typing import Any, Callable, Dict, List, Optional, Tuple

from torch import nn as nn
from torch.hub import load_state_dict_from_url
Expand Down Expand Up @@ -359,15 +359,15 @@ def build_model_with_cfg(
* pruning config / model adaptation

Args:
model_cls (nn.Module): model class
variant (str): model variant name
pretrained (bool): load pretrained weights
pretrained_cfg (dict): model's pretrained weight/task config
model_cfg (Optional[Dict]): model's architecture config
feature_cfg (Optional[Dict]: feature extraction adapter config
pretrained_strict (bool): load pretrained weights strictly
pretrained_filter_fn (Optional[Callable]): filter callable for pretrained weights
kwargs_filter (Optional[Tuple]): kwargs to filter before passing to model
model_cls: model class
variant: model variant name
pretrained: load pretrained weights
pretrained_cfg: model's pretrained weight/task config
model_cfg: model's architecture config
feature_cfg: feature extraction adapter config
pretrained_strict: load pretrained weights strictly
pretrained_filter_fn: filter callable for pretrained weights
kwargs_filter: kwargs to filter before passing to model
**kwargs: model args passed through to model __init__
"""
pruned = kwargs.pop('pruned', False)
Expand All @@ -392,6 +392,8 @@ def build_model_with_cfg(
feature_cfg.setdefault('out_indices', (0, 1, 2, 3, 4))
if 'out_indices' in kwargs:
feature_cfg['out_indices'] = kwargs.pop('out_indices')
if 'feature_cls' in kwargs:
feature_cfg['feature_cls'] = kwargs.pop('feature_cls')

# Instantiate the model
if model_cfg is None:
Expand All @@ -418,24 +420,36 @@ def build_model_with_cfg(

# Wrap the model in a feature extraction module if enabled
if features:
feature_cls = FeatureListNet
output_fmt = getattr(model, 'output_fmt', None)
if output_fmt is not None:
feature_cfg.setdefault('output_fmt', output_fmt)
use_getter = False
if 'feature_cls' in feature_cfg:
feature_cls = feature_cfg.pop('feature_cls')
if isinstance(feature_cls, str):
feature_cls = feature_cls.lower()

# flatten_sequential only valid for some feature extractors
if feature_cls not in ('dict', 'list', 'hook'):
feature_cfg.pop('flatten_sequential', None)

if 'hook' in feature_cls:
feature_cls = FeatureHookNet
elif feature_cls == 'list':
feature_cls = FeatureListNet
elif feature_cls == 'dict':
feature_cls = FeatureDictNet
elif feature_cls == 'fx':
feature_cls = FeatureGraphNet
elif feature_cls == 'getter':
use_getter = True
feature_cls = FeatureGetterNet
else:
assert False, f'Unknown feature class {feature_cls}'
else:
feature_cls = FeatureListNet

output_fmt = getattr(model, 'output_fmt', None)
if output_fmt is not None and not use_getter: # don't set default for intermediate feat getter
feature_cfg.setdefault('output_fmt', output_fmt)

model = feature_cls(model, **feature_cfg)
model.pretrained_cfg = pretrained_cfg_for_features(pretrained_cfg) # add back pretrained cfg
model.default_cfg = model.pretrained_cfg # alias for rename backwards compat (default_cfg -> pretrained_cfg)
Expand Down
7 changes: 4 additions & 3 deletions timm/models/_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ def __init__(
out_map: Optional[Sequence[Union[int, str]]] = None,
return_dict: bool = False,
output_fmt: str = 'NCHW',
no_rewrite: bool = False,
no_rewrite: Optional[bool] = None,
flatten_sequential: bool = False,
default_hook_type: str = 'forward',
):
Expand All @@ -385,7 +385,8 @@ def __init__(
self.return_dict = return_dict
self.output_fmt = Format(output_fmt)
self.grad_checkpointing = False

if no_rewrite is None:
no_rewrite = not flatten_sequential
layers = OrderedDict()
hooks = []
if no_rewrite:
Expand Down Expand Up @@ -467,7 +468,7 @@ def __init__(
self.out_indices = out_indices
self.out_map = out_map
self.return_dict = return_dict
self.output_fmt = output_fmt
self.output_fmt = Format(output_fmt)
self.norm = norm

def forward(self, x):
Expand Down
4 changes: 3 additions & 1 deletion timm/models/_features_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
has_fx_feature_extraction = False

# Layers we went to treat as leaf modules
from timm.layers import Conv2dSame, ScaledStdConv2dSame, CondConv2d, StdConv2dSame
from timm.layers import Conv2dSame, ScaledStdConv2dSame, CondConv2d, StdConv2dSame, Format
from timm.layers.non_local_attn import BilinearAttnTransform
from timm.layers.pool2d_same import MaxPool2dSame, AvgPool2dSame
from timm.layers.norm_act import (
Expand Down Expand Up @@ -108,12 +108,14 @@ def __init__(
model: nn.Module,
out_indices: Tuple[int, ...],
out_map: Optional[Dict] = None,
output_fmt: str = 'NCHW',
):
super().__init__()
assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction'
self.feature_info = _get_feature_info(model, out_indices)
if out_map is not None:
assert len(out_map) == len(out_indices)
self.output_fmt = Format(output_fmt)
return_nodes = _get_return_layers(self.feature_info, out_map)
self.graph_module = create_feature_extractor(model, return_nodes)

Expand Down
13 changes: 11 additions & 2 deletions timm/models/_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def _expand_filter(filter: str):

def list_models(
filter: Union[str, List[str]] = '',
module: str = '',
module: Union[str, List[str]] = '',
pretrained: bool = False,
exclude_filters: Union[str, List[str]] = '',
name_matches_cfg: bool = False,
Expand Down Expand Up @@ -217,7 +217,16 @@ def list_models(
# FIXME should this be default behaviour? or default to include_tags=True?
include_tags = pretrained

all_models: Set[str] = _module_to_models[module] if module else set(_model_entrypoints.keys())
if not module:
all_models: Set[str] = set(_model_entrypoints.keys())
else:
if isinstance(module, str):
all_models: Set[str] = _module_to_models[module]
else:
assert isinstance(module, Sequence)
all_models: Set[str] = set()
for m in module:
all_models.update(_module_to_models[m])
all_models = all_models - _deprecated_models.keys() # remove deprecated models from listings

if include_tags:
Expand Down
11 changes: 6 additions & 5 deletions timm/models/beit.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ def forward_intermediates(
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
return_prefix_tokens: bool = False,
norm: bool = False,
stop_early: bool = True,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
Expand All @@ -424,7 +424,7 @@ def forward_intermediates(
Returns:

"""
assert output_fmt in ('NCHW', 'NLC'), 'Output format for ViT features must be one of NCHW or NLC.'
assert output_fmt in ('NCHW', 'NLC'), 'Output format must be one of NCHW or NLC.'
reshape = output_fmt == 'NCHW'
intermediates = []
take_indices, max_index = feature_take_indices(len(self.blocks), indices)
Expand All @@ -436,6 +436,7 @@ def forward_intermediates(
if self.pos_embed is not None:
x = x + self.pos_embed
x = self.pos_drop(x)

rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
blocks = self.blocks
Expand Down Expand Up @@ -469,19 +470,19 @@ def forward_intermediates(

def prune_intermediate_layers(
self,
n: Union[int, List[int], Tuple[int]] = 1,
indices: Union[int, List[int], Tuple[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(len(self.blocks), n)
take_indices, max_index = feature_take_indices(len(self.blocks), indices)
self.blocks = self.blocks[:max_index + 1] # truncate blocks
if prune_norm:
self.norm = nn.Identity()
if prune_head:
self.fc_norm = nn.Identity()
self.head = nn.Identity()
self.reset_classifier(0, '')
return take_indices

def forward_features(self, x):
Expand Down
11 changes: 6 additions & 5 deletions timm/models/cait.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ def forward_intermediates(
x: torch.Tensor,
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
norm: bool = False,
stop_early: bool = True,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
Expand All @@ -357,7 +357,7 @@ def forward_intermediates(
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
"""
assert output_fmt in ('NCHW', 'NLC'), 'Output format for ViT features must be one of NCHW or NLC.'
assert output_fmt in ('NCHW', 'NLC'), 'Output format must be one of NCHW or NLC.'
reshape = output_fmt == 'NCHW'
intermediates = []
take_indices, max_index = feature_take_indices(len(self.blocks), indices)
Expand All @@ -367,6 +367,7 @@ def forward_intermediates(
x = self.patch_embed(x)
x = x + self.pos_embed
x = self.pos_drop(x)

if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
blocks = self.blocks
else:
Expand Down Expand Up @@ -397,19 +398,19 @@ def forward_intermediates(

def prune_intermediate_layers(
self,
n: Union[int, List[int], Tuple[int]] = 1,
indices: Union[int, List[int], Tuple[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(len(self.blocks), n)
take_indices, max_index = feature_take_indices(len(self.blocks), indices)
self.blocks = self.blocks[:max_index + 1] # truncate blocks
if prune_norm:
self.norm = nn.Identity()
if prune_head:
self.blocks_token_only = nn.ModuleList() # prune token blocks with head
self.head = nn.Identity()
self.reset_classifier(0, '')
return take_indices

def forward_features(self, x):
Expand Down
Loading