Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix MAE Initialization #182

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 8 additions & 10 deletions scripts/load_mae_vit.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
from collections import OrderedDict

import torch
from torch_em.model import UNETR

checkpoint = "imagenet.pth"
encoder_state = torch.load(checkpoint, map_location="cpu")["model"]
encoder_state = OrderedDict({
k: v for k, v in encoder_state.items()
if (k != "mask_token" and not k.startswith("decoder"))
})

unetr_model = UNETR(backbone="mae", encoder="vit_l", encoder_checkpoint=encoder_state)
def main():
checkpoint = "/home/nimanwai/mae_models/imagenet.pth"
unetr_model = UNETR(img_size=224, backbone="mae", encoder="vit_l", encoder_checkpoint=checkpoint)
print(unetr_model)


if __name__ == "__main__":
main()
16 changes: 14 additions & 2 deletions torch_em/model/unetr.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,18 @@ def _load_encoder_from_checkpoint(self, backbone, encoder, checkpoint):
encoder_state = torch.load(checkpoint)

elif backbone == "mae":
encoder_state = torch.load(checkpoint)
# vit initialization hints from:
# - https://github.com/facebookresearch/mae/blob/main/main_finetune.py#L233-L242
encoder_state = torch.load(checkpoint)["model"]
encoder_state = OrderedDict({
k: v for k, v in encoder_state.items()
if (k != "mask_token" and not k.startswith("decoder"))
})

# let's remove the `head` from our current encoder (as the MAE pretrained don't expect it)
current_encoder_state = self.encoder.state_dict()
if ("head.weight" in current_encoder_state) and ("head.bias" in current_encoder_state):
del self.encoder.head

else:
encoder_state = checkpoint
Expand All @@ -50,6 +61,7 @@ def _load_encoder_from_checkpoint(self, backbone, encoder, checkpoint):

def __init__(
self,
img_size: int = 1024,
backbone: str = "sam",
encoder: str = "vit_b",
decoder: Optional[nn.Module] = None,
Expand All @@ -65,7 +77,7 @@ def __init__(
self.use_mae_stats = use_mae_stats

print(f"Using {encoder} from {backbone.upper()}")
self.encoder = get_vision_transformer(backbone=backbone, model=encoder)
self.encoder = get_vision_transformer(img_size=img_size, backbone=backbone, model=encoder)
if encoder_checkpoint is not None:
self._load_encoder_from_checkpoint(backbone, encoder, encoder_checkpoint)

Expand Down
14 changes: 7 additions & 7 deletions torch_em/model/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def forward(self, x):
return x, list_from_encoder


def get_vision_transformer(backbone: str, model: str):
def get_vision_transformer(backbone: str, model: str, img_size: int = 1024):
if backbone == "sam":
if model == "vit_b":
encoder = ViT_Sam(
Expand Down Expand Up @@ -155,18 +155,18 @@ def get_vision_transformer(backbone: str, model: str):
elif backbone == "mae":
if model == "vit_b":
encoder = ViT_MAE(
patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6)
img_size=img_size, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)
)
elif model == "vit_l":
encoder = ViT_MAE(
patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6)
img_size=img_size, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4,
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)
)
elif model == "vit_h":
encoder = ViT_MAE(
patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6)
img_size=img_size, patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4,
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)
)
else:
raise ValueError(f"{model} is not supported by MAE. Currently vit_b, vit_l, vit_h are supported.")
Expand Down