Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.hub: add support for DOFA and Swin models #2052

Merged
merged 3 commits into from
May 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
18 changes: 16 additions & 2 deletions hubconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,22 @@
* https://pytorch.org/docs/stable/hub.html
"""

from torchgeo.models import resnet18, resnet50, vit_small_patch16_224
from torchgeo.models import (
dofa_base_patch16_224,
dofa_large_patch16_224,
resnet18,
resnet50,
swin_v2_b,
vit_small_patch16_224,
)

__all__ = ('resnet18', 'resnet50', 'vit_small_patch16_224')
__all__ = (
'dofa_base_patch16_224',
'dofa_large_patch16_224',
'resnet18',
'resnet50',
'swin_v2_b',
'vit_small_patch16_224',
)

dependencies = ['timm']
22 changes: 20 additions & 2 deletions tests/models/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,14 @@
from torchvision.models._api import WeightsEnum

from torchgeo.models import (
DOFABase16_Weights,
DOFALarge16_Weights,
ResNet18_Weights,
ResNet50_Weights,
Swin_V2_B_Weights,
ViTSmall16_Weights,
dofa_base_patch16_224,
dofa_large_patch16_224,
get_model,
get_model_weights,
get_weight,
Expand All @@ -23,8 +27,22 @@
vit_small_patch16_224,
)

builders = [resnet18, resnet50, vit_small_patch16_224, swin_v2_b]
enums = [ResNet18_Weights, ResNet50_Weights, ViTSmall16_Weights, Swin_V2_B_Weights]
builders = [
dofa_base_patch16_224,
dofa_large_patch16_224,
resnet18,
resnet50,
swin_v2_b,
vit_small_patch16_224,
]
enums = [
DOFABase16_Weights,
DOFALarge16_Weights,
ResNet18_Weights,
ResNet50_Weights,
Swin_V2_B_Weights,
ViTSmall16_Weights,
]


@pytest.mark.parametrize('builder', builders)
Expand Down
18 changes: 15 additions & 3 deletions torchgeo/models/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,26 +16,38 @@
import torch.nn as nn
from torchvision.models._api import WeightsEnum

from .dofa import (
DOFABase16_Weights,
DOFALarge16_Weights,
dofa_base_patch16_224,
dofa_large_patch16_224,
)
from .resnet import ResNet18_Weights, ResNet50_Weights, resnet18, resnet50
from .swin import Swin_V2_B_Weights, swin_v2_b
from .vit import ViTSmall16_Weights, vit_small_patch16_224

_model = {
'dofa_base_patch16_224': dofa_base_patch16_224,
'dofa_large_patch16_224': dofa_large_patch16_224,
'resnet18': resnet18,
'resnet50': resnet50,
'vit_small_patch16_224': vit_small_patch16_224,
'swin_v2_b': swin_v2_b,
'vit_small_patch16_224': vit_small_patch16_224,
}

_model_weights = {
dofa_base_patch16_224: DOFABase16_Weights,
dofa_large_patch16_224: DOFALarge16_Weights,
resnet18: ResNet18_Weights,
resnet50: ResNet50_Weights,
vit_small_patch16_224: ViTSmall16_Weights,
swin_v2_b: Swin_V2_B_Weights,
vit_small_patch16_224: ViTSmall16_Weights,
'dofa_base_patch16_224': DOFABase16_Weights,
'dofa_large_patch16_224': DOFALarge16_Weights,
'resnet18': ResNet18_Weights,
'resnet50': ResNet50_Weights,
'vit_small_patch16_224': ViTSmall16_Weights,
'swin_v2_b': Swin_V2_B_Weights,
'vit_small_patch16_224': ViTSmall16_Weights,
}


Expand Down
24 changes: 16 additions & 8 deletions torchgeo/models/dofa.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,7 @@ class DOFALarge16_Weights(WeightsEnum): # type: ignore[misc]
)


def dofa_small_patch16_224(**kwargs: Any) -> DOFA:
def dofa_small_patch16_224(*args: Any, **kwargs: Any) -> DOFA:
"""Dynamic One-For-All (DOFA) small patch size 16 model.

If you use this model in your research, please cite the following paper:
Expand All @@ -427,17 +427,19 @@ def dofa_small_patch16_224(**kwargs: Any) -> DOFA:
.. versionadded:: 0.6

Args:
*args: Additional arguments to pass to :class:`DOFA`.
**kwargs: Additional keywork arguments to pass to :class:`DOFA`.

Returns:
A DOFA small 16 model.
"""
model = DOFA(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
kwargs |= {'patch_size': 16, 'embed_dim': 384, 'depth': 12, 'num_heads': 6}
model = DOFA(*args, **kwargs)
return model


def dofa_base_patch16_224(
weights: DOFABase16_Weights | None = None, **kwargs: Any
weights: DOFABase16_Weights | None = None, *args: Any, **kwargs: Any
) -> DOFA:
"""Dynamic One-For-All (DOFA) base patch size 16 model.

Expand All @@ -449,12 +451,14 @@ def dofa_base_patch16_224(

Args:
weights: Pre-trained model weights to use.
*args: Additional arguments to pass to :class:`DOFA`.
**kwargs: Additional keywork arguments to pass to :class:`DOFA`.

Returns:
A DOFA base 16 model.
"""
model = DOFA(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
kwargs |= {'patch_size': 16, 'embed_dim': 768, 'depth': 12, 'num_heads': 12}
model = DOFA(*args, **kwargs)

if weights:
missing_keys, unexpected_keys = model.load_state_dict(
Expand All @@ -473,7 +477,7 @@ def dofa_base_patch16_224(


def dofa_large_patch16_224(
weights: DOFALarge16_Weights | None = None, **kwargs: Any
weights: DOFALarge16_Weights | None = None, *args: Any, **kwargs: Any
) -> DOFA:
"""Dynamic One-For-All (DOFA) large patch size 16 model.

Expand All @@ -485,12 +489,14 @@ def dofa_large_patch16_224(

Args:
weights: Pre-trained model weights to use.
*args: Additional arguments to pass to :class:`DOFA`.
**kwargs: Additional keywork arguments to pass to :class:`DOFA`.

Returns:
A DOFA large 16 model.
"""
model = DOFA(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs)
kwargs |= {'patch_size': 16, 'embed_dim': 1024, 'depth': 24, 'num_heads': 16}
model = DOFA(*args, **kwargs)

if weights:
missing_keys, unexpected_keys = model.load_state_dict(
Expand All @@ -508,7 +514,7 @@ def dofa_large_patch16_224(
return model


def dofa_huge_patch16_224(**kwargs: Any) -> DOFA:
def dofa_huge_patch16_224(*args: Any, **kwargs: Any) -> DOFA:
"""Dynamic One-For-All (DOFA) huge patch size 16 model.

If you use this model in your research, please cite the following paper:
Expand All @@ -518,10 +524,12 @@ def dofa_huge_patch16_224(**kwargs: Any) -> DOFA:
.. versionadded:: 0.6

Args:
*args: Additional arguments to pass to :class:`DOFA`.
**kwargs: Additional keywork arguments to pass to :class:`DOFA`.

Returns:
A DOFA huge 16 model.
"""
model = DOFA(patch_size=14, embed_dim=1280, depth=32, num_heads=16, **kwargs)
kwargs |= {'patch_size': 14, 'embed_dim': 1280, 'depth': 32, 'num_heads': 16}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we warn the user if these overwrite kwargs?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's worth adding, just not in this PR. Torchvision does issue a warning as of pytorch/vision#5618.

model = DOFA(*args, **kwargs)
return model