In [9]:
import os
import sys
import json
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from tqdm import tqdm

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("using {} device.".format(device))

using cuda:0 device.


In [11]:
data_root = os.path.join(os.getcwd(), "")
image_path = os.path.join(data_root, "flower_data")
assert os.path.exists(image_path), "{} path does not exists".format(image_path)

In [12]:
batch_size = 16
nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])
print('Using {} dataloader workers every process'.format(nw))

Using 8 dataloader workers every process


In [13]:
data_transform = {
    "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                transforms.RandomHorizontalFlip(),
                                transforms.ToTensor(),
                                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
    "val": transforms.Compose([transforms.Resize(256),
                               transforms.CenterCrop(224),
                               transforms.ToTensor(),
                               transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
}

train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
                                     transform = data_transform["train"])
train_num = len(train_dataset)

validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
                                        transform= data_transform["val"])
val_num = len(validate_dataset)

train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=batch_size,
                                           shuffle= True,
                                           num_workers= nw)

val_loader = torch.utils.data.DataLoader(validate_dataset,
                                         batch_size=batch_size,
                                         shuffle=False,
                                         num_workers=nw)


In [14]:
flower_list = train_dataset.class_to_idx
cla_dict = dict((val, key) for key, val in flower_list.items())

json_str = json.dumps(cla_dict, indent=4)
with open('class_indices.json', 'w') as json_file:
    json_file.write(json_str)

In [34]:
class BasicBlock(nn.Module):
    expansion = 1
    
    def __init__(self, in_channel, out_channel, stride=1, downsample=None, **kwargs):
        super(BasicBlock, self).__init__()
        
        self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,
                               kernel_size=3, stride=stride, padding=1, bias=False)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,
                               kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channel)
        self.downsample = downsample
        
    def forward(self, x):
        identity = x
        if self.downsample is not None:
            identity = self.downsample(x)
            
            
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        
        out = self.conv2(out)
        out = self.bn2(out)
        
        out += identity
        out = self.relu(out)
        
        return out
    
class Bottleneck(nn.Module):
    
    def __init__(self, in_channel, out_channel, stride=1, downsample=None, 
                 groups=1, width_per_group=64):
        super(Bottleneck, self).__init__()
        
        width = int(out_channel * (width_per_group / 64.))* groups
        
        self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=width,
                               kernel_size=1, stride=1, bias=False)
        self.bn1 = nn.BatchNorm2d(width)
        
        self.conv2 = nn.Conv2d(in_channels=in_channel, out_channels=width, groups=groups,
                               kernel_size=3, stride=stride, bias=False, padding=1)
        self.bn2 = nn.BatchNorm2d(width)
        
        self.conv3= nn.conv2d(in_channels=width, out_channels=out_channel*self.expansion,
                              kernel_size=1, stride=1, bias=False)
        self.bn3 = nn.BatchNorm2d(out_channel*self.expansion)
        
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        
    def forward(self, x):
        identity = x
        if self.downsample is not None:
            identity = self.downsample(x)
            
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)
        
        out = self.conv3(out)
        out = self.bn3(out)
        
        out += identity
        out = self.relu(out)
        
        return out
    
class ResNet(nn.Module):
    
    def __init__(self, block, blocks_num, num_classes = 1000, include_top = True, groups=1, width_per_group=64):
        super(ResNet, self).__init__()
        
        self.include_top = include_top
        self.in_channel = 64
        
        self.groups = groups
        self.width_per_group = width_per_group
        
        self.conv1 = nn.Conv2d(3, self.in_channel, kernel_size=7, stride=2,padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(self.in_channel)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        
        self.layer1 = self._maker_layer(block, 64, blocks_num[0])
        self.layer2 = self._maker_layer(block, 128, blocks_num[1], stride=2)
        self.layer3 = self._maker_layer(block, 256, blocks_num[2], stride=2)
        self.layer4 = self._maker_layer(block, 512, blocks_num[3], stride=2)
        if self.include_top:
            self.avgpool = nn.AdaptiveAvgPool2d((1,1))
            self.fc = nn.Linear(512 * block.expansion, num_classes)
            
            
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
        
    
    
    def _maker_layer(self, block, channel, block_num, stride=1):
        downsample = None
        if stride != 1 or self.in_channel != channel * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.in_channel, channel * block.expansion, kernel_size=1, 
                          stride=stride, bias=False),
                nn.BatchNorm2d(channel * block.expansion))
            
        layers = []
        layers.append(block(self.in_channel, channel, downsample = downsample, stride = stride, 
                            groups = self.groups, width_per_group = self.width_per_group))
        self.in_channel = channel * block.expansion
        
        for _ in range(1, block_num):
            layers.append(block(self.in_channel, channel, groups=self.groups, 
                       width_per_group=self.width_per_group))
        return nn.Sequential(*layers)
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        
        if self.include_top:
            x = self.avgpool(x)
            x = torch.flatten(x, 1)
            x = self.fc(x)
            
        return x
    
def resnet34(num_classes=1000, include_top=True):
    # https://download.pytorch.org/models/resnet34-333f7ec4.pth
    return ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top)


def resnet50(num_classes=1000, include_top=True):
    # https://download.pytorch.org/models/resnet50-19c8e357.pth
    return ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top)


def resnet101(num_classes=1000, include_top=True):
    # https://download.pytorch.org/models/resnet101-5d3b4d8f.pth
    return ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes, include_top=include_top)


def resnext50_32x4d(num_classes=1000, include_top=True):
    # https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth
    groups = 32
    width_per_group = 4
    return ResNet(Bottleneck, [3, 4, 6, 3],
                  num_classes=num_classes,
                  include_top=include_top,
                  groups=groups,
                  width_per_group=width_per_group)


def resnext101_32x8d(num_classes=1000, include_top=True):
    # https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth
    groups = 32
    width_per_group = 8
    return ResNet(Bottleneck, [3, 4, 23, 3],
                  num_classes=num_classes,
                  include_top=include_top,
                  groups=groups,
                  width_per_group=width_per_group)

In [35]:
net = resnet34()

In [21]:
model_weight_path = './resnet34-pre.pth'
assert os.path.exists(model_weight_path), "file {} does not exist".format(model_weight_path)
net.load_state_dict(torch.load(model_weight_path, map_location='cpu'))

in_channel = net.fc.in_features
net.fc = nn.Linear(in_channel, 5)
net.to(device)

AssertionError: file ./resnet34-pre.pth does not exist

In [None]:
loss_function = nn.CrossEntropyLoss()

In [None]:
params = [p for p in net.parameters() if p.requires_grad]
optimizer = optim.Adam(params, lr=0.0001)


In [None]:
epochs =  3
best_acc = 0.0
save_path = './resNet34.pth'
train_steps = len(train_loader)

for epochs in range(epochs):
    net.train()
    running_loss = 0.0
    train_bar = tqdm(train_loader, file = sys.stdout)
    for step, data in enumerate(train_bar):
        images, labels = data
        optimizer.zero_grad()
        loss = loss_function(logits, labels.to(device))
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        
        train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1, epochs, loss)
    
    net.eval()
    acc = 0.0
    with torch.no_grad():
        val_bar = tqdm(validate_loader, file = sys.stdout)
        for val_data in val_bar:
            val_images, val_labels = val_data
            outputs = net(val_images.to(device))
            predict_y = torch.max(outputs, dim=1)[1]
            acc += torch.eq(predict_y, val_labels.to(device)).sum().item()
            
            val_bar.desc = "valid epoch[{}/{}]".format(epoch+1, epochs)
            
    val_accurate = acc / val_num
        print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %
              (epoch + 1, running_loss / train_steps, val_accurate))

        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(net.state_dict(), save_path)

    print('Finished Training')