Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

different layer names with official pytorch ResNet #33

Closed
Keson96 opened this issue Sep 16, 2018 · 2 comments
Closed

different layer names with official pytorch ResNet #33

Keson96 opened this issue Sep 16, 2018 · 2 comments

Comments

@Keson96
Copy link

Keson96 commented Sep 16, 2018

Although the _ConvBatchNormReLU abstraction is very handy for building the model, it requires extra conversion when loading a pretrained ResNet model as the names of parameters differ. Could you use the official pytorch ResNet code? Thanks.

@kazuto1011
Copy link
Owner

kazuto1011 commented Sep 16, 2018

I would like to keep readability and minimality of the model definition with the abstraction, for caffe-to-pytorch porting and comparison in versions.

To leverage the torchvision ResNet, I think it's more simple to dynamically modify it to DeepLab.

# PyTorch ResNet-101 -> DeepLab v2
# A bit different in the stride position
import torch
import torch.nn as nn
from torchvision import models
from collections import OrderedDict

from libs.models.deeplabv2 import _ASPPModule

model = models.resnet101(pretrained=True)

# Layer 3 (OS=16 -> OS=8)
model.layer3[0].conv2.stride = (1, 1)
model.layer3[0].downsample[0].stride = (1, 1)
for m in model.layer3[1:]:
    m.conv2.padding = (2, 2)
    m.conv2.dilation = (2, 2)

# Layer 4 (OS=32 -> OS=8)
model.layer4[0].conv2.stride = (1, 1)
model.layer4[0].downsample[0].stride = (1, 1)
for m in model.layer4[1:]:
    m.conv2.padding = (4, 4)
    m.conv2.dilation = (4, 4)

# Remove "avgpool" and "fc", and add ASPP
model = list(model.named_children())[:-2]
model += [("aspp", _ASPPModule(2048, 21, [6, 12, 18, 24]))]
model = nn.Sequential(OrderedDict(model))

# The model output is sub-sampled by 8 instead of 32
image = torch.randn(1, 3, 224, 224)
logit = model(image)

print(logit.shape)

@Keson96
Copy link
Author

Keson96 commented Sep 17, 2018

I haven't considered dynamic modification. Thanks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants