Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Segmentation Pretrained Weights #1046

Conversation

isaaccorley
Copy link
Collaborator

@isaaccorley isaaccorley commented Jan 25, 2023

This PR addresses part of #1044 adds the ability to load pretrained weights from a backbone model e.g. ResNet into a semantic segmentation encoder. This works for the segmentation-models-pytorch Unet and DeepLabv3 implementations but not the FCN because we aren't using a backbone for that.

@isaaccorley isaaccorley self-assigned this Jan 25, 2023
@github-actions github-actions bot added the trainers PyTorch Lightning trainers label Jan 25, 2023
@adamjstewart
Copy link
Collaborator

This only supports ImageNet pretrained weights, which is pretty useless for us. We really want support for our weight enums like we have for regression/classification/byol. Could we manually load weights from state dict ourselves like we do with our timm-based models?

@adamjstewart
Copy link
Collaborator

Oh wait, I didn't read the full PR. Maybe this does cover all the features we want. You'll definitely want to update the docstring though.

@adamjstewart
Copy link
Collaborator

Also needs tests

@adamjstewart adamjstewart modified the milestone: 0.5.0 Jan 26, 2023
@adamjstewart adamjstewart added backwards-incompatible Changes that are not backwards compatible labels Jan 26, 2023
@nilsleh
Copy link
Collaborator

nilsleh commented Jan 27, 2023

Maybe also put a note in the docs that not all our pretrained models (i.e. the ViTs) will be compatible according to the smp docs or maybe there is a way around, not sure.

@isaaccorley
Copy link
Collaborator Author

The other thing I realized is that by default smp will use torchvision for resnet backbones e.g. resnet18. If you want to force using timm implementations you have to prefix the encoder names with tu- e.g. tu-resnet18.

@calebrob6
Copy link
Member

We could just only allow timm pretrained backbones here

@isaaccorley
Copy link
Collaborator Author

This actually needs #1049 to be solved otherwise this only supports resnet backbones.

@isaaccorley
Copy link
Collaborator Author

We could just only allow timm pretrained backbones here

So do we want to prefix tu- to all the backbones or do we want the user to do it?

@adamjstewart
Copy link
Collaborator

That is a tough question. I would be fine with only supporting timm backbones for now, which would suggest prepending tu- for the user so that it matches other trainers. But there may come a time in which we want to support non-timm backbones because pretrained models are not always made with timm.

calebrob6
calebrob6 previously approved these changes Feb 23, 2023
@calebrob6
Copy link
Member

We could prefix "tu-" if it results in a valid timm backbone to give ourselves room to grow.

@calebrob6 calebrob6 closed this Feb 23, 2023
@calebrob6 calebrob6 reopened this Feb 23, 2023
@calebrob6 calebrob6 force-pushed the trainers/segmentation-pretrained-weights branch from da2fb0c to 86a0716 Compare February 23, 2023 23:31
@isaaccorley isaaccorley force-pushed the trainers/segmentation-pretrained-weights branch 4 times, most recently from 6324e42 to a7695d7 Compare April 25, 2023 14:38
@isaaccorley isaaccorley force-pushed the trainers/segmentation-pretrained-weights branch from a7695d7 to fb488f4 Compare April 26, 2023 00:28
@github-actions github-actions bot added the testing Continuous integration testing label Apr 26, 2023
@calebrob6
Copy link
Member

Ugh, also the horrific model checkpoint tests

@calebrob6
Copy link
Member

We need this :)

@calebrob6
Copy link
Member

@isaaccorley does this work with our custom ResNet weights?

@isaaccorley
Copy link
Collaborator Author

I haven't tried using the weights enums but in theory it should since the encoders are just resnet models.

@adamjstewart
Copy link
Collaborator

It should be the same loading code and docstring description as all of the other trainers.

@isaaccorley
Copy link
Collaborator Author

isaaccorley commented May 3, 2023

@calebrob6 I just made it work with the ResNet weights

from torchgeo.trainers import SemanticSegmentationTask
from torchgeo.models import ResNet50_Weights

model = SemanticSegmentationTask(
    model="unet",
    backbone="resnet50",
    weights=ResNet50_Weights.SENTINEL2_RGB_MOCO,
    in_channels=3,
    num_classes=2,
    loss="ce",
    ignore_index=0,
    learning_rate=3e-4,
    learning_rate_schedule_patience=5,
    freeze_backbone=False,
    freeze_decoder=False,
)

@calebrob6
Copy link
Member

LGTM!

image

calebrob6
calebrob6 previously approved these changes May 3, 2023
tests/conf/inria.yaml Outdated Show resolved Hide resolved
tests/trainers/test_segmentation.py Show resolved Hide resolved
torchgeo/trainers/segmentation.py Outdated Show resolved Hide resolved
@adamjstewart adamjstewart enabled auto-merge (squash) May 3, 2023 20:54
@adamjstewart adamjstewart merged commit 1973d77 into microsoft:main May 3, 2023
@isaaccorley isaaccorley deleted the trainers/segmentation-pretrained-weights branch May 3, 2023 21:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backwards-incompatible Changes that are not backwards compatible testing Continuous integration testing trainers PyTorch Lightning trainers
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants