Skip to content

Commit

Permalink
Attach fully sequential ResNet-101 example
Browse files Browse the repository at this point in the history
  • Loading branch information
sublee authored and GitHub Enterprise committed Jun 12, 2019
1 parent 38248d0 commit e98cb57
Show file tree
Hide file tree
Showing 5 changed files with 297 additions and 1 deletion.
31 changes: 31 additions & 0 deletions examples/resnet/NOTICE
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
=======================================================================
torchvision's BSD 3-Clause License
=======================================================================

Copyright (c) Soumith Chintala 2016,
All rights reserved.

Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:

* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.

* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.

* Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
102 changes: 102 additions & 0 deletions examples/resnet/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
"""A ResNet implementation but using :class:`nn.Sequential`. :func:`resnet101`
returns a :class:`nn.Sequential` instead of ``ResNet``.
This code is transformed :mod:`torchvision.models.resnet`.
"""
from collections import OrderedDict
from typing import Any, List

from torch import Tensor
import torch.nn as nn

from resnet.bottleneck import bottleneck
from resnet.flatten import flatten

__all__ = ['resnet101']


class Flatten(nn.Module):
"""Flattens any input tensor into an 1-d tensor."""

def forward(self, x: Tensor): # type: ignore
return x.view(x.size(0), -1)


def build_resnet(layers: List[int],
num_classes: int = 1000,
) -> nn.Sequential:
"""Builds a ResNet as a simple sequential model.
Note:
The implementation is copied from :mod:`torchvision.models.resnet`.
"""
inplanes = 64

def make_layer(planes: int, blocks: int, stride: int = 1) -> nn.Sequential:
nonlocal inplanes

downsample = None
if stride != 1 or inplanes != planes * 4:
downsample = nn.Sequential(
nn.Conv2d(inplanes, planes * 4,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(planes * 4),
)

layers = []
layers.append(bottleneck(inplanes, planes, stride, downsample))
inplanes = planes * 4
for _ in range(1, blocks):
layers.append(bottleneck(inplanes, planes))

return nn.Sequential(*layers)

# Build ResNet as a sequential model.
model = nn.Sequential(OrderedDict([
('conv1', nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)),
('bn1', nn.BatchNorm2d(64)),
('relu', nn.ReLU()),
('maxpool', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),

('layer1', make_layer(64, layers[0])),
('layer2', make_layer(128, layers[1], stride=2)),
('layer3', make_layer(256, layers[2], stride=2)),
('layer4', make_layer(512, layers[3], stride=2)),

('avgpool', nn.AdaptiveAvgPool2d((1, 1))),
('flat', Flatten()),
('fc', nn.Linear(512 * 4, num_classes)),
]))

# Flatten nested sequentials.
model = flatten(model)

# Initialize weights for Conv2d and BatchNorm2d layers.
def init_weight(m: nn.Module) -> None:
if isinstance(m, nn.Conv2d):
assert isinstance(m.kernel_size, tuple)
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels

m.weight.requires_grad = False
m.weight.normal_(0, 2. / n**0.5)
m.weight.requires_grad = True

elif isinstance(m, nn.BatchNorm2d):
m.weight.requires_grad = False
m.weight.fill_(1)
m.weight.requires_grad = True

m.bias.requires_grad = False
m.bias.zero_()
m.bias.requires_grad = True

model.apply(init_weight)

return model


def resnet101(**kwargs: Any) -> nn.Sequential:
"""Constructs a ResNet-101 model."""
return build_resnet([3, 4, 23, 3], **kwargs)
77 changes: 77 additions & 0 deletions examples/resnet/bottleneck.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
"""A ResNet bottleneck implementation but using :class:`nn.Sequential`."""
from collections import OrderedDict
from typing import Dict, Optional, Tuple, Union

from torch import Tensor
import torch.nn as nn

__all__ = ['bottleneck']

Tensors = Tuple[Tensor, ...]
TensorOrTensors = Union[Tensor, Tensors]


def conv3x3(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=1, bias=False)


def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
"""1x1 convolution"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


class Twin(nn.Module):
def forward(self, # type: ignore
tensor: Tensor,
) -> Tuple[Tensor, Tensor]:
return tensor, tensor


class Gutter(nn.Module):
def __init__(self, module: nn.Module):
super().__init__()
self.module = module

def forward(self, # type: ignore
input_and_skip: Tuple[Tensor, Tensor],
) -> Tuple[Tensor, Tensor]:
input, skip = input_and_skip
output = self.module(input)
return output, skip


class Residual(nn.Module):
def __init__(self, downsample: Optional[nn.Module] = None):
super().__init__()
self.downsample = downsample

def forward(self, # type: ignore
input_and_identity: Tuple[Tensor, Tensor],
) -> Tensor:
input, identity = input_and_identity
if self.downsample is not None:
identity = self.downsample(identity)
return input + identity


def bottleneck(inplanes: int,
planes: int,
stride: int = 1,
downsample: Optional[nn.Module] = None,
) -> nn.Sequential:
"""Creates a bottlenect block in ResNet as a :class:`nn.Sequential`."""
layers: Dict[str, nn.Module] = OrderedDict()
layers['twin'] = Twin()

layers['conv1'] = Gutter(conv1x1(inplanes, planes))
layers['bn1'] = Gutter(nn.BatchNorm2d(planes))
layers['conv2'] = Gutter(conv3x3(planes, planes, stride))
layers['bn2'] = Gutter(nn.BatchNorm2d(planes))
layers['conv3'] = Gutter(conv1x1(planes, planes * 4))
layers['bn3'] = Gutter(nn.BatchNorm2d(planes * 4))
layers['residual'] = Residual(downsample)
layers['relu'] = nn.ReLU()

return nn.Sequential(layers)
22 changes: 22 additions & 0 deletions examples/resnet/flatten.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from collections import OrderedDict
from typing import Iterator, Tuple

from torch import nn


def flatten(module: nn.Sequential) -> nn.Sequential:
"""Flattens a nested sequential module."""
if not isinstance(module, nn.Sequential):
raise TypeError('not sequential')

return nn.Sequential(OrderedDict(_flatten(module)))


def _flatten(module: nn.Sequential) -> Iterator[Tuple[str, nn.Module]]:
for name, child in module.named_children():
# Flatten child sequential layers only.
if isinstance(child, nn.Sequential):
for sub_name, sub_child in _flatten(child):
yield ('%s_%s' % (name, sub_name), sub_child)
else:
yield (name, child)
66 changes: 65 additions & 1 deletion stubs/torch/nn/__init__.pyi
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#MODIFIED BY TORCHGPIPE
from typing import Any, Callable, Iterable, Iterator, Optional, Tuple, TypeVar, Union
from typing import Any, Callable, Dict, Iterable, Iterator, Optional, Tuple, TypeVar, Union, overload

from torch import Tensor, device

Expand Down Expand Up @@ -45,6 +45,10 @@ class Module:


class Sequential(Module):
@overload
def __init__(self, args: Dict[str, Module]) -> None: ...

@overload
def __init__(self, *args: Module) -> None: ...

def __iter__(self) -> Iterator[Module]: ...
Expand All @@ -61,4 +65,64 @@ class ModuleList(Module):
def __len__(self) -> int: ...
def __getitem__(self, index: int) -> Module: ...


class Linear(Module):
def __init__(self,
in_features: int,
out_features: int,
bias: bool = True,
) -> None: ...


class Conv2d(Module):
in_channels: int
out_channels: int
kernel_size: Union[int, Tuple[int, ...]]

weight: Tensor
bias: Tensor

def __init__(self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, ...]],
stride: int = 1,
padding: int = 0,
dilation: int = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = 'zeros',
) -> None: ...


class BatchNorm2d(Module):
weight: Tensor
bias: Tensor

def __init__(self,
num_features: int,
eps: float = 1e-5,
momentum: Optional[float] = 0.1,
affine: bool = True,
track_running_stats: bool = True,
) -> None: ...


class MaxPool2d(Module):
def __init__(self,
kernel_size: Union[int, Tuple[int, ...]],
stride: Optional[int] = None,
padding: int = 0,
dilation: int = 1,
return_indices: bool = False,
ceil_mode: bool = False) -> None: ...


class AdaptiveAvgPool2d(Module):
def __init__(self, output_size: Union[int, Tuple[int, ...]]) -> None: ...


class ReLU(Module):
def __init__(self, inplace: bool = False) -> None: ...

#END

0 comments on commit e98cb57

Please sign in to comment.