CNNによる画像分類

In [1]:
import torch
from torch import nn, optim
from torch.utils.data import (Dataset,
                             DataLoader,
                             TensorDataset)
import tqdm

In [2]:
from torchvision.datasets import FashionMNIST
from torchvision import transforms

# 訓練用のデータを取得
# そのままだとPIL(Python Imaging Library)の画像形式でDatasetを作ってしまうのでtransform.toTensorでtensorに変換する
fashion_mnist_train = FashionMNIST("ch4data/FashionMNIST",
                                  train=True, download=True,
                                  transform=transforms.ToTensor())
fashion_mnist_test = FashionMNIST("ch4data/FashionMNIST",
                                  train=False, download=True,
                                  transform=transforms.ToTensor())

# バッチサイズが128のDataLoaderをそれぞれ作成
batch_size = 128
train_loader = DataLoader(fashion_mnist_train,
                          batch_size=batch_size, shuffle=True)
test_loader = DataLoader(fashion_mnist_test,
                        batch_size=batch_size, shuffle=False)

CNNの構築と学習

In [3]:
"""
2層の畳み込み層と2層のMLPをつなげたCNNを作成
"""
# 畳み込み層の出力をMLPに渡す際に必要
class FlattenLayer(nn.Module):
    def forward(self, x):
        sizes = x.size()
        return x.view(sizes[0], -1)

# 5x5のカーネルを使用し最初に32個、次に64個のチャンネルを作成する
# BatchNorm2dは画像形式用のBatch Normalization
# Dropout2dは画像形式用のDropout
# 最後にFlattenLayerを挟む
conv_net = nn.Sequential(
    nn.Conv2d(1, 32, 5),
    nn.MaxPool2d(2),
    nn.ReLU(),
    nn.BatchNorm2d(32),
    nn.Dropout2d(0.25),
    nn.Conv2d(32, 64, 5),
    nn.MaxPool2d(2),
    nn.ReLU(),
    nn.BatchNorm2d(64),
    nn.Dropout2d(0.25),
    FlattenLayer()
)

# 畳み込みによって最終的にどのようなサイズになっているかを、実際にダミーデータを入れてみて確認する
test_input = torch.ones(1, 1, 28, 28)
conv_output_size = conv_net(test_input).size()[-1]

# 2層のMLP
mlp = nn.Sequential(
    nn.Linear(conv_output_size, 200),
    nn.ReLU(),
    nn.BatchNorm1d(200),
    nn.Dropout(0.25),
    nn.Linear(200, 10)
)

# 最終的なCNN
net = nn.Sequential(
    conv_net,
    mlp
)

In [4]:
"""
評価と訓練のヘルパー関数を作成
"""
def eval_net(net, data_loader, device="cpu"):
    # DropoutやBatchNormを無効化
    net.eval()
    ys = []
    ypreds = []
    for x, y in data_loader:
        # toメソッドで計算を実行するデバイスに転送する
        x = x.to(device)
        y = y.to(device)
        # 確率が最大のクラスを予測
        # ここではforward(推論)の計算だけなので自動微分に必要な処理はoffにして余計な計算を省く
        with torch.no_grad():
            _, y_pred = net(x).max(1)
        ys.append(y)
        ypreds.append(y_pred)
        # ミニバッチごと予測結果などを1つにまとめる
        ys = torch.cat(ys)
        ypreds = torch.cat(ypreds)
        # 予測精度を計算
        acc = (ys == ypreds).float().sum() / len(ys)
        return acc.item()

def train_net(net, train_loader, test_loader, optimizer_cls=optim.Adam, loss_fn=nn.CrossEntropyLoss(), n_iter=10, device="cpu"):
    train_losses = []
    train_acc = []
    val_acc = []
    optimizer = optimizer_cls(net.parameters())
    for epoch in range(n_iter):
        running_loss = 0.0
        # ネットワークを訓練モードにする
        net.train()
        n = 0
        n_acc = 0
        # 非常に時間がかかるのでtqdmを使用してプログレスバーを出す
        for i, (xx, yy) in tqdm.tqdm(enumerate(train_loader),
                                total=len(train_loader)):
            xx = xx.to(device)
            yy = yy.to(device)
            h = net(xx)
            loss = loss_fn(h, yy)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            n += len(xx)
            _, y_pred = h.max(1)
            n_acc += (yy == y_pred).float().sum().item()
        train_losses.append(running_loss / i)
        # 訓練データの予測精度
        train_acc.append(n_acc / n)
        # 検証データの予測精度
        val_acc.append(eval_net(net, test_loader, device))
        # このepochでの結果を表示
        print(epoch, train_losses[-1], train_acc[-1], val_acc[-1], flush=True)

In [5]:
"""
全パラメータをGPUに転送して訓練を実行
"""
# ネットワークの全パラメータをGPUに転送
net.to("cuda:0")

# 訓練を実行
train_net(net, train_loader, test_loader, n_iter=20, device="cuda:0")

100%|██████████| 469/469 [00:10<00:00, 45.23it/s]

0 0.47064782612216777 0.8344 0.875



100%|██████████| 469/469 [00:09<00:00, 49.53it/s]

1 0.31681714394790494 0.8841166666666667 0.8828125



100%|██████████| 469/469 [00:09<00:00, 49.45it/s]

2 0.28175940294551033 0.8974333333333333 0.890625



100%|██████████| 469/469 [00:09<00:00, 49.61it/s]


3 0.2619724556421622 0.9038 0.890625


100%|██████████| 469/469 [00:09<00:00, 49.29it/s]

4 0.24339861164872462 0.9096166666666666 0.8828125



100%|██████████| 469/469 [00:09<00:00, 50.26it/s]

5 0.23181136737330857 0.9132166666666667 0.8984375



100%|██████████| 469/469 [00:09<00:00, 49.41it/s]


6 0.22108536813822058 0.9172333333333333 0.90625


100%|██████████| 469/469 [00:09<00:00, 49.45it/s]

7 0.21302990042246306 0.9199833333333334 0.921875



100%|██████████| 469/469 [00:09<00:00, 49.61it/s]

8 0.2043916278351576 0.92365 0.921875



100%|██████████| 469/469 [00:09<00:00, 49.78it/s]

9 0.19479505950187007 0.92795 0.9140625



100%|██████████| 469/469 [00:09<00:00, 50.56it/s]

10 0.1879618960376988 0.929 0.90625



100%|██████████| 469/469 [00:09<00:00, 49.65it/s]

11 0.18398999964070117 0.9298833333333333 0.890625



100%|██████████| 469/469 [00:09<00:00, 49.93it/s]

12 0.1781123783120997 0.9331166666666667 0.8984375



100%|██████████| 469/469 [00:09<00:00, 49.84it/s]


13 0.17488930761240995 0.9345166666666667 0.8984375


100%|██████████| 469/469 [00:09<00:00, 50.13it/s]

14 0.16761300081594122 0.9369166666666666 0.9140625



100%|██████████| 469/469 [00:09<00:00, 49.86it/s]

15 0.16348915565440542 0.9385 0.9140625



100%|██████████| 469/469 [00:09<00:00, 49.88it/s]

16 0.16008517254366833 0.9385166666666667 0.9296875



100%|██████████| 469/469 [00:09<00:00, 49.56it/s]

17 0.1557827016386466 0.9408333333333333 0.8984375



100%|██████████| 469/469 [00:09<00:00, 49.39it/s]


18 0.15423270310155857 0.94115 0.90625


100%|██████████| 469/469 [00:09<00:00, 49.43it/s]

19 0.15150482640561894 0.9430333333333333 0.90625





# 転移学習

データの準備

In [5]:
"""
DataLoaderを作成
"""
from torchvision.datasets import ImageFolder
from torchvision import transforms

# ImageFolder関数を使用してDatasetを作成する
train_imgs = ImageFolder(
    "ch4data/taco_and_burrito/train/",
    transform=transforms.Compose([
        transforms.RandomCrop(224),
        transforms.ToTensor()
    ])
)
test_imgs = ImageFolder(
    "ch4data/taco_and_burrito/test/",
    transform=transforms.Compose([
        transforms.RandomCrop(224),
        transforms.ToTensor()
    ])
)

# DataLoaderを作成
train_loader = DataLoader(
    train_imgs, batch_size=32, shuffle=True
)
test_loader = DataLoader(
    test_imgs, batch_size=32, shuffle=False
)

In [6]:
print(train_imgs.classes)

['burrito', 'taco']


In [7]:
print(train_imgs.class_to_idx)

{'burrito': 0, 'taco': 1}


PyTorchで転移学習

In [8]:
"""
事前学習済みモデルのロードと定義
"""
from torchvision import models

# 事前学習済みのresten18をロード
net = models.resnet18(pretrained=True)

# 全てのパラメータを微分対象外にする
for p in net.parameters():
    p.requires_grad = False

# 最後の線形層を付け替える
fc_input_dim = net.fc.in_features
net.fc = nn.Linear(fc_input_dim, 2)

In [9]:
"""
モデルの訓練関数の記述
"""
def eval_net(net, data_loader, device="cpu"):
    # DropoutやBatchNormを無効化
    net.eval()
    ys = []
    ypreds = []
    for x, y in data_loader:
        # toメソッドで計算を実行するデバイスに転送する
        x = x.to(device)
        y = y.to(device)
        # 確率が最大のクラスを予測
        # ここではforward(推論)の計算だけなので自動微分に必要な処理はoffにして余計な計算を省く
        with torch.no_grad():
            _, y_pred = net(x).max(1)
        ys.append(y)
        ypreds.append(y_pred)
    # ミニバッチごとの予測結果などを1つにまとめる
    ys = torch.cat(ys)
    ypreds = torch.cat(ypreds)
    # 予測精度を計算
    acc = (ys == ypreds).float().sum() / len(ys)
    return acc.item()

def train_net(net, train_loader, test_loader,
             only_fc=True,
             optimizer_cls=optim.Adam,
             loss_fn=nn.CrossEntropyLoss(),
             n_iter=10,
             device="cpu"):
    train_losses = []
    train_acc = []
    val_acc = []
    if only_fc:
        # 最後の線形層のパラメータのみをoptimizerに渡す
        optimizer = optimizer_cls(net.fc.parameters())
    else:
        optimizer = optimizer_cls(net.parameters())
    for epoch in range(n_iter):
        running_loss = 0.0
        # ネットワークを訓練モードにする
        net.train()
        n = 0
        n_acc = 0
        # 非常に時間がかかるのでtqdmを使用してプログレスバーをだす
        for i, (xx, yy) in tqdm.tqdm(enumerate(train_loader),
                                    total=len(train_loader)):
            xx = xx.to(device)
            yy = yy.to(device)
            h = net(xx)
            loss = loss_fn(h, yy)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            n += len(xx)
            _, y_pred = h.max(1)
            n_acc += (yy == y_pred).float().sum().item()
        train_losses.append(running_loss / i)
        # 訓練データの予測精度
        train_acc.append(n_acc / n)
        # 検証データの予測精度
        val_acc.append(eval_net(net, test_loader, device))
        # このepochでの結果を表示
        print(epoch, train_losses[-1], train_acc[-1], val_acc[-1], flush=True)

In [11]:
"""
全パラメータをGPUに転送
"""
net.to("cuda:0")

# 訓練を実行
train_net(net, train_loader, test_loader, n_iter=20, device="cuda:0")

100%|██████████| 23/23 [00:06<00:00,  3.59it/s]


0 0.7590137069875543 0.5435393258426966 0.6666666865348816


100%|██████████| 23/23 [00:03<00:00,  5.83it/s]


1 0.5655451159585606 0.7429775280898876 0.7500000596046448


100%|██████████| 23/23 [00:03<00:00,  5.90it/s]


2 0.5198699263009158 0.7865168539325843 0.8333333730697632


100%|██████████| 23/23 [00:03<00:00,  6.10it/s]


3 0.47599869153716345 0.8286516853932584 0.7833333611488342


100%|██████████| 23/23 [00:03<00:00,  6.04it/s]


4 0.4362903345714916 0.8117977528089888 0.8333333730697632


100%|██████████| 23/23 [00:03<00:00,  6.25it/s]


5 0.44816620647907257 0.8230337078651685 0.8500000238418579


100%|██████████| 23/23 [00:03<00:00,  6.25it/s]


6 0.37510093098337 0.8539325842696629 0.7333333492279053


100%|██████████| 23/23 [00:03<00:00,  6.23it/s]


7 0.37435702099041507 0.8525280898876404 0.8000000715255737


100%|██████████| 23/23 [00:03<00:00,  6.07it/s]


8 0.3973151452162049 0.8314606741573034 0.8666667342185974


100%|██████████| 23/23 [00:03<00:00,  6.24it/s]


9 0.358208029107614 0.8693820224719101 0.9000000357627869


100%|██████████| 23/23 [00:03<00:00,  6.24it/s]


10 0.3556080037219958 0.8525280898876404 0.8333333730697632


100%|██████████| 23/23 [00:03<00:00,  6.24it/s]


11 0.337276735089042 0.8665730337078652 0.8000000715255737


100%|██████████| 23/23 [00:03<00:00,  6.05it/s]


12 0.324612785469402 0.8778089887640449 0.8666667342185974


100%|██████████| 23/23 [00:03<00:00,  6.11it/s]


13 0.34486527470025147 0.8707865168539326 0.8833333849906921


100%|██████████| 23/23 [00:03<00:00,  5.94it/s]


14 0.29963084919886157 0.8946629213483146 0.8500000238418579


100%|██████████| 23/23 [00:03<00:00,  6.15it/s]


15 0.319154022092169 0.8792134831460674 0.8000000715255737


100%|██████████| 23/23 [00:03<00:00,  6.13it/s]


16 0.34083300761201163 0.8553370786516854 0.8333333730697632


100%|██████████| 23/23 [00:03<00:00,  6.26it/s]


17 0.33010095019232144 0.8735955056179775 0.8000000715255737


100%|██████████| 23/23 [00:03<00:00,  6.28it/s]


18 0.323605204847726 0.8707865168539326 0.8333333730697632


100%|██████████| 23/23 [00:03<00:00,  6.26it/s]


19 0.3170310732993213 0.875 0.8166667222976685


In [10]:
"""
入力をそのまま出力するダミーの層を作り、fcを置き換える
"""
class IdentityLayer(nn.Module):
    def forward(self, x):
        return x

net = models.resnet18(pretrained=True)
for p in net.parameters():
    p.requires_grad = False
net.fc = IdentityLayer()

In [13]:
"""
CNNモデル
"""
conv_net = nn.Sequential(
    nn.Conv2d(3, 32, 5),
    nn.MaxPool2d(2),
    nn.ReLU(),
    nn.BatchNorm2d(32),
    nn.Conv2d(32, 64, 5),
    nn.MaxPool2d(2),
    nn.ReLU(),
    nn.BatchNorm2d(64),
    nn.Conv2d(64, 128, 5),
    nn.MaxPool2d(2),
    nn.ReLU(),
    nn.BatchNorm2d(128),
    FlattenLayer()
)

# 畳み込みによって最終的にどのようなサイズになっているかを実際にデータを入れて確認する
test_input = torch.ones(1, 3, 224, 224)
conv_output_size = conv_net(test_input).size()[-1]

# 最終的なCNN
net = nn.Sequential(
    conv_net,
    nn.Linear(conv_output_size, 2)
)

# 訓練を実行
train_net(net, train_loader, test_loader, n_iter=10, only_fc=False)

100%|██████████| 23/23 [01:49<00:00,  3.80s/it]


0 1.9517546919259159 0.5926966292134831 0.5333333611488342


100%|██████████| 23/23 [01:50<00:00,  3.82s/it]


1 2.4599930721927774 0.6530898876404494 0.5166666507720947


100%|██████████| 23/23 [01:49<00:00,  3.81s/it]


2 2.990257512439381 0.6362359550561798 0.4833333194255829


100%|██████████| 23/23 [01:49<00:00,  3.87s/it]


3 2.97407711094076 0.6334269662921348 0.5166666507720947


100%|██████████| 23/23 [01:48<00:00,  3.78s/it]


4 2.8796351118521257 0.651685393258427 0.6666666865348816


100%|██████████| 23/23 [01:46<00:00,  3.76s/it]


5 2.1866612027991903 0.651685393258427 0.6333333253860474


100%|██████████| 23/23 [01:47<00:00,  3.77s/it]


6 2.4933279779824344 0.6474719101123596 0.7166666388511658


100%|██████████| 23/23 [01:47<00:00,  3.78s/it]


7 2.795888976617293 0.648876404494382 0.46666666865348816


100%|██████████| 23/23 [01:50<00:00,  3.84s/it]


8 2.307888327674432 0.6601123595505618 0.6666666865348816


100%|██████████| 23/23 [01:48<00:00,  3.78s/it]


9 2.0579400766979563 0.6460674157303371 0.6000000238418579


# CNNモデルによる画像の高解像度化

データの準備

In [11]:
"""
32x32ピクセルの画像を128x128ピクセルに拡大
"""
class DownSizedPairImageFolder(ImageFolder):
    def __init__(self, root, transform=None, large_size=128, small_size=32, **kwds):
        super().__init__(root, transform=transform, **kwds)
        self.large_resizer = transforms.Resize(large_size)
        self.small_resizer = transforms.Resize(small_size)
    
    def __getitem__(self, index):
        path, _ = self.imgs[index]
        img = self.loader(path)
        
        # 読み取った画像を128x128ピクセルと32x32ピクセルにリサイズする
        large_img = self.large_resizer(img)
        small_img = self.small_resizer(img)
        
        # その他の変換を適用する
        if self.transform is not None:
            large_img = self.transform(large_img)
            small_img = self.transform(small_img)
        
        # 32ピクセルの画像と128ピクセルの画像を返す
        return small_img, large_img

In [12]:
"""
訓練用と検証用のDataLoaderを作成
"""
train_data = DownSizedPairImageFolder(
    "ch4data/lfw-deepfunneled/train/",
    transform = transforms.ToTensor()
)
test_data = DownSizedPairImageFolder(
"ch4data/lfw-deepfunneled/test/",
    transform = transforms.ToTensor()
)

batch_size = 32
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=4)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=4)

In [13]:
net = nn.Sequential(
    nn.Conv2d(3, 256, 4, stride=2, padding=1),
    nn.ReLU(),
    nn.BatchNorm2d(256),
    nn.Conv2d(256, 512, 4, stride=2, padding=1),
    nn.ReLU(),
    nn.BatchNorm2d(512),
    nn.ConvTranspose2d(512, 256, 4, stride=2, padding=1),
    nn.ReLU(),
    nn.BatchNorm2d(256),
    nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1),
    nn.ReLU(),
    nn.BatchNorm2d(128),
    nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),
    nn.ReLU(),
    nn.BatchNorm2d(64),
    nn.ConvTranspose2d(64, 3, 4, stride=2, padding=1)
)

In [14]:
"""
PSNRの計算
"""
import math
def psnr(mse, max_v=1.0):
    return 10 * math.log10(max_v**2 / mse)

# 評価のヘルパー関数
def eval_net(net, data_loader, device="cpu"):
    # DropoutやBatchNormを無効化
    net.eval()
    ys = []
    ypreds = []
    for x, y in data_loader:
        x = x.to(device)
        y = y.to(device)
        with torch.no_grad():
            y_pred = net(x)
        ys.append(y)
        ypreds.append(y_pred)
    # ミニバッチごとの予測結果などを1つにまとめる
    ys = torch.cat(ys)
    ypreds = torch.cat(ypreds)
    # 予測精度(MSE)を計算
    score = nn.functional.mse_loss(ypreds, ys).item()
    return score

# 訓練のヘルパー関数
def train_net(net, train_loader, test_loader, optimizer_cls=optim.Adam, loss_fn=nn.MSELoss(), n_iter=10, device="cpu"):
    train_losses = []
    train_acc = []
    val_acc = []
    optimizer = optimizer_cls(net.parameters())
    for epoch in range(n_iter):
        running_loss = 0.0
        # ネットワークを訓練モードにする
        net.train()
        n = 0
        score = 0
        # 非常に時間がかかるのでtqdmを使用してプログレスバーをだす
        for i, (xx, yy) in tqdm.tqdm(enumerate(train_loader),
                                    total=len(train_loader)):
            xx = xx.to(device)
            yy = yy.to(device)
            y_pred = net(xx)
            loss = loss_fn(y_pred, yy)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            n += len(xx)
        train_losses.append(running_loss / len(train_loader))
        # 検証データの予測精度
        val_acc.append(eval_net(net, test_loader, device))
        # このepochでの結果を表示
        print(epoch, train_losses[-1], psnr(train_losses[-1]), psnr(val_acc[-1]), flush=True)

In [32]:
"""
複数回の演算(10回)
"""
net.to("cuda:0")
train_net(net, train_loader, test_loader, device="cuda:0")



  0%|          | 0/409 [00:00<?, ?it/s][A[A

  0%|          | 1/409 [00:00<01:36,  4.24it/s][A[A

  0%|          | 2/409 [00:00<01:22,  4.94it/s][A[A

  1%|          | 3/409 [00:00<01:11,  5.67it/s][A[A

  1%|          | 4/409 [00:00<01:03,  6.40it/s][A[A

  1%|          | 5/409 [00:00<00:57,  7.03it/s][A[A

  2%|▏         | 7/409 [00:00<00:51,  7.81it/s][A[A

  2%|▏         | 9/409 [00:01<00:47,  8.51it/s][A[A

  3%|▎         | 11/409 [00:01<00:43,  9.14it/s][A[A

  3%|▎         | 13/409 [00:01<00:41,  9.61it/s][A[A

  4%|▎         | 15/409 [00:01<00:39,  9.97it/s][A[A

  4%|▍         | 17/409 [00:01<00:38, 10.25it/s][A[A

  5%|▍         | 19/409 [00:01<00:37, 10.46it/s][A[A

  5%|▌         | 21/409 [00:02<00:36, 10.55it/s][A[A

  6%|▌         | 23/409 [00:02<00:36, 10.68it/s][A[A

  6%|▌         | 25/409 [00:02<00:35, 10.75it/s][A[A

  7%|▋         | 27/409 [00:02<00:35, 10.77it/s][A[A

  7%|▋         | 29/409 [00:02<00:35, 10.84it/s][A[A

  8%|▊

0 0.002444629193953077 26.117870061652162 27.23364977224534




  0%|          | 0/409 [00:00<?, ?it/s][A[A

  0%|          | 1/409 [00:00<01:49,  3.72it/s][A[A

  1%|          | 3/409 [00:00<01:27,  4.66it/s][A[A

  1%|          | 5/409 [00:00<01:11,  5.62it/s][A[A

  2%|▏         | 7/409 [00:00<01:01,  6.56it/s][A[A

  2%|▏         | 9/409 [00:01<00:54,  7.38it/s][A[A

  3%|▎         | 11/409 [00:01<00:48,  8.15it/s][A[A

  3%|▎         | 13/409 [00:01<00:45,  8.79it/s][A[A

  4%|▎         | 15/409 [00:01<00:42,  9.33it/s][A[A

  4%|▍         | 17/409 [00:01<00:40,  9.73it/s][A[A

  5%|▍         | 19/409 [00:01<00:38, 10.03it/s][A[A

  5%|▌         | 21/409 [00:02<00:38, 10.19it/s][A[A

  6%|▌         | 23/409 [00:02<00:37, 10.32it/s][A[A

  6%|▌         | 25/409 [00:02<00:36, 10.42it/s][A[A

  7%|▋         | 27/409 [00:02<00:36, 10.53it/s][A[A

  7%|▋         | 29/409 [00:02<00:35, 10.61it/s][A[A

  8%|▊         | 31/409 [00:03<00:35, 10.68it/s][A[A

  8%|▊         | 33/409 [00:03<00:35, 10.73it/s][A[A

  9%

1 0.002389269095360504 26.21734934365326 26.344707476535135




  0%|          | 0/409 [00:00<?, ?it/s][A[A

  0%|          | 1/409 [00:00<01:43,  3.95it/s][A[A

  1%|          | 3/409 [00:00<01:22,  4.90it/s][A[A

  1%|          | 5/409 [00:00<01:08,  5.88it/s][A[A

  2%|▏         | 7/409 [00:00<00:58,  6.82it/s][A[A

  2%|▏         | 9/409 [00:00<00:52,  7.66it/s][A[A

  3%|▎         | 11/409 [00:01<00:47,  8.37it/s][A[A

  3%|▎         | 13/409 [00:01<00:44,  8.96it/s][A[A

  4%|▎         | 15/409 [00:01<00:41,  9.40it/s][A[A

  4%|▍         | 17/409 [00:01<00:40,  9.76it/s][A[A

  5%|▍         | 19/409 [00:01<00:38, 10.09it/s][A[A

  5%|▌         | 21/409 [00:02<00:37, 10.28it/s][A[A

  6%|▌         | 23/409 [00:02<00:37, 10.36it/s][A[A

  6%|▌         | 25/409 [00:02<00:36, 10.46it/s][A[A

  7%|▋         | 27/409 [00:02<00:36, 10.50it/s][A[A

  7%|▋         | 29/409 [00:02<00:36, 10.55it/s][A[A

  8%|▊         | 31/409 [00:03<00:35, 10.63it/s][A[A

  8%|▊         | 33/409 [00:03<00:35, 10.70it/s][A[A

  9%

2 0.0023745127112780014 26.244255012038337 26.727901974387734




  0%|          | 0/409 [00:00<?, ?it/s][A[A

  0%|          | 1/409 [00:00<01:45,  3.88it/s][A[A

  1%|          | 3/409 [00:00<01:24,  4.81it/s][A[A

  1%|          | 5/409 [00:00<01:10,  5.75it/s][A[A

  2%|▏         | 7/409 [00:00<01:00,  6.67it/s][A[A

  2%|▏         | 9/409 [00:01<00:53,  7.53it/s][A[A

  3%|▎         | 11/409 [00:01<00:47,  8.29it/s][A[A

  3%|▎         | 13/409 [00:01<00:44,  8.91it/s][A[A

  4%|▎         | 15/409 [00:01<00:42,  9.37it/s][A[A

  4%|▍         | 17/409 [00:01<00:40,  9.72it/s][A[A

  5%|▍         | 19/409 [00:01<00:39,  9.94it/s][A[A

  5%|▌         | 21/409 [00:02<00:38, 10.12it/s][A[A

  6%|▌         | 23/409 [00:02<00:37, 10.30it/s][A[A

  6%|▌         | 25/409 [00:02<00:36, 10.44it/s][A[A

  7%|▋         | 27/409 [00:02<00:36, 10.48it/s][A[A

  7%|▋         | 29/409 [00:02<00:36, 10.53it/s][A[A

  8%|▊         | 31/409 [00:03<00:35, 10.62it/s][A[A

  8%|▊         | 33/409 [00:03<00:35, 10.62it/s][A[A

  9%

3 0.002141902859386485 26.692002293907464 27.573083875457122




  0%|          | 0/409 [00:00<?, ?it/s][A[A

  0%|          | 1/409 [00:00<01:49,  3.72it/s][A[A

  1%|          | 3/409 [00:00<01:27,  4.62it/s][A[A

  1%|          | 5/409 [00:00<01:12,  5.58it/s][A[A

  2%|▏         | 7/409 [00:00<01:01,  6.54it/s][A[A

  2%|▏         | 9/409 [00:01<00:53,  7.43it/s][A[A

  3%|▎         | 11/409 [00:01<00:48,  8.17it/s][A[A

  3%|▎         | 13/409 [00:01<00:45,  8.80it/s][A[A

  4%|▎         | 15/409 [00:01<00:42,  9.30it/s][A[A

  4%|▍         | 17/409 [00:01<00:40,  9.65it/s][A[A

  5%|▍         | 19/409 [00:01<00:39,  9.92it/s][A[A

  5%|▌         | 21/409 [00:02<00:38, 10.16it/s][A[A

  6%|▌         | 23/409 [00:02<00:37, 10.38it/s][A[A

  6%|▌         | 25/409 [00:02<00:36, 10.50it/s][A[A

  7%|▋         | 27/409 [00:02<00:36, 10.60it/s][A[A

  7%|▋         | 29/409 [00:02<00:35, 10.64it/s][A[A

  8%|▊         | 31/409 [00:03<00:35, 10.63it/s][A[A

  8%|▊         | 33/409 [00:03<00:35, 10.60it/s][A[A

  9%

4 0.0022669868725340605 26.445509947400577 26.680456519649447




  0%|          | 0/409 [00:00<?, ?it/s][A[A

  0%|          | 1/409 [00:00<01:18,  5.18it/s][A[A

  1%|          | 3/409 [00:00<01:05,  6.19it/s][A[A

  1%|          | 5/409 [00:00<00:56,  7.13it/s][A[A

  2%|▏         | 7/409 [00:00<00:50,  7.95it/s][A[A

  2%|▏         | 9/409 [00:00<00:46,  8.58it/s][A[A

  3%|▎         | 11/409 [00:01<00:43,  9.11it/s][A[A

  3%|▎         | 13/409 [00:01<00:41,  9.55it/s][A[A

  4%|▎         | 15/409 [00:01<00:39,  9.90it/s][A[A

  4%|▍         | 17/409 [00:01<00:38, 10.14it/s][A[A

  5%|▍         | 19/409 [00:01<00:37, 10.33it/s][A[A

  5%|▌         | 21/409 [00:02<00:37, 10.41it/s][A[A

  6%|▌         | 23/409 [00:02<00:36, 10.46it/s][A[A

  6%|▌         | 25/409 [00:02<00:36, 10.52it/s][A[A

  7%|▋         | 27/409 [00:02<00:36, 10.55it/s][A[A

  7%|▋         | 29/409 [00:02<00:35, 10.61it/s][A[A

  8%|▊         | 31/409 [00:02<00:35, 10.65it/s][A[A

  8%|▊         | 33/409 [00:03<00:35, 10.69it/s][A[A

  9%

5 0.0021719717199192327 26.631458337622803 26.41969875470722




  0%|          | 0/409 [00:00<?, ?it/s][A[A

  0%|          | 1/409 [00:00<01:19,  5.14it/s][A[A

  1%|          | 3/409 [00:00<01:06,  6.12it/s][A[A

  1%|          | 5/409 [00:00<00:57,  7.05it/s][A[A

  2%|▏         | 7/409 [00:00<00:51,  7.87it/s][A[A

  2%|▏         | 9/409 [00:00<00:46,  8.57it/s][A[A

  3%|▎         | 11/409 [00:01<00:43,  9.11it/s][A[A

  3%|▎         | 13/409 [00:01<00:41,  9.53it/s][A[A

  4%|▎         | 15/409 [00:01<00:40,  9.84it/s][A[A

  4%|▍         | 17/409 [00:01<00:38, 10.07it/s][A[A

  5%|▍         | 19/409 [00:01<00:38, 10.22it/s][A[A

  5%|▌         | 21/409 [00:02<00:37, 10.39it/s][A[A

  6%|▌         | 23/409 [00:02<00:36, 10.52it/s][A[A

  6%|▌         | 25/409 [00:02<00:36, 10.61it/s][A[A

  7%|▋         | 27/409 [00:02<00:35, 10.61it/s][A[A

  7%|▋         | 29/409 [00:02<00:35, 10.61it/s][A[A

  8%|▊         | 31/409 [00:02<00:35, 10.61it/s][A[A

  8%|▊         | 33/409 [00:03<00:35, 10.66it/s][A[A

  9%

6 0.0020996073197145334 26.77861921850701 26.932520060416167




  0%|          | 0/409 [00:00<?, ?it/s][A[A

  0%|          | 1/409 [00:00<01:16,  5.33it/s][A[A

  1%|          | 3/409 [00:00<01:04,  6.29it/s][A[A

  1%|          | 5/409 [00:00<00:56,  7.19it/s][A[A

  2%|▏         | 7/409 [00:00<00:50,  7.94it/s][A[A

  2%|▏         | 9/409 [00:00<00:46,  8.60it/s][A[A

  3%|▎         | 11/409 [00:01<00:43,  9.16it/s][A[A

  3%|▎         | 13/409 [00:01<00:41,  9.59it/s][A[A

  4%|▎         | 15/409 [00:01<00:39,  9.94it/s][A[A

  4%|▍         | 17/409 [00:01<00:38, 10.11it/s][A[A

  5%|▍         | 19/409 [00:01<00:37, 10.28it/s][A[A

  5%|▌         | 21/409 [00:02<00:37, 10.39it/s][A[A

  6%|▌         | 23/409 [00:02<00:36, 10.48it/s][A[A

  6%|▌         | 25/409 [00:02<00:36, 10.57it/s][A[A

  7%|▋         | 27/409 [00:02<00:35, 10.65it/s][A[A

  7%|▋         | 29/409 [00:02<00:35, 10.68it/s][A[A

  8%|▊         | 31/409 [00:02<00:35, 10.69it/s][A[A

  8%|▊         | 33/409 [00:03<00:35, 10.68it/s][A[A

  9%

7 0.0019933337469250512 27.004199806205914 26.74772338498005




  0%|          | 0/409 [00:00<?, ?it/s][A[A

  0%|          | 1/409 [00:00<01:17,  5.26it/s][A[A

  1%|          | 3/409 [00:00<01:04,  6.26it/s][A[A

  1%|          | 5/409 [00:00<00:56,  7.18it/s][A[A

  2%|▏         | 7/409 [00:00<00:50,  7.99it/s][A[A

  2%|▏         | 9/409 [00:00<00:46,  8.62it/s][A[A

  3%|▎         | 11/409 [00:01<00:43,  9.18it/s][A[A

  3%|▎         | 13/409 [00:01<00:41,  9.62it/s][A[A

  4%|▎         | 15/409 [00:01<00:39,  9.90it/s][A[A

  4%|▍         | 17/409 [00:01<00:38, 10.10it/s][A[A

  5%|▍         | 19/409 [00:01<00:37, 10.29it/s][A[A

  5%|▌         | 21/409 [00:02<00:37, 10.42it/s][A[A

  6%|▌         | 23/409 [00:02<00:36, 10.52it/s][A[A

  6%|▌         | 25/409 [00:02<00:36, 10.60it/s][A[A

  7%|▋         | 27/409 [00:02<00:36, 10.60it/s][A[A

  7%|▋         | 29/409 [00:02<00:35, 10.62it/s][A[A

  8%|▊         | 31/409 [00:02<00:35, 10.62it/s][A[A

  8%|▊         | 33/409 [00:03<00:35, 10.66it/s][A[A

  9%

8 0.00203014443677447 26.924730626470556 27.09936388568117




  0%|          | 0/409 [00:00<?, ?it/s][A[A

  0%|          | 1/409 [00:00<01:18,  5.21it/s][A[A

  1%|          | 3/409 [00:00<01:05,  6.16it/s][A[A

  1%|          | 5/409 [00:00<00:57,  7.06it/s][A[A

  2%|▏         | 7/409 [00:00<00:51,  7.85it/s][A[A

  2%|▏         | 9/409 [00:00<00:46,  8.53it/s][A[A

  3%|▎         | 11/409 [00:01<00:43,  9.08it/s][A[A

  3%|▎         | 13/409 [00:01<00:41,  9.54it/s][A[A

  4%|▎         | 15/409 [00:01<00:39,  9.88it/s][A[A

  4%|▍         | 17/409 [00:01<00:38, 10.10it/s][A[A

  5%|▍         | 19/409 [00:01<00:37, 10.27it/s][A[A

  5%|▌         | 21/409 [00:02<00:37, 10.37it/s][A[A

  6%|▌         | 23/409 [00:02<00:36, 10.48it/s][A[A

  6%|▌         | 25/409 [00:02<00:36, 10.59it/s][A[A

  7%|▋         | 27/409 [00:02<00:35, 10.63it/s][A[A

  7%|▋         | 29/409 [00:02<00:35, 10.67it/s][A[A

  8%|▊         | 31/409 [00:02<00:35, 10.71it/s][A[A

  8%|▊         | 33/409 [00:03<00:35, 10.64it/s][A[A

  9%

9 0.002035552322122207 26.913177299748725 26.26673221808247


In [14]:
"""
画像を拡大してオリジナルと比較する
"""
from torchvision.utils import save_image

# テストのデータセットからランダムに4つずつ取り出すDataLoader
random_test_loader = DataLoader(test_data, batch_size=4, shuffle=True)
# DataLoaderをPythonのイテレータにhenkannsi,4つ例を取り出す
it = iter(random_test_loader)
x, y = next(it)

# Bilineadで拡大
bl_recon = torch.nn.functional.upsample(x, 128, mode="bilinear", align_corners=True)
# CNNで拡大
yp = net(x.to("cuda:0")).to("cpu")

# torch.catでオリジナル、Bilinear、CNNの画像を結合しsage_imageで画像ファイルに書き出し
save_image(torch.cat([y, bl_recon, yp], 0), "cnn_upscale.jpg", nrow=4)

RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same

# DCGANによる画像生成

In [15]:
img_data = ImageFolder("ch4data/oxford-102/",
                      transform=transforms.Compose([
                          transforms.Resize(80),
                          transforms.CenterCrop(64),
                          transforms.ToTensor()
                      ]))
batch_size = 64
img_loader = DataLoader(img_data, batch_size=batch_size, shuffle=True)

PyTorchによるDCGAN

In [16]:
"""
画像の生成モデルを組み立てる
"""
nz = 100
ngf = 32

class GNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(ngf, 3, 4, 2, 1, bias=False),
            nn.Tanh()
        )
    def forward(self, x):
        out = self.main(x)
        return out

In [17]:
"""
画像の識別モデルを組み立てる
"""
ndf = 32
class DNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.main = nn.Sequential(
            nn.Conv2d(3, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False)
        )
        
    def forward(self, x):
        out = self.main(x)
        return out.squeeze()

In [18]:
"""
訓練関数の作成
"""
d = DNet().to("cuda:0")
g = GNet().to("cuda:0")

# Adamのパラメータは元論文(dcgan-paper)の提案値
opt_d = optim.Adam(d.parameters(), lr=0.0002, betas=(0.5, 0.999))
opt_g = optim.Adam(g.parameters(), lr=0.0002, betas=(0.5, 0.999))

# クロスエントロピーを計算するための補助変数など
ones = torch.ones(batch_size).to("cuda:0")
zeros = torch.zeros(batch_size).to("cuda:0")
loss_f = nn.BCEWithLogitsLoss()

# モニタリング用のz
fixed_z = torch.randn(batch_size, nz, 1, 1).to("cuda:0")

In [19]:
"""
訓練関数
"""
from statistics import mean

def train_dcgan(g, d, opt_g, opt_d, loader):
    # 生成モデル、識別モデルの目的関数の追跡用の配列
    log_loss_g = []
    log_loss_d = []
    for real_img, _ in tqdm.tqdm(loader):
        batch_len = len(real_img)
        
        # 実際の画像をGPUにコピー
        real_img = real_img.to("cuda:0")
        
        # 偽画像を乱数と生成モデルから作る
        z = torch.randn(batch_len, nz, 1, 1).to("cuda:0")
        fake_img = g(z)
        
        # あとで使用するので偽画像の値のみ取り出しておく
        fake_img_tensor = fake_img.detach()
        
        # 偽画像に対する生成モデルの評価関数を計算する
        out = d(fake_img)
        loss_g = loss_f(out, ones[: batch_len])
        log_loss_g.append(loss_g.item())
        
        # 計算グラフが生成モデルと識別モデルの両方に依存しているので、両者とも勾配をクリアしてから
        # 微分の計算とパラメータ更新を行う
        d.zero_grad(), g.zero_grad()
        loss_g.backward()
        opt_g.step()
        
        # 実際の画像に対する識別萌えるの評価関数を計算
        real_out = d(real_img)
        loss_d_real = loss_f(real_out, ones[: batch_len])
        
        # PyTorchでは同じTensorを含んだ計算グラフに対して2回backwardを行うことができないので、
        # 保存してあったTensorを使用して無駄な計算を省く
        fake_img = fake_img_tensor
        
        # 偽モデルの対する識別モデルの評価関数の計算
        fake_out = d(fake_img_tensor)
        loss_d_fake = loss_f(fake_out, zeros[: batch_len])
        
        # 実偽の評価関数の合計値
        loss_d = loss_d_real + loss_d_fake
        log_loss_d.append(loss_d.item())
        
        # 識別モデルの微分計算とパラメータ更新
        d.zero_grad(), g.zero_grad()
        loss_d.backward()
        opt_d.step()
    return mean(log_loss_g), mean(log_loss_d)

In [20]:
"""
DCGANの訓練
"""
for epoch in range(300):
    train_dcgan(g, d, opt_g, opt_d, img_loader)
    # 10回の繰り返しごとに学習結果を保存する
    if epoch % 10 == 0:
        # パラメータの保存
        torch.save(
            g.state_dict(),
            "dcgan_out/g_{:03d}.prm".format(epoch),
            pickle_protocol=4
        )
        torch.save(
            d.state_dict(),
            "dcgan_out/d_{:03d}.prm".format(epoch),
            pickle_protocol=4
        )
        # モニタリング用のzから生成した画像を保存
        generated_img = g(fixed_z)
        save_image(generated_img,
                   "dcgan_out/{:03d}.jpg".format(epoch))

100%|██████████| 128/128 [01:02<00:00,  2.09it/s]
100%|██████████| 128/128 [01:01<00:00,  2.07it/s]
100%|██████████| 128/128 [01:02<00:00,  2.04it/s]
100%|██████████| 128/128 [01:01<00:00,  2.14it/s]
100%|██████████| 128/128 [01:01<00:00,  2.12it/s]
100%|██████████| 128/128 [01:01<00:00,  2.11it/s]
100%|██████████| 128/128 [01:01<00:00,  2.14it/s]
100%|██████████| 128/128 [01:01<00:00,  2.01it/s]
100%|██████████| 128/128 [01:02<00:00,  2.10it/s]
100%|██████████| 128/128 [01:03<00:00,  2.01it/s]
100%|██████████| 128/128 [01:03<00:00,  2.04it/s]
100%|██████████| 128/128 [01:03<00:00,  2.02it/s]
100%|██████████| 128/128 [01:04<00:00,  2.05it/s]
100%|██████████| 128/128 [01:03<00:00,  2.04it/s]
100%|██████████| 128/128 [01:03<00:00,  2.03it/s]
100%|██████████| 128/128 [01:03<00:00,  2.05it/s]
100%|██████████| 128/128 [01:03<00:00,  2.08it/s]
100%|██████████| 128/128 [01:03<00:00,  2.06it/s]
100%|██████████| 128/128 [01:02<00:00,  2.11it/s]
100%|██████████| 128/128 [01:02<00:00,  2.04it/s]


100%|██████████| 128/128 [01:01<00:00,  2.09it/s]
100%|██████████| 128/128 [01:02<00:00,  2.09it/s]
100%|██████████| 128/128 [01:02<00:00,  2.07it/s]
100%|██████████| 128/128 [01:02<00:00,  2.01it/s]
100%|██████████| 128/128 [01:02<00:00,  2.09it/s]
100%|██████████| 128/128 [01:01<00:00,  2.06it/s]
100%|██████████| 128/128 [01:02<00:00,  2.09it/s]
100%|██████████| 128/128 [01:01<00:00,  2.10it/s]
100%|██████████| 128/128 [01:01<00:00,  2.13it/s]
100%|██████████| 128/128 [01:02<00:00,  2.10it/s]
100%|██████████| 128/128 [01:01<00:00,  2.07it/s]
100%|██████████| 128/128 [01:02<00:00,  2.12it/s]
100%|██████████| 128/128 [01:01<00:00,  2.13it/s]
100%|██████████| 128/128 [01:02<00:00,  2.09it/s]
100%|██████████| 128/128 [01:02<00:00,  2.10it/s]
100%|██████████| 128/128 [01:02<00:00,  2.05it/s]
100%|██████████| 128/128 [01:02<00:00,  2.10it/s]
100%|██████████| 128/128 [01:02<00:00,  2.05it/s]
100%|██████████| 128/128 [01:01<00:00,  2.09it/s]
100%|██████████| 128/128 [01:02<00:00,  2.09it/s]
