# DCGAN

*Deep Convolitional GAN*

畳み込み層を用いたGAN。DCGANではなく普通にGANと呼ばれることも多い。

In [None]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
import torchvision
from torchvision.datasets import MNIST
from torchvision import transforms
from IPython.display import display
from tqdm import tqdm


batch_size = 64
nz = 100
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## MNIST

In [None]:
dataset = MNIST(
    root="datasets/",
    train=True,
    download=True,
    transform=transforms.ToTensor()
)

dataloader = DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=True,
    drop_last=True
)

sample_x, _ = next(iter(dataloader))
print("batch shape: ", sample_x.shape)

## Discriminator



In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            self._conv(1, 16, 4, 2, 1),
            self._conv(16, 32, 4, 2, 1),
            self._conv(32, 64, 3, 2, 0),
            nn.Conv2d(64, 128, 3, 1, 0),
            nn.Flatten(),
            nn.Linear(128, 1),
            nn.Sigmoid(),
        )

    def _conv(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2)
        )

    def forward(self, x):
        y = self.net(x)
        return y

## Generator

In [None]:
class Generator(nn.Module):
    def __init__(self, nz):
        super().__init__()
        self.net = nn.Sequential(
            self._convT(nz, 128, 3, 1, 0),
            self._convT(128, 64, 3, 2, 0),
            self._convT(64, 32, 4, 2, 1),
            nn.ConvTranspose2d(32, 1, 4, 2, 1),
            nn.Sigmoid()
        )

    def _convT(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )

    def forward(self, x):
        y = self.net(x)
        return y

## 学習

In [None]:
def make_noise(batch_size):
    return torch.randn(batch_size, nz, 1, 1, device=device)

def write(netG, n_rows=1, n_cols=8, size=64):
    z = make_noise(n_rows*n_cols)
    images = netG(z)
    images = transforms.Resize(size)(images)
    img = torchvision.utils.make_grid(images, n_cols)
    img = transforms.functional.to_pil_image(img)
    display(img)

In [None]:
fake_labels = torch.zeros(batch_size, 1).to(device)
real_labels = torch.ones(batch_size, 1).to(device)
criterion = nn.BCELoss()

def train(netD, netG, optimD, optimG, n_epochs, write_interval=1, progress=True, print_interval=1):
    netD.train()
    netG.train()
    for epoch in range(n_epochs):
        epoch += 1
        lossD_list = []
        lossG_list = []
        for X, _ in tqdm(dataloader, disable=not progress):
            X = X.to(device)
            optimD.zero_grad()
            optimG.zero_grad()

            z = make_noise(batch_size)
            fake = netG(z)
            pred_fake = netD(fake)
            pred_real = netD(X)
            loss_fake = criterion(pred_fake, fake_labels)
            loss_real = criterion(pred_real, real_labels)
            lossD = loss_fake + loss_real
            lossD.backward()
            optimD.step()

            fake = netG(z)
            pred = netD(fake)
            lossG = criterion(pred, real_labels)
            lossG.backward()
            optimG.step()

            lossD_list.append(lossD.item())
            lossG_list.append(lossG.item())

        if print_interval and epoch % print_interval == 0:
            print(f'{epoch:>3}epoch | lossD: {lossD}, lossG: {lossG}')
        if write_interval and epoch % write_interval == 0:
            write(netG)

In [None]:
netD = Discriminator().to(device)
netG = Generator(nz).to(device)
optimD = optim.Adam(netD.parameters(), lr=0.0002)
optimG = optim.Adam(netG.parameters(), lr=0.0002)
n_epochs = 30

train(netD, netG, optimD, optimG, n_epochs, 5, False)

In [None]:
write(netG, 4)


---

## ポケモンを作ってみる

### データセット

[Pokemon Images Dataset | Kaggle](https://www.kaggle.com/kvpratama/pokemon-images-dataset)

In [None]:
from torchvision.datasets import ImageFolder

dataset = ImageFolder(
    "datasets/pokemon",
    transform=transforms.Compose([
        transforms.Resize(64),
        transforms.ToTensor(),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.ColorJitter(),
    ])
)
dataloader = DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=True,
    drop_last=True
)

sample_x, _ = next(iter(dataloader))
w, h = sample_x.shape[2:]
image_size = w * h # 画像のサイズ
print("batch shape:", sample_x.shape)
print("width:", w)
print("height:", h)
print("image size:", image_size)

### Discriminator

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            self._conv(3, 32, 4, 2, 1), # 32
            self._conv(32, 64, 4, 2, 1), # 16
            self._conv(64, 128, 4, 2, 1), # 8
            self._conv(128, 256, 4, 2, 1), # 4
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )

    def _conv(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2)
        )

    def forward(self, x):
        y = self.net(x)
        return y

### Generator

In [None]:
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            self._convT(nz, 256, 4, 1, 0), # 4
            self._convT(256, 128, 4, 2, 1), # 8
            self._convT(128, 64, 4, 2, 1), # 16
            self._convT(64, 32, 4, 2, 1), # 32
            nn.ConvTranspose2d(32, 3, 4, 2, 1), # 64
            nn.Tanh()
        )

    def _convT(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding),
            nn.InstanceNorm2d(out_channels),
            nn.ReLU()
        )

    def forward(self, x):
        y = self.net(x)
        return y

### 学習

In [None]:
# ノイズを生成する関数
def make_noise(batch_size=batch_size):
    return torch.randn(batch_size, nz, 1, 1, device=device)

# 画像を生成して描画する関数
def write(netG, n_rows=1, n_cols=8, size=64):
    z = make_noise(n_rows*n_cols)
    images = netG(z)
    images = (images + 1) / 2
    images = transforms.Resize(size)(images)
    img = torchvision.utils.make_grid(images, n_cols)
    img = transforms.functional.to_pil_image(img)
    display(img)

In [None]:
real_labels = torch.zeros(batch_size, 1).to(device)
fake_labels = torch.ones(batch_size, 1).to(device)
criterion = nn.BCELoss()


def train(netD, netG, optimD, optimG, n_epochs, draw_freq=10):
    for epoch in range(1, n_epochs+1):
        for x, _ in dataloader:
            x = x * 2 - 1
            x = x.to(device)

            # 勾配をリセット
            optimD.zero_grad()
            optimG.zero_grad()

            # 識別器の学習 -----------------------------------------------
            z = make_noise()
            fake = netG(z)
            pred_fake = netD(fake)
            pred_real = netD(x)
            loss_fake = criterion(pred_fake, fake_labels)
            loss_real = criterion(pred_real, real_labels)
            lossD = loss_fake + loss_real
            lossD.backward()
            optimD.step()

            # 生成器の学習 -----------------------------------------------
            fake = netG(z)
            pred = netD(fake)
            lossG = criterion(pred, real_labels)
            lossG.backward()
            optimG.step()

        if epoch % draw_freq == 0:
            print(f'{epoch:>3}epoch | lossD: {lossD:.4f}, lossG: {lossG:.4f}')
            write(netG)

In [None]:
netD = Discriminator().to(device)
netG = Generator().to(device)
optimD = optim.Adam(netD.parameters(), lr=0.0002)
optimG = optim.Adam(netG.parameters(), lr=0.0002)

write(netG)

In [None]:
train(netD, netG, optimD, optimG, 50, draw_freq=5)