Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Uses pretrained ResNet for U-Net encoder, closes #45 and #44 #46

Merged
merged 1 commit into from
Jul 3, 2018

Conversation

daniel-j-h
Copy link
Collaborator

@daniel-j-h daniel-j-h commented Jun 22, 2018

For #45 and #44.

This changeset switches out the standard U-Net encoder for a pre-trained ResNet50 encoder. In addition this changeset switches out learned deconvolutions for upsampling and uses nearest neaighbor upsampling followed by a convolution for refinement instead.

With this approach training will be faster, prediction will be way faster and more accurate, memory usage will be lower, and general happiness will be obtained. The only downside: right now it only works for three-channel inputs (RGB) and not with arbitrary channels. That's something for later down the line.

Thanks to Alexander Buslaev and Dzianis Dus for letting me pick their brains about this over dinner.

@daniel-j-h daniel-j-h force-pushed the issue/45 branch 2 times, most recently from 43e633d to 6f9957d Compare June 22, 2018 04:53
@daniel-j-h daniel-j-h changed the title Uses pretrained ResNet for U-Net encoder, closes #45 Uses pretrained ResNet for U-Net encoder, closes #45 and #44 Jun 22, 2018
@daniel-j-h daniel-j-h force-pushed the issue/45 branch 2 times, most recently from ddf347c to f0f2e5c Compare June 22, 2018 20:29
@@ -1,65 +0,0 @@
import os
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would we need rs stats when we want to try different model architectures like PSPNet or YOLO?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Our PSPNet has a ResNet backbone, too, see

# Backbone network we use to harvest convolutional image features from
self.resnet = resnet50(pretrained=pretrained)
# https://github.com/pytorch/vision/blob/c84aa9989f5256480487cafe280b521e50ddd113/torchvision/models/resnet.py#L101-L105
self.block0 = nn.Sequential(self.resnet.conv1, self.resnet.bn1, self.resnet.relu, self.resnet.maxpool)
# https://github.com/pytorch/vision/blob/c84aa9989f5256480487cafe280b521e50ddd113/torchvision/models/resnet.py#L106-L109
self.block1 = self.resnet.layer1
self.block2 = self.resnet.layer2
self.block3 = self.resnet.layer3
self.block4 = self.resnet.layer4
# See https://arxiv.org/abs/1606.02147v1 section 4: Information-preserving dimensionality changes
#
# "When downsampling, the first 1x1 projection of the convolutional branch is performed with a stride of 2
# in both dimensions, which effectively discards 75% of the input. Increasing the filter size to 2x2 allows
# to take the full input into consideration, and thus improves the information flow and accuracy."
#
# We can not change the kernel_size on the fly but we can change the stride instead from (2, 2) to (1, 1).
assert self.block3[0].downsample[0].stride == (2, 2)
assert self.block4[0].downsample[0].stride == (2, 2)
self.block3[0].downsample[0].stride = (1, 1)
self.block4[0].downsample[0].stride = (1, 1)
# See https://arxiv.org/abs/1511.07122 and https://arxiv.org/abs/1706.05587 for dilated convolutions.
# ResNets reduce spatial dimension too much for segmentation => patch in dilated convolutions.
for name, module in self.block3.named_modules():
if "conv2" in name:
module.dilation = (2, 2)
module.padding = (2, 2)
module.stride = (1, 1)
for name, module in self.block4.named_modules():
if "conv2" in name:
module.dilation = (4, 4)
module.padding = (4, 4)
module.stride = (1, 1)

For object detection I want us to implement the RetinaNet.

Then we need a feature pyramid network and we can use both the ResNet backbone as well as the feature pyramid network for segmentation and object detection.


'''

import torch
import torch.nn as nn

from torchvision.models import resnet50
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be part of the model configurations? I mean, would it help if users saw that by default Robosat would use a Resnet50 and they could choose other configurations available in torchvision if they want to?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe. At the moment we also still have the model name in the config

# The model to use. Depending on the model different attributes might be available.
model = 'unet'

but don't use it at all and simply always use our U-Net adaption. The PSPNet in multi-class mode will not work with how we currently serialize and deserialize segmentation probabilities, too.

I think we need to adapt a bit more code than just swapping out the ResNets:

robosat/robosat/unet.py

Lines 94 to 108 in 5a1479c

self.resnet = resnet50(pretrained=pretrained)
self.enc0 = nn.Sequential(self.resnet.conv1, self.resnet.bn1, self.resnet.relu, self.resnet.maxpool)
self.enc1 = self.resnet.layer1 # 256
self.enc2 = self.resnet.layer2 # 512
self.enc3 = self.resnet.layer3 # 1024
self.enc4 = self.resnet.layer4 # 2048
self.center = DecoderBlock(2048, num_filters * 8)
self.dec0 = DecoderBlock(2048 + num_filters * 8, num_filters * 8)
self.dec1 = DecoderBlock(1024 + num_filters * 8, num_filters * 8)
self.dec2 = DecoderBlock(512 + num_filters * 8, num_filters * 2)
self.dec3 = DecoderBlock(256 + num_filters * 2, num_filters * 2 * 2)
self.dec4 = DecoderBlock(num_filters * 2 * 2, num_filters)

What we could do is abstract away the feature extractor, then our UNet class would take such a FeatureExtractor and get feature maps of specific resolutions and channels from it, and construct a corresponding decoder. I don't think it's worth it, though.

robosat/unet.py Outdated

self.block3 = Block(128, 256)
self.down3 = Downsample()
Also known as AlbuNet due to its inventor Alexander Buslaev.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🙏

@bkowshik
Copy link
Contributor

Running predict.py with the resnet encoder fails with the following error. Posting here and my plan is to big a little deeper and come back with more notes.

Traceback (most recent call last):
  File "/usr/lib/python3.6/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/usr/lib/python3.6/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/app/robosat/robosat/tools/__main__.py", line 41, in <module>
    args.func(args)
  File "/app/robosat/robosat/tools/predict.py", line 83, in main
    outputs = net(images)
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 491, in __call__
    result = self.forward(*input, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/parallel/data_parallel.py", line 112, in forward
    return self.module(*inputs[0], **kwargs[0])
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 491, in __call__
    result = self.forward(*input, **kwargs)
  File "/app/robosat/robosat/unet.py", line 131, in forward
    dec0 = self.dec0(torch.cat([enc4, center], dim=1))
RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 1. Got 16 and 17 in dimension 2 at /pytorch/aten/src/THC/generic/THCTensorMath.cu:111

@daniel-j-h
Copy link
Collaborator Author

Arbitrary sizes are no longer possible with the ResNet based encoder. Image resolution needs to be a multiple of 32 e.g. an overlap of 0 or 32 works. This is due to how the ResNet downsamples by two (check out either the ResNet paper or print the ResNet architecture from the PyTorch models).

Here are the tensor sizes after each encoder layer for 512x512 images and an overlaop of 0 and 32, respectively (batchsize, channels, xres, yres):

torch.Size([1, 64, 128, 128])
torch.Size([1, 256, 128, 128])
torch.Size([1, 512, 64, 64])
torch.Size([1, 1024, 32, 32])
torch.Size([1, 2048, 16, 16])
torch.Size([1, 64, 144, 144])
torch.Size([1, 256, 144, 144])
torch.Size([1, 512, 72, 72])
torch.Size([1, 1024, 36, 36])
torch.Size([1, 2048, 18, 18])

@bkowshik
Copy link
Contributor

bkowshik commented Jul 3, 2018

Rebased this branch again with master in preparation to merge. 🚀

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants