In [39]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import torch.autograd as autograd
import torchvision.transforms as transforms
import torchvision.datasets as datasets

In [40]:
class LambdaLayer(nn.Module):
    def __init__(self, lambda_func):
        super(LambdaLayer, self).__init__()
        self.lambda_func = lambda_func

    def forward(self, x):
        return self.lambda_func(x)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1, option="A"):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.shortcut = nn.Sequential()

        if stride != 1 or in_planes != planes:
            if option == "A":
                # パッディングを使ってチャンネル数を合わせる
                self.shortcut = LambdaLayer(lambda x: F.pad(x[:, :, ::2, ::2],
                                                            (0, 0, 0, 0, planes//4, planes//4),
                                                            "constant",
                                                            0))
            elif option == "B":
                self.shortcut = nn.Sequential(
                    nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                    nn.BatchNorm2d(self.expansion * planes)
                )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

In [41]:
class ResNet(nn.Module):
    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion

        return nn.Sequential(*layers)

    def _weights_init(self, m):
        if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
            init.kaiming_normal_(m.weight)

    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 16 # 入力のチャンネル数
        
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(16)
        self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2)
        self.linear = nn.Linear(64, num_classes)

        self.apply(self._weights_init)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)

        # out.size()[3] は画像の縦横のサイズで画像のサイズが1x1になる
        out = F.avg_pool2d(out, out.size()[3])

        # out.size(0) はバッチサイズ
        # 64チャンネルの1x1画像を64次元のベクトルに変換
        out = out.view(out.size(0), -1)

        # 64チャンネルを10分類に変換
        out = self.linear(out)
        return out

    def resnet20(): return ResNet(BasicBlock, [3, 3, 3])
    def resnet32(): return ResNet(BasicBlock, [5, 5, 5])
    def resnet44(): return ResNet(BasicBlock, [7, 7, 7])
    def resnet56(): return ResNet(BasicBlock, [9, 9, 9])
    def renet110(): return ResNet(BasicBlock, [18, 18, 18])
    def resnet1202(): return ResNet(BasicBlock, [200, 200, 200])



In [42]:
train_loader = torch.utils.data.DataLoader(
    datasets.CIFAR10("./data", train=True, transform=transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(32, 4),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.228, 0.224, 0.225]),
    ]), download=True),
    batch_size=128,
    shuffle=True,
    num_workers=4,
    pin_memory=True
)

val_loader = torch.utils.data.DataLoader(
    datasets.CIFAR10("./data", train=False, transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.228, 0.224, 0.225]),
    ])),
    batch_size=128,
    num_workers=4,
    pin_memory=True
)

Files already downloaded and verified


In [43]:
class AverageMeter(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [44]:
def accuracy(output, target, topk=(1,)):
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))
    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))

    return res

In [45]:
model = ResNet.resnet20().cuda()

criterion = nn.CrossEntropyLoss().cuda()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[100, 150])

best_prec1 = 0

for epoch in range(1, 201):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()

    model.train()
    for i, (input, target) in enumerate(train_loader):
        target = target.cuda()
        input_var = input.cuda()
        target_var = target.cuda()

        output = model(input_var)
        loss = criterion(output, target_var)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        output = output.float()
        loss = loss.float()

        prec1 = accuracy(output.data, target)[0]
        losses.update(loss.item(), input.size(0))
        top1.update(prec1.item(), input.size(0))

        if i % 100 == 0:
            print(f"[{i}/{len(train_loader)}] 損失: {losses.avg:.4f} 精度: {top1.avg:.4f}%")

    print(f"訓練 エポック: {epoch} 損失: {losses.avg:.4f} 精度: {top1.avg}%")

    lr_scheduler.step()

    model.eval()
    with torch.no_grad():
        for i, (input, target) in enumerate(val_loader):
            target = target.cuda()
            input_var = input.cuda()
            target_var = target.cuda()

            output = model(input_var)
            loss = criterion(output, target_var)

            output = output.float()
            loss = loss.float()

            prec1 = accuracy(output.data, target)[0]
            losses.update(loss.item(), input.size(0))
            top1.update(prec1.item(), input.size(0))

    print(f"検証 エポック: {epoch} 損失: {losses.avg:.4f} 精度: {top1.avg:.4f}%")

    prec1 = top1.avg
    if prec1 > best_prec1:
        print(f"精度が向上したためモデルを保存します: {best_prec1:.4f}% -> {prec1:.4f}%")
        best_prec1 = prec1
        torch.save(model.state_dict(), "models/resnet.pth")

[0/391] 損失: 3.7490 精度: 6.2500%
[100/391] 損失: 2.0308 精度: 26.9802%
[200/391] 損失: 1.8262 精度: 32.7931%
[300/391] 損失: 1.6913 精度: 37.7206%
訓練 エポック: 1 損失: 1.6014 精度: 41.136%
検証 エポック: 1 損失: 1.5613 精度: 42.8733%
精度が向上したためモデルを保存します: 0.0000% -> 42.8733%
[0/391] 損失: 1.2872 精度: 53.9062%
[100/391] 損失: 1.1700 精度: 57.2324%
[200/391] 損失: 1.1319 精度: 58.7376%
[300/391] 損失: 1.1061 精度: 59.9642%
訓練 エポック: 2 損失: 1.0813 精度: 60.998%
検証 エポック: 2 損失: 1.1035 精度: 60.9133%
精度が向上したためモデルを保存します: 42.8733% -> 60.9133%
[0/391] 損失: 0.9096 精度: 67.1875%
[100/391] 損失: 0.9459 精度: 66.1278%
[200/391] 損失: 0.9308 精度: 66.8338%
[300/391] 損失: 0.9138 精度: 67.5405%
訓練 エポック: 3 損失: 0.8981 精度: 68.152%
検証 エポック: 3 損失: 0.9059 精度: 68.0917%
精度が向上したためモデルを保存します: 60.9133% -> 68.0917%
[0/391] 損失: 0.8848 精度: 67.9688%
[100/391] 損失: 0.8240 精度: 70.9081%
[200/391] 損失: 0.7995 精度: 71.8633%
[300/391] 損失: 0.7925 精度: 72.1839%
訓練 エポック: 4 損失: 0.7822 精度: 72.66%
検証 エポック: 4 損失: 0.8120 精度: 71.9333%
精度が向上したためモデルを保存します: 68.0917% -> 71.9333%
[0/391] 損失: 1.0116 精度: 69.5