Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
torch>=1.4.0
torchvision>=0.5.0
torch>=1.7
torchvision
pyyaml
huggingface_hub
9 changes: 6 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,15 @@
# 3 - Alpha
# 4 - Beta
# 5 - Production/Stable
'Development Status :: 3 - Alpha',
'Development Status :: 4 - Beta',
'Intended Audience :: Education',
'Intended Audience :: Science/Research',
'License :: OSI Approved :: Apache Software License',
'Programming Language :: Python :: 3.6',
'Programming Language :: Python :: 3.7',
'Programming Language :: Python :: 3.8',
'Programming Language :: Python :: 3.9',
'Programming Language :: Python :: 3.10',
'Topic :: Scientific/Engineering',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'Topic :: Software Development',
Expand All @@ -40,9 +42,10 @@
],

# Note that this is a string of words separated by whitespace, not a list.
keywords='pytorch pretrained models efficientnet mobilenetv3 mnasnet',
keywords='pytorch pretrained models efficientnet mobilenetv3 mnasnet resnet vision transformer vit',
packages=find_packages(exclude=['convert', 'tests', 'results']),
include_package_data=True,
install_requires=['torch >= 1.4', 'torchvision'],
install_requires=['torch >= 1.7', 'torchvision', 'pyyaml', 'huggingface_hub'],
python_requires='>=3.6',
)

2 changes: 2 additions & 0 deletions timm/data/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,5 @@
IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5)
IMAGENET_DPN_MEAN = (124 / 255, 117 / 255, 104 / 255)
IMAGENET_DPN_STD = tuple([1 / (.0167 * 255)] * 3)
OPENAI_CLIP_MEAN = (0.48145466, 0.4578275, 0.40821073)
OPENAI_CLIP_STD = (0.26862954, 0.26130258, 0.27577711)
8 changes: 7 additions & 1 deletion timm/models/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,9 @@ def _resolve_pretrained_source(pretrained_cfg):
# hf-hub available as alternate weight source in default_cfg
load_from = 'hf-hub'
pretrained_loc = hf_hub_id
if load_from == 'hf-hub' and 'hf_hub_filename' in pretrained_cfg:
# if a filename override is set, return tuple for location w/ (hub_id, filename)
pretrained_loc = pretrained_loc, pretrained_cfg['hf_hub_filename']
return load_from, pretrained_loc


Expand Down Expand Up @@ -246,7 +249,10 @@ def load_pretrained(
pretrained_loc, map_location='cpu', progress=_DOWNLOAD_PROGRESS, check_hash=_CHECK_HASH)
elif load_from == 'hf-hub':
_logger.info(f'Loading pretrained weights from Hugging Face hub ({pretrained_loc})')
state_dict = load_state_dict_from_hf(pretrained_loc)
if isinstance(pretrained_loc, (list, tuple)):
state_dict = load_state_dict_from_hf(*pretrained_loc)
else:
state_dict = load_state_dict_from_hf(pretrained_loc)
else:
_logger.warning("No pretrained weights exist or were found for this model. Using random initialization.")
return
Expand Down
9 changes: 5 additions & 4 deletions timm/models/hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from torch.hub import _get_torch_home as get_dir

from timm import __version__

try:
from huggingface_hub import HfApi, HfFolder, Repository, hf_hub_download, hf_hub_url
hf_hub_download = partial(hf_hub_download, library_name="timm", library_version=__version__)
Expand Down Expand Up @@ -55,7 +56,7 @@ def download_cached_file(url, check_hash=True, progress=False):

def has_hf_hub(necessary=False):
if not _has_hf_hub and necessary:
# if no HF Hub module installed and it is necessary to continue, raise error
# if no HF Hub module installed, and it is necessary to continue, raise error
raise RuntimeError(
'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.')
return _has_hf_hub
Expand All @@ -78,7 +79,7 @@ def load_cfg_from_json(json_file: Union[str, os.PathLike]):

def _download_from_hf(model_id: str, filename: str):
hf_model_id, hf_revision = hf_split(model_id)
return hf_hub_download(hf_model_id, filename, revision=hf_revision, cache_dir=get_cache_dir('hf'))
return hf_hub_download(hf_model_id, filename, revision=hf_revision)


def load_model_config_from_hf(model_id: str):
Expand All @@ -91,9 +92,9 @@ def load_model_config_from_hf(model_id: str):
return pretrained_cfg, model_name


def load_state_dict_from_hf(model_id: str):
def load_state_dict_from_hf(model_id: str, filename: str = 'pytorch_model.bin'):
assert has_hf_hub(True)
cached_file = _download_from_hf(model_id, 'pytorch_model.bin')
cached_file = _download_from_hf(model_id, filename)
state_dict = torch.load(cached_file, map_location='cpu')
return state_dict

Expand Down
13 changes: 11 additions & 2 deletions timm/models/layers/patch_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,16 @@
class PatchEmbed(nn.Module):
""" 2D Image to Patch Embedding
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True):
def __init__(
self,
img_size=224,
patch_size=16,
in_chans=3,
embed_dim=768,
norm_layer=None,
flatten=True,
bias=True,
):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
Expand All @@ -25,7 +34,7 @@ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.flatten = flatten

self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()

def forward(self, x):
Expand Down
Loading