Skip to content

Commit

Permalink
torch.hub: add support for DOFA and Swin models (#2052)
Browse files Browse the repository at this point in the history
* torch.hub: add support for DOFA and Swin models

* Fix tests

* Add *args support to DOFA to match other models
  • Loading branch information
adamjstewart committed May 12, 2024
1 parent 94bd5c7 commit dbfe7fa
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 15 deletions.
18 changes: 16 additions & 2 deletions hubconf.py
Expand Up @@ -7,8 +7,22 @@
* https://pytorch.org/docs/stable/hub.html
"""

from torchgeo.models import resnet18, resnet50, vit_small_patch16_224
from torchgeo.models import (
dofa_base_patch16_224,
dofa_large_patch16_224,
resnet18,
resnet50,
swin_v2_b,
vit_small_patch16_224,
)

__all__ = ('resnet18', 'resnet50', 'vit_small_patch16_224')
__all__ = (
'dofa_base_patch16_224',
'dofa_large_patch16_224',
'resnet18',
'resnet50',
'swin_v2_b',
'vit_small_patch16_224',
)

dependencies = ['timm']
22 changes: 20 additions & 2 deletions tests/models/test_api.py
Expand Up @@ -9,10 +9,14 @@
from torchvision.models._api import WeightsEnum

from torchgeo.models import (
DOFABase16_Weights,
DOFALarge16_Weights,
ResNet18_Weights,
ResNet50_Weights,
Swin_V2_B_Weights,
ViTSmall16_Weights,
dofa_base_patch16_224,
dofa_large_patch16_224,
get_model,
get_model_weights,
get_weight,
Expand All @@ -23,8 +27,22 @@
vit_small_patch16_224,
)

builders = [resnet18, resnet50, vit_small_patch16_224, swin_v2_b]
enums = [ResNet18_Weights, ResNet50_Weights, ViTSmall16_Weights, Swin_V2_B_Weights]
builders = [
dofa_base_patch16_224,
dofa_large_patch16_224,
resnet18,
resnet50,
swin_v2_b,
vit_small_patch16_224,
]
enums = [
DOFABase16_Weights,
DOFALarge16_Weights,
ResNet18_Weights,
ResNet50_Weights,
Swin_V2_B_Weights,
ViTSmall16_Weights,
]


@pytest.mark.parametrize('builder', builders)
Expand Down
18 changes: 15 additions & 3 deletions torchgeo/models/api.py
Expand Up @@ -16,26 +16,38 @@
import torch.nn as nn
from torchvision.models._api import WeightsEnum

from .dofa import (
DOFABase16_Weights,
DOFALarge16_Weights,
dofa_base_patch16_224,
dofa_large_patch16_224,
)
from .resnet import ResNet18_Weights, ResNet50_Weights, resnet18, resnet50
from .swin import Swin_V2_B_Weights, swin_v2_b
from .vit import ViTSmall16_Weights, vit_small_patch16_224

_model = {
'dofa_base_patch16_224': dofa_base_patch16_224,
'dofa_large_patch16_224': dofa_large_patch16_224,
'resnet18': resnet18,
'resnet50': resnet50,
'vit_small_patch16_224': vit_small_patch16_224,
'swin_v2_b': swin_v2_b,
'vit_small_patch16_224': vit_small_patch16_224,
}

_model_weights = {
dofa_base_patch16_224: DOFABase16_Weights,
dofa_large_patch16_224: DOFALarge16_Weights,
resnet18: ResNet18_Weights,
resnet50: ResNet50_Weights,
vit_small_patch16_224: ViTSmall16_Weights,
swin_v2_b: Swin_V2_B_Weights,
vit_small_patch16_224: ViTSmall16_Weights,
'dofa_base_patch16_224': DOFABase16_Weights,
'dofa_large_patch16_224': DOFALarge16_Weights,
'resnet18': ResNet18_Weights,
'resnet50': ResNet50_Weights,
'vit_small_patch16_224': ViTSmall16_Weights,
'swin_v2_b': Swin_V2_B_Weights,
'vit_small_patch16_224': ViTSmall16_Weights,
}


Expand Down
24 changes: 16 additions & 8 deletions torchgeo/models/dofa.py
Expand Up @@ -415,7 +415,7 @@ class DOFALarge16_Weights(WeightsEnum): # type: ignore[misc]
)


def dofa_small_patch16_224(**kwargs: Any) -> DOFA:
def dofa_small_patch16_224(*args: Any, **kwargs: Any) -> DOFA:
"""Dynamic One-For-All (DOFA) small patch size 16 model.
If you use this model in your research, please cite the following paper:
Expand All @@ -425,17 +425,19 @@ def dofa_small_patch16_224(**kwargs: Any) -> DOFA:
.. versionadded:: 0.6
Args:
*args: Additional arguments to pass to :class:`DOFA`.
**kwargs: Additional keywork arguments to pass to :class:`DOFA`.
Returns:
A DOFA small 16 model.
"""
model = DOFA(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
kwargs |= {'patch_size': 16, 'embed_dim': 384, 'depth': 12, 'num_heads': 6}
model = DOFA(*args, **kwargs)
return model


def dofa_base_patch16_224(
weights: DOFABase16_Weights | None = None, **kwargs: Any
weights: DOFABase16_Weights | None = None, *args: Any, **kwargs: Any
) -> DOFA:
"""Dynamic One-For-All (DOFA) base patch size 16 model.
Expand All @@ -447,12 +449,14 @@ def dofa_base_patch16_224(
Args:
weights: Pre-trained model weights to use.
*args: Additional arguments to pass to :class:`DOFA`.
**kwargs: Additional keywork arguments to pass to :class:`DOFA`.
Returns:
A DOFA base 16 model.
"""
model = DOFA(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
kwargs |= {'patch_size': 16, 'embed_dim': 768, 'depth': 12, 'num_heads': 12}
model = DOFA(*args, **kwargs)

if weights:
missing_keys, unexpected_keys = model.load_state_dict(
Expand All @@ -471,7 +475,7 @@ def dofa_base_patch16_224(


def dofa_large_patch16_224(
weights: DOFALarge16_Weights | None = None, **kwargs: Any
weights: DOFALarge16_Weights | None = None, *args: Any, **kwargs: Any
) -> DOFA:
"""Dynamic One-For-All (DOFA) large patch size 16 model.
Expand All @@ -483,12 +487,14 @@ def dofa_large_patch16_224(
Args:
weights: Pre-trained model weights to use.
*args: Additional arguments to pass to :class:`DOFA`.
**kwargs: Additional keywork arguments to pass to :class:`DOFA`.
Returns:
A DOFA large 16 model.
"""
model = DOFA(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs)
kwargs |= {'patch_size': 16, 'embed_dim': 1024, 'depth': 24, 'num_heads': 16}
model = DOFA(*args, **kwargs)

if weights:
missing_keys, unexpected_keys = model.load_state_dict(
Expand All @@ -506,7 +512,7 @@ def dofa_large_patch16_224(
return model


def dofa_huge_patch16_224(**kwargs: Any) -> DOFA:
def dofa_huge_patch16_224(*args: Any, **kwargs: Any) -> DOFA:
"""Dynamic One-For-All (DOFA) huge patch size 16 model.
If you use this model in your research, please cite the following paper:
Expand All @@ -516,10 +522,12 @@ def dofa_huge_patch16_224(**kwargs: Any) -> DOFA:
.. versionadded:: 0.6
Args:
*args: Additional arguments to pass to :class:`DOFA`.
**kwargs: Additional keywork arguments to pass to :class:`DOFA`.
Returns:
A DOFA huge 16 model.
"""
model = DOFA(patch_size=14, embed_dim=1280, depth=32, num_heads=16, **kwargs)
kwargs |= {'patch_size': 14, 'embed_dim': 1280, 'depth': 32, 'num_heads': 16}
model = DOFA(*args, **kwargs)
return model

0 comments on commit dbfe7fa

Please sign in to comment.