-
Notifications
You must be signed in to change notification settings - Fork 573
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
shaoniangu
committed
Jul 24, 2019
1 parent
4bf49c2
commit 618d39b
Showing
6 changed files
with
203 additions
and
16 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
# Experiment all tricks with center loss : 256x128-bs16x4-warmup10-erase0_5-labelsmooth_on-laststride1-bnneck_on-triplet_centerloss0_0005 | ||
# Dataset 1: market1501 | ||
# imagesize: 256x128 | ||
# batchsize: 16x4 | ||
# warmup_step 10 | ||
# random erase prob 0.5 | ||
# labelsmooth: on | ||
# last stride 1 | ||
# bnneck on | ||
# with center loss | ||
python3 tools/train.py --config_file='configs/softmax_triplet_with_center.yml' MODEL.DEVICE_ID "('1')" MODEL.NAME "('resnet50_ibn_a')" MODEL.PRETRAIN_PATH "('/home/haoluo/gu/ibn_a.pth')" DATASETS.NAMES "('market1501')" DATASETS.ROOT_DIR "('/home/haoluo/data')" OUTPUT_DIR "('/home/haoluo/log/gu/reid_baseline_review/Opensource_test/market1501/Experiment-resnet50_ibn_a-all-tricks-tri_center-256x128-bs16x4-warmup10-erase0_5-labelsmooth_on-laststride1-bnneck_on-triplet_centerloss0_0005')" |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,181 @@ | ||
import torch | ||
import torch.nn as nn | ||
import math | ||
import torch.utils.model_zoo as model_zoo | ||
|
||
|
||
__all__ = ['ResNet_IBN', 'resnet50_ibn_a', 'resnet101_ibn_a', | ||
'resnet152_ibn_a'] | ||
|
||
|
||
model_urls = { | ||
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', | ||
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', | ||
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', | ||
} | ||
|
||
|
||
class IBN(nn.Module): | ||
def __init__(self, planes): | ||
super(IBN, self).__init__() | ||
half1 = int(planes/2) | ||
self.half = half1 | ||
half2 = planes - half1 | ||
self.IN = nn.InstanceNorm2d(half1, affine=True) | ||
self.BN = nn.BatchNorm2d(half2) | ||
|
||
def forward(self, x): | ||
split = torch.split(x, self.half, 1) | ||
out1 = self.IN(split[0].contiguous()) | ||
out2 = self.BN(split[1].contiguous()) | ||
out = torch.cat((out1, out2), 1) | ||
return out | ||
|
||
|
||
class Bottleneck_IBN(nn.Module): | ||
expansion = 4 | ||
|
||
def __init__(self, inplanes, planes, ibn=False, stride=1, downsample=None): | ||
super(Bottleneck_IBN, self).__init__() | ||
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) | ||
if ibn: | ||
self.bn1 = IBN(planes) | ||
else: | ||
self.bn1 = nn.BatchNorm2d(planes) | ||
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, | ||
padding=1, bias=False) | ||
self.bn2 = nn.BatchNorm2d(planes) | ||
self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) | ||
self.bn3 = nn.BatchNorm2d(planes * self.expansion) | ||
self.relu = nn.ReLU(inplace=True) | ||
self.downsample = downsample | ||
self.stride = stride | ||
|
||
def forward(self, x): | ||
residual = x | ||
|
||
out = self.conv1(x) | ||
out = self.bn1(out) | ||
out = self.relu(out) | ||
|
||
out = self.conv2(out) | ||
out = self.bn2(out) | ||
out = self.relu(out) | ||
|
||
out = self.conv3(out) | ||
out = self.bn3(out) | ||
|
||
if self.downsample is not None: | ||
residual = self.downsample(x) | ||
|
||
out += residual | ||
out = self.relu(out) | ||
|
||
return out | ||
|
||
|
||
class ResNet_IBN(nn.Module): | ||
|
||
def __init__(self, last_stride, block, layers, num_classes=1000): | ||
scale = 64 | ||
self.inplanes = scale | ||
super(ResNet_IBN, self).__init__() | ||
self.conv1 = nn.Conv2d(3, scale, kernel_size=7, stride=2, padding=3, | ||
bias=False) | ||
self.bn1 = nn.BatchNorm2d(scale) | ||
self.relu = nn.ReLU(inplace=True) | ||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) | ||
self.layer1 = self._make_layer(block, scale, layers[0]) | ||
self.layer2 = self._make_layer(block, scale*2, layers[1], stride=2) | ||
self.layer3 = self._make_layer(block, scale*4, layers[2], stride=2) | ||
self.layer4 = self._make_layer(block, scale*8, layers[3], stride=last_stride) | ||
self.avgpool = nn.AvgPool2d(7) | ||
self.fc = nn.Linear(scale * 8 * block.expansion, num_classes) | ||
|
||
for m in self.modules(): | ||
if isinstance(m, nn.Conv2d): | ||
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels | ||
m.weight.data.normal_(0, math.sqrt(2. / n)) | ||
elif isinstance(m, nn.BatchNorm2d): | ||
m.weight.data.fill_(1) | ||
m.bias.data.zero_() | ||
elif isinstance(m, nn.InstanceNorm2d): | ||
m.weight.data.fill_(1) | ||
m.bias.data.zero_() | ||
|
||
def _make_layer(self, block, planes, blocks, stride=1): | ||
downsample = None | ||
if stride != 1 or self.inplanes != planes * block.expansion: | ||
downsample = nn.Sequential( | ||
nn.Conv2d(self.inplanes, planes * block.expansion, | ||
kernel_size=1, stride=stride, bias=False), | ||
nn.BatchNorm2d(planes * block.expansion), | ||
) | ||
|
||
layers = [] | ||
ibn = True | ||
if planes == 512: | ||
ibn = False | ||
layers.append(block(self.inplanes, planes, ibn, stride, downsample)) | ||
self.inplanes = planes * block.expansion | ||
for i in range(1, blocks): | ||
layers.append(block(self.inplanes, planes, ibn)) | ||
|
||
return nn.Sequential(*layers) | ||
|
||
def forward(self, x): | ||
x = self.conv1(x) | ||
x = self.bn1(x) | ||
x = self.relu(x) | ||
x = self.maxpool(x) | ||
|
||
x = self.layer1(x) | ||
x = self.layer2(x) | ||
x = self.layer3(x) | ||
x = self.layer4(x) | ||
|
||
# x = self.avgpool(x) | ||
# x = x.view(x.size(0), -1) | ||
# x = self.fc(x) | ||
|
||
return x | ||
|
||
def load_param(self, model_path): | ||
param_dict = torch.load(model_path) | ||
for i in param_dict: | ||
if 'fc' in i: | ||
continue | ||
self.state_dict()[i].copy_(param_dict[i]) | ||
|
||
|
||
def resnet50_ibn_a(last_stride, pretrained=False, **kwargs): | ||
"""Constructs a ResNet-50 model. | ||
Args: | ||
pretrained (bool): If True, returns a model pre-trained on ImageNet | ||
""" | ||
model = ResNet_IBN(last_stride, Bottleneck_IBN, [3, 4, 6, 3], **kwargs) | ||
if pretrained: | ||
model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) | ||
return model | ||
|
||
|
||
def resnet101_ibn_a(last_stride, pretrained=False, **kwargs): | ||
"""Constructs a ResNet-101 model. | ||
Args: | ||
pretrained (bool): If True, returns a model pre-trained on ImageNet | ||
""" | ||
model = ResNet_IBN(last_stride, Bottleneck_IBN, [3, 4, 23, 3], **kwargs) | ||
if pretrained: | ||
model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) | ||
return model | ||
|
||
|
||
def resnet152_ibn_a(last_stride, pretrained=False, **kwargs): | ||
"""Constructs a ResNet-152 model. | ||
Args: | ||
pretrained (bool): If True, returns a model pre-trained on ImageNet | ||
""" | ||
model = ResNet_IBN(last_stride, Bottleneck_IBN, [3, 8, 36, 3], **kwargs) | ||
if pretrained: | ||
model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) | ||
return model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters