Skip to content

Commit

Permalink
Support PyTorch-1.2.0
Browse files Browse the repository at this point in the history
  • Loading branch information
sublee authored and GitHub Enterprise committed Aug 16, 2019
1 parent df681f9 commit c724226
Show file tree
Hide file tree
Showing 41 changed files with 2,543 additions and 576 deletions.
1 change: 1 addition & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ python:

# Supported PyTorch versions
env:
- PYTORCH=1.2.0
- PYTORCH=1.1.0
- PYTORCH=1.0.1
- PYTORCH=1.0.0
Expand Down
20 changes: 9 additions & 11 deletions examples/amoebanet/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""An AmoebaNet-D implementation but using :class:`nn.Sequential`. :func:`amoebanetd`
returns a :class:`nn.Sequential`.
"""
from collections import OrderedDict
from typing import Tuple
Expand Down Expand Up @@ -32,15 +31,16 @@ class ReLUConvBN(nn.Module):
def __init__(self, in_planes: int, out_planes: int, kernel_size: int, stride: int,
padding: int, affine: bool = True):
super().__init__()
self.op = nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(in_planes, out_planes, kernel_size,
stride=stride, padding=padding, bias=False),
nn.BatchNorm2d(out_planes, affine=affine)
)
self.relu = nn.ReLU(inplace=False)
self.conv = nn.Conv2d(in_planes, out_planes, kernel_size,
stride=stride, padding=padding, bias=False)
self.bn = nn.BatchNorm2d(out_planes, affine=affine)

def forward(self, x: torch.Tensor) -> torch.Tensor: # type: ignore
return self.op(x)
x = self.relu(x)
x = self.conv(x)
x = self.bn(x)
return x


class FactorizedReduce(nn.Module):
Expand Down Expand Up @@ -141,14 +141,12 @@ class Classifier(nn.Module):

def __init__(self, channel_prev: int, num_classes: int):
super().__init__()

self.global_pooling = nn.AvgPool2d(7)
self.classifier = nn.Linear(channel_prev, num_classes)

def forward(self, x: torch.Tensor) -> nn.Linear: # type: ignore
def forward(self, x: torch.Tensor) -> torch.Tensor: # type: ignore
s1 = self.global_pooling(x[1])
y = self.classifier(s1.view(s1.size(0), -1))

return y


Expand Down
9 changes: 7 additions & 2 deletions examples/resnet/bottleneck.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""A ResNet bottleneck implementation but using :class:`nn.Sequential`."""
from collections import OrderedDict
from typing import Dict, Optional, Tuple, Union
from typing import TYPE_CHECKING, Optional, Tuple, Union

from torch import Tensor
import torch.nn as nn
Expand All @@ -10,6 +10,11 @@
Tensors = Tuple[Tensor, ...]
TensorOrTensors = Union[Tensor, Tensors]

if TYPE_CHECKING:
NamedModules = OrderedDict[str, nn.Module]
else:
NamedModules = OrderedDict


def conv3x3(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
"""3x3 convolution with padding"""
Expand Down Expand Up @@ -72,7 +77,7 @@ def bottleneck(inplanes: int,
downsample: Optional[nn.Module] = None,
) -> nn.Sequential:
"""Creates a bottlenect block in ResNet as a :class:`nn.Sequential`."""
layers: Dict[str, nn.Module] = OrderedDict()
layers: NamedModules = OrderedDict()
layers['twin'] = Twin()

layers['conv1'] = Gutter(conv1x1(inplanes, planes))
Expand Down

0 comments on commit c724226

Please sign in to comment.