New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
different layer names with official pytorch ResNet #33
Comments
I would like to keep readability and minimality of the model definition with the abstraction, for caffe-to-pytorch porting and comparison in versions. To leverage the torchvision ResNet, I think it's more simple to dynamically modify it to DeepLab. # PyTorch ResNet-101 -> DeepLab v2
# A bit different in the stride position
import torch
import torch.nn as nn
from torchvision import models
from collections import OrderedDict
from libs.models.deeplabv2 import _ASPPModule
model = models.resnet101(pretrained=True)
# Layer 3 (OS=16 -> OS=8)
model.layer3[0].conv2.stride = (1, 1)
model.layer3[0].downsample[0].stride = (1, 1)
for m in model.layer3[1:]:
m.conv2.padding = (2, 2)
m.conv2.dilation = (2, 2)
# Layer 4 (OS=32 -> OS=8)
model.layer4[0].conv2.stride = (1, 1)
model.layer4[0].downsample[0].stride = (1, 1)
for m in model.layer4[1:]:
m.conv2.padding = (4, 4)
m.conv2.dilation = (4, 4)
# Remove "avgpool" and "fc", and add ASPP
model = list(model.named_children())[:-2]
model += [("aspp", _ASPPModule(2048, 21, [6, 12, 18, 24]))]
model = nn.Sequential(OrderedDict(model))
# The model output is sub-sampled by 8 instead of 32
image = torch.randn(1, 3, 224, 224)
logit = model(image)
print(logit.shape) |
I haven't considered dynamic modification. Thanks. |
Closed
Closed
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Although the
_ConvBatchNormReLU
abstraction is very handy for building the model, it requires extra conversion when loading a pretrained ResNet model as the names of parameters differ. Could you use the official pytorch ResNet code? Thanks.The text was updated successfully, but these errors were encountered: