Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add: fine tune using alexnet and resnet
- Loading branch information
1 parent
4853242
commit dd942b1
Showing
11 changed files
with
527 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
import torch.nn as nn | ||
import torch.utils.model_zoo as model_zoo | ||
import torch | ||
|
||
__all__ = ['AlexNet', 'alexnet'] | ||
|
||
model_urls = { | ||
'alexnet': 'http://download.pytorch.org/models/alexnet-owt-4df8aa71.pth', | ||
} | ||
|
||
|
||
class AlexNet(nn.Module): | ||
|
||
def __init__(self, num_classes=1000): | ||
super(AlexNet, self).__init__() | ||
self.features = nn.Sequential( | ||
nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2), | ||
nn.ReLU(inplace=True), | ||
nn.MaxPool2d(kernel_size=3, stride=2), | ||
nn.Conv2d(64, 192, kernel_size=5, padding=2), | ||
nn.ReLU(inplace=True), | ||
nn.MaxPool2d(kernel_size=3, stride=2), | ||
nn.Conv2d(192, 384, kernel_size=3, padding=1), | ||
nn.ReLU(inplace=True), | ||
nn.Conv2d(384, 256, kernel_size=3, padding=1), | ||
nn.ReLU(inplace=True), | ||
nn.Conv2d(256, 256, kernel_size=3, padding=1), | ||
nn.ReLU(inplace=True), | ||
nn.MaxPool2d(kernel_size=3, stride=2), | ||
) | ||
self.classifier = nn.Sequential( | ||
nn.Dropout(), | ||
nn.Linear(256 * 6 * 6, 4096), | ||
nn.ReLU(inplace=True), | ||
nn.Dropout(), | ||
nn.Linear(4096, 4096), | ||
nn.ReLU(inplace=True), | ||
nn.Linear(4096, num_classes), | ||
) | ||
|
||
def forward(self, x): | ||
x = self.features(x) | ||
x = x.view(x.size(0), 256 * 6 * 6) | ||
x = self.classifier(x) | ||
return x | ||
|
||
|
||
class AlexNetFc(nn.Module): | ||
def __init__(self, pretrained=False, num_classes=1000): | ||
super(AlexNetFc, self).__init__() | ||
model_alexnet = alexnet(pretrained=pretrained) | ||
self.features = model_alexnet.features | ||
self.classifier = nn.Sequential() | ||
for i in range(6): | ||
self.classifier.add_module("classifier" + str(i), model_alexnet.classifier[i]) | ||
self.__in_features = model_alexnet.classifier[6].in_features | ||
self.nfc = nn.Linear(4096, num_classes) | ||
|
||
def forward(self, x): | ||
x = self.features(x) | ||
x = x.view(x.size(0), 256 * 6 * 6) | ||
x = self.classifier(x) | ||
x = self.nfc(x) | ||
return x | ||
|
||
def output_num(self): | ||
return self.__in_features | ||
|
||
|
||
def alexnet(pretrained=False, **kwargs): | ||
r"""AlexNet model architecture from the | ||
`"One weird trick..." <https://arxiv.org/abs/1404.5997>`_ paper. | ||
Args: | ||
pretrained (bool): If True, returns a model pre-trained on ImageNet | ||
""" | ||
model = AlexNet(**kwargs) | ||
if pretrained: | ||
model.load_state_dict(torch.load('alexnet-owt-4df8aa71.pth')) | ||
return model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,237 @@ | ||
import torch.nn as nn | ||
import math | ||
import torch.utils.model_zoo as model_zoo | ||
import torch | ||
|
||
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', | ||
'resnet152'] | ||
|
||
model_urls = { | ||
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', | ||
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', | ||
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', | ||
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', | ||
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', | ||
} | ||
|
||
|
||
def conv3x3(in_planes, out_planes, stride=1): | ||
"""3x3 convolution with padding""" | ||
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, | ||
padding=1, bias=False) | ||
|
||
|
||
class BasicBlock(nn.Module): | ||
expansion = 1 | ||
|
||
def __init__(self, inplanes, planes, stride=1, downsample=None): | ||
super(BasicBlock, self).__init__() | ||
self.conv1 = conv3x3(inplanes, planes, stride) | ||
self.bn1 = nn.BatchNorm2d(planes) | ||
self.relu = nn.ReLU(inplace=True) | ||
self.conv2 = conv3x3(planes, planes) | ||
self.bn2 = nn.BatchNorm2d(planes) | ||
self.downsample = downsample | ||
self.stride = stride | ||
|
||
def forward(self, x): | ||
residual = x | ||
|
||
out = self.conv1(x) | ||
out = self.bn1(out) | ||
out = self.relu(out) | ||
|
||
out = self.conv2(out) | ||
out = self.bn2(out) | ||
|
||
if self.downsample is not None: | ||
residual = self.downsample(x) | ||
|
||
out += residual | ||
out = self.relu(out) | ||
|
||
return out | ||
|
||
|
||
class Bottleneck(nn.Module): | ||
expansion = 4 | ||
|
||
def __init__(self, inplanes, planes, stride=1, downsample=None): | ||
super(Bottleneck, self).__init__() | ||
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) | ||
self.bn1 = nn.BatchNorm2d(planes) | ||
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, | ||
padding=1, bias=False) | ||
self.bn2 = nn.BatchNorm2d(planes) | ||
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) | ||
self.bn3 = nn.BatchNorm2d(planes * 4) | ||
self.relu = nn.ReLU(inplace=True) | ||
self.downsample = downsample | ||
self.stride = stride | ||
|
||
def forward(self, x): | ||
residual = x | ||
|
||
out = self.conv1(x) | ||
out = self.bn1(out) | ||
out = self.relu(out) | ||
|
||
out = self.conv2(out) | ||
out = self.bn2(out) | ||
out = self.relu(out) | ||
|
||
out = self.conv3(out) | ||
out = self.bn3(out) | ||
|
||
if self.downsample is not None: | ||
residual = self.downsample(x) | ||
|
||
out += residual | ||
out = self.relu(out) | ||
|
||
return out | ||
|
||
|
||
class ResNet(nn.Module): | ||
|
||
def __init__(self, block, layers, num_classes=1000): | ||
self.inplanes = 64 | ||
super(ResNet, self).__init__() | ||
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, | ||
bias=False) | ||
self.bn1 = nn.BatchNorm2d(64) | ||
self.relu = nn.ReLU(inplace=True) | ||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) | ||
self.layer1 = self._make_layer(block, 64, layers[0]) | ||
self.layer2 = self._make_layer(block, 128, layers[1], stride=2) | ||
self.layer3 = self._make_layer(block, 256, layers[2], stride=2) | ||
self.layer4 = self._make_layer(block, 512, layers[3], stride=2) | ||
self.avgpool = nn.AvgPool2d(7, stride=1) | ||
self.fc = nn.Linear(512 * block.expansion, num_classes) | ||
self.in_feature = 512 * block.expansion | ||
|
||
# for m in self.modules(): | ||
# if isinstance(m, nn.Conv2d): | ||
# n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels | ||
# m.weight.data.normal_(0, math.sqrt(2. / n)) | ||
# elif isinstance(m, nn.BatchNorm2d): | ||
# m.weight.data.fill_(1) | ||
# m.bias.data.zero_() | ||
|
||
def _make_layer(self, block, planes, blocks, stride=1): | ||
downsample = None | ||
if stride != 1 or self.inplanes != planes * block.expansion: | ||
downsample = nn.Sequential( | ||
nn.Conv2d(self.inplanes, planes * block.expansion, | ||
kernel_size=1, stride=stride, bias=False), | ||
nn.BatchNorm2d(planes * block.expansion), | ||
) | ||
|
||
layers = [] | ||
layers.append(block(self.inplanes, planes, stride, downsample)) | ||
self.inplanes = planes * block.expansion | ||
for i in range(1, blocks): | ||
layers.append(block(self.inplanes, planes)) | ||
|
||
return nn.Sequential(*layers) | ||
|
||
def forward(self, x): | ||
x = self.conv1(x) | ||
x = self.bn1(x) | ||
x = self.relu(x) | ||
x = self.maxpool(x) | ||
|
||
x = self.layer1(x) | ||
x = self.layer2(x) | ||
x = self.layer3(x) | ||
x = self.layer4(x) | ||
|
||
x = self.avgpool(x) | ||
x = x.view(x.size(0), -1) | ||
x = self.fc(x) | ||
|
||
return x | ||
|
||
|
||
def resnet18(pretrained=False, **kwargs): | ||
"""Constructs a ResNet-18 model. | ||
Args: | ||
pretrained (bool): If True, returns a model pre-trained on ImageNet | ||
""" | ||
model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) | ||
if pretrained: | ||
model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) | ||
return model | ||
|
||
|
||
def resnet34(pretrained=False, **kwargs): | ||
"""Constructs a ResNet-34 model. | ||
Args: | ||
pretrained (bool): If True, returns a model pre-trained on ImageNet | ||
""" | ||
model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) | ||
if pretrained: | ||
model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) | ||
return model | ||
|
||
|
||
def resnet50(pretrained=False, **kwargs): | ||
"""Constructs a ResNet-50 model. | ||
Args: | ||
pretrained (bool): If True, returns a model pre-trained on ImageNet | ||
""" | ||
model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) | ||
if pretrained: | ||
model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) | ||
return model | ||
|
||
|
||
def resnet101(pretrained=False, **kwargs): | ||
"""Constructs a ResNet-101 model. | ||
Args: | ||
pretrained (bool): If True, returns a model pre-trained on ImageNet | ||
""" | ||
model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) | ||
if pretrained: | ||
model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) | ||
return model | ||
|
||
|
||
def resnet152(pretrained=False, **kwargs): | ||
"""Constructs a ResNet-152 model. | ||
Args: | ||
pretrained (bool): If True, returns a model pre-trained on ImageNet | ||
""" | ||
model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) | ||
if pretrained: | ||
model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) | ||
return model | ||
|
||
|
||
class myresnet(nn.Module): | ||
def __init__(self, pretrained=True, num_classes=1000): | ||
super(myresnet,self).__init__() | ||
model_resnet = resnet50(pretrained=pretrained) | ||
self.features = nn.Sequential( | ||
model_resnet.conv1, | ||
model_resnet.bn1, | ||
model_resnet.relu, | ||
model_resnet.maxpool, | ||
model_resnet.layer1, | ||
model_resnet.layer2, | ||
model_resnet.layer3, | ||
model_resnet.layer4, | ||
model_resnet.avgpool | ||
) | ||
self.nfc = nn.Linear(model_resnet.in_feature,num_classes) | ||
|
||
def forward(self, x): | ||
x = self.features(x) | ||
x = x.view(x.size(0), -1) | ||
x = self.nfc(x) | ||
return x |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
Download the Office-31 dataset (raw images) and extract it into this directory. | ||
This directory should look like: | ||
data | ||
--OFFICE31 | ||
----amazon | ||
------class1 | ||
------class2 | ||
... | ||
----webcam | ||
------(same as amazon) | ||
----dslr | ||
------(same as amazon) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
from torchvision import datasets, transforms | ||
import torch | ||
|
||
def load_training(root_path, dir, batch_size): | ||
transform = transforms.Compose( | ||
[transforms.Resize([256,256]), | ||
transforms.RandomResizedCrop(224), | ||
transforms.RandomHorizontalFlip(), | ||
transforms.ToTensor(), | ||
transforms.Normalize(mean=[0.485, 0.456, 0.406], | ||
std=[0.229, 0.224, 0.225]), | ||
]) | ||
data = datasets.ImageFolder(root=root_path + dir, transform=transform) | ||
train_loader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=True, drop_last=True,num_workers=4) | ||
return train_loader | ||
|
||
def load_testing(root_path, dir, batch_size): | ||
transform = transforms.Compose( | ||
[transforms.Resize([256, 256]), | ||
transforms.RandomResizedCrop(224), | ||
transforms.ToTensor(), | ||
transforms.Normalize(mean=[0.485, 0.456, 0.406], | ||
std=[0.229, 0.224, 0.225]), | ||
]) | ||
data = datasets.ImageFolder(root=root_path + dir, transform=transform) | ||
test_loader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=False,num_workers=4) | ||
return test_loader |
Oops, something went wrong.