Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion avg_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import os
import glob
import hashlib
from timm.models.helpers import load_state_dict
from timm.models import load_state_dict

parser = argparse.ArgumentParser(description='PyTorch Checkpoint Averager')
parser.add_argument('--input', default='', type=str, metavar='PATH',
Expand Down
3 changes: 2 additions & 1 deletion benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
import torch.nn.parallel

from timm.data import resolve_data_config
from timm.models import create_model, is_model, list_models, set_fast_norm
from timm.layers import set_fast_norm
from timm.models import create_model, is_model, list_models
from timm.optim import create_optimizer_v2
from timm.utils import setup_default_logging, set_jit_fuser, decay_batch_step, check_batch_size_retry

Expand Down
2 changes: 1 addition & 1 deletion clean_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import hashlib
import shutil
from collections import OrderedDict
from timm.models.helpers import load_state_dict
from timm.models import load_state_dict

parser = argparse.ArgumentParser(description='PyTorch Checkpoint Cleaner')
parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
Expand Down
5 changes: 2 additions & 3 deletions hubconf.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
dependencies = ['torch']
from timm.models import registry

globals().update(registry._model_entrypoints)
import timm
globals().update(timm.models._registry._model_entrypoints)
9 changes: 4 additions & 5 deletions inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,23 @@

Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman)
"""
import os
import time
import argparse
import json
import logging
import os
import time
from contextlib import suppress
from functools import partial

import numpy as np
import pandas as pd
import torch

from timm.models import create_model, apply_test_time_pool, load_checkpoint
from timm.data import create_dataset, create_loader, resolve_data_config
from timm.layers import apply_test_time_pool
from timm.models import create_model
from timm.utils import AverageMeter, setup_default_logging, set_jit_fuser



try:
from apex import amp
has_apex = True
Expand Down
5 changes: 1 addition & 4 deletions tests/test_layers.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
import pytest
import torch
import torch.nn as nn
import platform
import os

from timm.models.layers import create_act_layer, get_act_layer, set_layer_config
from timm.layers import create_act_layer, set_layer_config


class MLP(nn.Module):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import timm
from timm import list_models, create_model, set_scriptable, get_pretrained_cfg_value
from timm.models.fx_features import _leaf_modules, _autowrap_functions
from timm.models._features_fx import _leaf_modules, _autowrap_functions

if hasattr(torch._C, '_jit_set_profiling_executor'):
# legacy executor is too slow to compile large models for unit tests
Expand Down
2 changes: 1 addition & 1 deletion timm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .version import __version__
from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable
from .models import create_model, list_models, list_pretrained, is_model, list_modules, model_entrypoint, \
is_scriptable, is_exportable, set_scriptable, set_exportable, \
is_model_pretrained, get_pretrained_cfg, get_pretrained_cfg_value
3 changes: 2 additions & 1 deletion timm/data/readers/class_map.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import pickle


def load_class_map(map_or_filename, root=''):
if isinstance(map_or_filename, dict):
assert dict, 'class_map dict must be non-empty'
Expand All @@ -14,7 +15,7 @@ def load_class_map(map_or_filename, root=''):
with open(class_map_path) as f:
class_to_idx = {v.strip(): k for k, v in enumerate(f)}
elif class_map_ext == '.pkl':
with open(class_map_path,'rb') as f:
with open(class_map_path, 'rb') as f:
class_to_idx = pickle.load(f)
else:
assert False, f'Unsupported class map file extension ({class_map_ext}).'
Expand Down
44 changes: 44 additions & 0 deletions timm/layers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from .activations import *
from .adaptive_avgmax_pool import \
adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d
from .blur_pool import BlurPool2d
from .classifier import ClassifierHead, create_classifier
from .cond_conv2d import CondConv2d, get_condconv_initializer
from .config import is_exportable, is_scriptable, is_no_jit, set_exportable, set_scriptable, set_no_jit,\
set_layer_config
from .conv2d_same import Conv2dSame, conv2d_same
from .conv_bn_act import ConvNormAct, ConvNormActAa, ConvBnAct
from .create_act import create_act_layer, get_act_layer, get_act_fn
from .create_attn import get_attn, create_attn
from .create_conv2d import create_conv2d
from .create_norm import get_norm_layer, create_norm_layer
from .create_norm_act import get_norm_act_layer, create_norm_act_layer, get_norm_act_layer
from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path
from .eca import EcaModule, CecaModule, EfficientChannelAttn, CircularEfficientChannelAttn
from .evo_norm import EvoNorm2dB0, EvoNorm2dB1, EvoNorm2dB2,\
EvoNorm2dS0, EvoNorm2dS0a, EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a
from .fast_norm import is_fast_norm, set_fast_norm, fast_group_norm, fast_layer_norm
from .filter_response_norm import FilterResponseNormTlu2d, FilterResponseNormAct2d
from .gather_excite import GatherExcite
from .global_context import GlobalContext
from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible, extend_tuple
from .inplace_abn import InplaceAbn
from .linear import Linear
from .mixed_conv2d import MixedConv2d
from .mlp import Mlp, GluMlp, GatedMlp, ConvMlp
from .non_local_attn import NonLocalAttn, BatNonLocalAttn
from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d
from .norm_act import BatchNormAct2d, GroupNormAct, convert_sync_batchnorm
from .padding import get_padding, get_same_padding, pad_same
from .patch_embed import PatchEmbed
from .pool2d_same import AvgPool2dSame, create_pool2d
from .squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite
from .selective_kernel import SelectiveKernel
from .separable_conv import SeparableConv2d, SeparableConvNormAct
from .space_to_depth import SpaceToDepthModule
from .split_attn import SplitAttn
from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model
from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame
from .test_time_pool import TestTimePoolHead, apply_test_time_pool
from .trace_utils import _assert, _float_to_int
from .weight_init import trunc_normal_, trunc_normal_tf_, variance_scaling_, lecun_normal_
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
22 changes: 14 additions & 8 deletions timm/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,18 @@
from .xception_aligned import *
from .xcit import *

from .factory import create_model, parse_model_name, safe_model_name
from .helpers import load_checkpoint, resume_checkpoint, model_parameters
from .layers import TestTimePoolHead, apply_test_time_pool
from .layers import convert_splitbn_model, convert_sync_batchnorm
from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable, is_no_jit, set_no_jit
from .layers import set_fast_norm
from .pretrained import PretrainedCfg, filter_pretrained_cfg, generate_default_cfgs, split_model_name_tag
from .registry import register_model, model_entrypoint, list_models, list_pretrained, is_model, list_modules,\
from ._builder import build_model_with_cfg, load_pretrained, load_custom_pretrained, resolve_pretrained_cfg, \
set_pretrained_download_progress, set_pretrained_check_hash
from ._factory import create_model, parse_model_name, safe_model_name
from ._features import FeatureInfo, FeatureHooks, FeatureHookNet, FeatureListNet, FeatureDictNet
from ._features_fx import FeatureGraphNet, GraphExtractNet, create_feature_extractor, \
register_notrace_module, register_notrace_function
from ._helpers import clean_state_dict, load_state_dict, load_checkpoint, remap_checkpoint, resume_checkpoint
from ._hub import load_model_config_from_hf, load_state_dict_from_hf, push_to_hf_hub
from ._manipulate import model_parameters, named_apply, named_modules, named_modules_with_params, \
group_modules, group_parameters, checkpoint_seq, adapt_input_conv
from ._pretrained import PretrainedCfg, DefaultCfg, \
filter_pretrained_cfg, generate_default_cfgs, split_model_name_tag
from ._prune import adapt_model_from_string
from ._registry import register_model, model_entrypoint, list_models, list_pretrained, is_model, list_modules, \
is_model_in_modules, is_model_pretrained, get_pretrained_cfg, get_pretrained_cfg_value
Loading