Skip to content

Commit

Permalink
experimental xresnet changes
Browse files Browse the repository at this point in the history
  • Loading branch information
Jeremy Howard committed Apr 6, 2019
1 parent c136e27 commit d3cc045
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 125 deletions.
3 changes: 2 additions & 1 deletion examples/train_imagenette.py
Expand Up @@ -49,10 +49,11 @@ def main(
bs_rat = bs/256
lr *= bs_rat

learn = (Learner(data, models.xresnet50(),
learn = (Learner(data, models.resnet50(),
metrics=[accuracy,top_k_accuracy], wd=1e-3, opt_func=opt_func,
bn_wd=False, true_wd=True, loss_func = LabelSmoothingCrossEntropy())
)
#print(learn.model); exit()
if mixup: learn = learn.mixup(alpha=0.2)
learn = learn.to_fp16(dynamic=True)
if gpu is None: learn.to_parallel()
Expand Down
6 changes: 3 additions & 3 deletions examples/train_imagenette_adv.py
Expand Up @@ -65,7 +65,7 @@ def bn_and_final(m):

def on_step(self, p, group, group_idx):
st = self.state[p]
alpha = (st['alpha_buffer'].sqrt() + group['eps']
alpha = ((st['alpha_buffer'] + group['eps']).sqrt()
) if 'alpha_buffer' in st else mom.new_tensor(1.)
clip = group['clip'] if 'clip' in group else 1e9
alr = (st['alpha_buffer']).clamp_min_(clip)
Expand All @@ -79,12 +79,12 @@ def main(
debias_mom: Param("Debias statistics", bool)=False,
debias_sqr: Param("Debias statistics", bool)=False,
opt: Param("Optimizer: 'adam','genopt','rms','sgd'", str)='genopt',
alpha: Param("Alpha", float)=0.9,
alpha: Param("Alpha", float)=0.99,
mom: Param("Momentum", float)=0.9,
eps: Param("epsilon", float)=1e-7,
decay: Param("Decay AvgStatistic (momentum)", bool)=False,
epochs: Param("Number of epochs", int)=5,
bs: Param("Batch size", int)=256,
bs: Param("Batch size", int)=128,
):
"""Distributed training of Imagenette.
Fastest multi-gpu speed is if you run with: python -m fastai.launch"""
Expand Down
192 changes: 71 additions & 121 deletions fastai/vision/models/xresnet.py
Expand Up @@ -3,128 +3,110 @@
import math
import torch.utils.model_zoo as model_zoo


__all__ = ['XResNet', 'xresnet18', 'xresnet34', 'xresnet50', 'xresnet101', 'xresnet152']

def init_cnn(m):
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
if getattr(m, 'bias', None) is not None: nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
for l in m.children(): init_cnn(l)

def conv(ni, nf, ks=3, stride=1, bias=False):
return nn.Conv2d(ni, nf, kernel_size=ks, stride=stride, padding=ks//2, bias=bias)

def conv_relu_bn_(ni, nf, ks=3, stride=1, rev=False):
layers = [conv(ni, nf, ks, stride=stride),
nn.ReLU(inplace=True),
nn.BatchNorm2d(ni if rev else nf)]
if rev: layers = reversed(layers)
return layers

def conv3x3(in_planes, out_planes, stride=1):
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
def conv_bn_relu(ni, nf, ks=3, stride=1):
return nn.Sequential(*conv_relu_bn_(ni, nf, ks=ks, stride=stride))

def bn_relu_conv(ni, nf, ks=3, stride=1):
return nn.Sequential(*conv_relu_bn_(ni, nf, ks=ks, stride=stride, rev=True))

class BasicBlock(nn.Module):
expansion = 1

def __init__(self, inplanes, planes, stride=1, downsample=None):
def __init__(self, ni, nf, stride=1, downsample=None):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2d(planes)
self.conv1 = bn_relu_conv(ni, nf, stride=stride)
self.conv2 = bn_relu_conv(nf, nf)
self.downsample = downsample
self.stride = stride

def forward(self, x):
residual = x

out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)

out = self.conv2(out)
out = self.bn2(out)

if self.downsample is not None: residual = self.downsample(x)

out += residual
out = self.relu(out)

return out

identity = x if self.downsample is None else self.downsample(x)
x = self.conv1(x)
x = self.conv2(x)
x += identity
return x

class Bottleneck(nn.Module):
expansion = 4

def __init__(self, inplanes, planes, stride=1, downsample=None):
def __init__(self, ni, nf, stride=1, downsample=None):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.conv1 = bn_relu_conv(ni, nf, 1)
self.conv2 = bn_relu_conv(nf, nf, stride=stride)
self.conv3 = bn_relu_conv(nf, nf * self.expansion, 1)
self.downsample = downsample
self.stride = stride

def forward(self, x):
residual = x

out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)

out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)

out = self.conv3(out)
out = self.bn3(out)

if self.downsample is not None: residual = self.downsample(x)

out += residual
out = self.relu(out)

return out

def conv2d(ni, nf, stride):
return nn.Sequential(nn.Conv2d(ni, nf, kernel_size=3, stride=stride, padding=1, bias=False),
nn.BatchNorm2d(nf), nn.ReLU(inplace=True))
identity = x if self.downsample is None else self.downsample(x)
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x += identity
return x

class XResNet(nn.Module):

def __init__(self, block, layers, num_classes=1000):
self.inplanes = 64
super(XResNet, self).__init__()
self.conv1 = conv2d(3, 32, 2)
self.conv2 = conv2d(32, 32, 1)
self.conv3 = conv2d(32, 64, 1)
self.ni = 64
super().__init__()
self.conv1 = conv_bn_relu(3, 32, stride=2)
self.conv2 = conv_bn_relu(32, 32)
self.conv3 = conv(32, 64)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
ni = 512*block.expansion
self.avgpool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Linear(512 * block.expansion, num_classes)
self.fc = nn.Sequential(
nn.ReLU(inplace=True),
nn.BatchNorm1d(ni),
nn.Linear(ni, num_classes))

for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
init_cnn(self)

for m in self.modules():
if isinstance(m, BasicBlock): m.bn2.weight = nn.Parameter(torch.zeros_like(m.bn2.weight))
if isinstance(m, Bottleneck): m.bn3.weight = nn.Parameter(torch.zeros_like(m.bn3.weight))
if isinstance(m, BasicBlock): nn.init.constant_(m.conv2[0].weight, 0.)
if isinstance(m, Bottleneck): nn.init.constant_(m.conv3[0].weight, 0.)
if isinstance(m, nn.Linear): m.weight.data.normal_(0, 0.01)

def _make_layer(self, block, planes, blocks, stride=1):
def _make_layer(self, block, nf, blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
if stride != 1 or self.ni != nf*block.expansion:
layers = []
if stride==2: layers.append(nn.AvgPool2d(kernel_size=2, stride=2))
layers += [
nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=1, bias=False),
nn.BatchNorm2d(planes * block.expansion) ]
conv(self.ni, nf*block.expansion, 1),
nn.BatchNorm2d(nf * block.expansion) ]
downsample = nn.Sequential(*layers)

layers = []
layers.append(block(self.inplanes, planes, stride, downsample))
self.inplanes = planes * block.expansion
for i in range(1, blocks): layers.append(block(self.inplanes, planes))
layers.append(block(self.ni, nf, stride, downsample))
self.ni = nf * block.expansion
for i in range(1, blocks): layers.append(block(self.ni, nf))
return nn.Sequential(*layers)

def forward(self, x):
Expand All @@ -144,58 +126,26 @@ def forward(self, x):

return x

model_urls = dict(xresnet34='xresnet34', xresnet50='xresnet50')

def xresnet18(pretrained=False, **kwargs):
"""Constructs a XResNet-18 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = XResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
if pretrained: model.load_state_dict(model_zoo.load_url(model_urls['xresnet18']))
def xresnet(block, n_layers, name, pre=False, **kwargs):
model = XResNet(block, n_layers, **kwargs)
#if pre: model.load_state_dict(model_zoo.load_url(model_urls[name]))
if pre: model.load_state_dict(torch.load(model_urls[name]))
return model

def xresnet18(pretrained=False, **kwargs):
return xresnet(BasicBlock, [2, 2, 2, 2], 'xresnet18', pre=pretrained, **kwargs)

def xresnet34(pretrained=False, **kwargs):
"""Constructs a XResNet-34 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = XResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
if pretrained: model.load_state_dict(model_zoo.load_url(model_urls['xresnet34']))
return model

return xresnet(BasicBlock, [3, 4, 6, 3], 'xresnet34', pre=pretrained, **kwargs)

def xresnet50(pretrained=False, **kwargs):
"""Constructs a XResNet-50 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = XResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
if pretrained: model.load_state_dict(model_zoo.load_url(model_urls['xresnet50']))
return model

return xresnet(Bottleneck, [3, 4, 6, 3], 'xresnet50', pre=pretrained, **kwargs)

def xresnet101(pretrained=False, **kwargs):
"""Constructs a XResNet-101 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = XResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
if pretrained: model.load_state_dict(model_zoo.load_url(model_urls['xresnet101']))
return model

return xresnet(Bottleneck, [3, 4, 23, 3], 'xresnet101', pre=pretrained, **kwargs)

def xresnet152(pretrained=False, **kwargs):
"""Constructs a XResNet-152 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = XResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
if pretrained: model.load_state_dict(model_zoo.load_url(model_urls['xresnet152']))
return model
return xresnet(Bottleneck, [3, 8, 36, 3], 'xresnet152', pre=pretrained, **kwargs)

0 comments on commit d3cc045

Please sign in to comment.