<a href="https://colab.research.google.com/github/iiyama-lab/semi_tutorial/blob/main/20220531.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# GANのコード（全部入り）
「ドライブにコピー」してから使ってください

# 0. ドライブのマウント
ドライブをマウントします

In [None]:
from google.colab import drive
drive.mount("/content/drive")

# 1. いろいろインポート

In [None]:
from tqdm.notebook import tqdm
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torchvision.io import read_image
from torchvision import transforms
import os
import glob

import matplotlib.pyplot as plt
import numpy as np


# 2. データローダの作成

In [None]:
class GANImageDataset(Dataset):
    """GAN用のImageDataset

    Attributes:
        filenames (list): 画像(PNG画像)のファイル名リスト
        transform (obj): 画像変換用の関数
    """

    def __init__(self, img_dir, transform=None):
        """
        Args:
            img_dir: 画像が置いてあるディレクトリ名
            transform: 画像変換用の関数
        """
        self.transform = transform
        self.filenames = glob.glob(os.path.join(img_dir, "*/*.png"))
        print(f"{self.__len__()} images for training")

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        img_path = self.filenames[idx]
        image = read_image(img_path)
        if self.transform:
            image = self.transform(image)
        return image

In [None]:
class ImageTransform():
    def __init__(self,  mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]):
        self.data_transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize(64),
            transforms.CenterCrop(64),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ])

    def __call__(self, img):
        return self.data_transform(img)

In [None]:
def tensor2image(image,  mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]):
    """
    Args:
        image: pytorch Tensor
    """
    inp = image.numpy().transpose((1, 2, 0))
    mean = np.array(mean)
    std = np.array(std)
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    return inp


In [None]:
def show_images(images, filename=None, ncols=8, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]):
    nImages = images.shape[0]
    width = images.shape[3]
    height = images.shape[2]
    nrows = nImages // ncols

    buf = np.zeros((ncols*height, nrows*width, 3))
    idx = 0
    for r in range(nrows):
        for c in range(ncols):
            if idx >= nImages:
                continue
            buf[c*height:(c+1)*height, r*width:(r+1)*width,
                :] = tensor2image(images[idx], mean, std)
            idx += 1

    fig, ax = plt.subplots()
    ax.imshow(buf)
    if filename is None:
        filename = "out.png"
    fig.savefig(filename)
    plt.close()

# 3. モデルの作成

In [None]:
class Generator(nn.Module):
    def __init__(self, z_dim=100):
        super(Generator, self).__init__()

        self.layer1 = nn.Sequential(
            nn.utils.spectral_norm(nn.ConvTranspose2d(
                z_dim, 1024, kernel_size=4, stride=1, padding=0)),
            nn.BatchNorm2d(1024),
            nn.ReLU(inplace=True)
        )

        self.layer2 = nn.Sequential(
            nn.utils.spectral_norm(nn.ConvTranspose2d(
                1024, 512, kernel_size=4, stride=2, padding=1)),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True)
        )

        self.layer3 = nn.Sequential(
            nn.utils.spectral_norm(nn.ConvTranspose2d(
                512, 256, kernel_size=4, stride=2, padding=1)),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True)
        )

        self.layer4 = nn.Sequential(
            nn.utils.spectral_norm(nn.ConvTranspose2d(
                256, 128, kernel_size=4, stride=2, padding=1)),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True)
        )

        self.last = nn.Sequential(
            nn.utils.spectral_norm(nn.ConvTranspose2d(
                128, 64, kernel_size=4, stride=2, padding=1)),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.utils.spectral_norm(nn.Conv2d(64, 3, kernel_size=3, padding=1)),
            nn.Tanh()
        )

    def forward(self, z):
        out = self.layer1(z)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.last(out)
        return out

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.layer1 = nn.Sequential(
            nn.utils.spectral_norm(
                nn.Conv2d(3, 32, kernel_size=3, padding=1, stride=2)),
            nn.LeakyReLU(0.1, inplace=True)
        )

        self.layer2 = nn.Sequential(
            nn.utils.spectral_norm(
                nn.Conv2d(32, 64, kernel_size=3, padding=1, stride=2)),
            nn.LeakyReLU(0.1, inplace=True)
        )

        self.layer3 = nn.Sequential(
            nn.utils.spectral_norm(
                nn.Conv2d(64, 128, kernel_size=3, padding=1, stride=2)),
            nn.LeakyReLU(0.1, inplace=True)
        )

        self.layer4 = nn.Sequential(
            nn.utils.spectral_norm(
                nn.Conv2d(128, 256, kernel_size=3, padding=1, stride=2)),
            nn.LeakyReLU(0.1, inplace=True)
        )

        self.last = nn.Conv2d(256, 1, kernel_size=3, stride=2)

    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.last(out)

        return out

# 4. 訓練

In [None]:
class Train_model:
    def __init__(self, device):
        self.device = device

    def initialize(self, G, D, z_dim, setting):
        self.D = D
        self.G = G
        self.z_dim = z_dim
        g_lr = setting["g_lr"]
        d_lr = setting["d_lr"]
        beta1 = setting["beta1"]
        beta2 = setting["beta2"]
        self.g_optimizer = torch.optim.Adam(self.G.parameters(), g_lr, [beta1, beta2])
        self.d_optimizer = torch.optim.Adam(self.D.parameters(), d_lr, [beta1, beta2])
        # self.g_optimizer = torch.optim.SGD(self.G.parameters(), g_lr)
        # self.d_optimizer = torch.optim.SGD(self.D.parameters(), d_lr)

        self.D.to(self.device)
        self.G.to(self.device)

        torch.backends.cudnn.benchmark = True

        self.criterion = nn.BCEWithLogitsLoss(reduction="mean")

    def generate(self,batch_size):
        with torch.no_grad():
          input_z = torch.randn(batch_size, self.z_dim, 1, 1).to(self.device)
          fake_images = self.G(input_z)      
        return fake_images

    def init_unroll_GAN(self, batch_size):
        self.nRoll = 10
        self.nBuffer = 11
        self.old_fake_images = [self.generate(batch_size) for i in range(self.nBuffer)]
        self.current_idx = 0

    def train(self, dataloader, num_epochs):
        num_images = len(dataloader.dataset)
        self.batch_size = dataloader.batch_size
        z = torch.randn(self.batch_size, self.z_dim, 1, 1).to(self.device)

        pbar_epoch = tqdm(total=num_epochs)
        pbar_epoch.set_description("epoch")
        self.G.train()
        self.D.train()
        self.init_unroll_GAN(self.batch_size)

        for epoch in range(num_epochs):
            epoch_d_loss = 0
            epoch_g_loss = 0
            self.G.train()
            self.D.train()
            pbar_batch = tqdm(total=num_images, leave=False)
            for i, images in enumerate(dataloader):
                _batch_size = images.size()[0]
                epoch_d_loss += self.train_D(images, _batch_size)
                epoch_g_loss += self.train_G(_batch_size)
                pbar_batch.set_postfix(
                    {"dLoss": epoch_d_loss / (i + 1), "gLoss": epoch_g_loss / (i + 1)}
                )
                pbar_batch.update(_batch_size)
            pbar_epoch.set_postfix(
                {"dLoss": epoch_d_loss / (i + 1), "gLoss": epoch_g_loss / (i + 1)}
            )
            pbar_epoch.update()

            sample, pred = self.generate_fake_images(z)
            show_images(sample)
            if epoch % 10 == 0:
                show_images(sample, f"out{epoch:04d}.png")

    def train_D(self, images, batch_size):
        if batch_size != self.batch_size:
            return 0
        self.D.zero_grad()

        images = images.to(self.device)
        label_real = torch.full((batch_size,), 0.0).to(self.device)
        label_fake = torch.full((batch_size,), 1.0).to(self.device)

        d_out_real = self.D(images)
        
        #fake_images = self.generate(batch_size)
        fake_images = self.old_fake_images[self.current_idx]
        d_out_fake = self.D(fake_images)

        d_loss_real = self.criterion(d_out_real.view(-1), label_real)
        d_loss_fake = self.criterion(d_out_fake.view(-1), label_fake)
        d_loss = d_loss_real + d_loss_fake

        d_loss.backward()
        self.d_optimizer.step()

        return d_loss.item()

    def train_G(self, batch_size):
        if batch_size != self.batch_size:
            return 0
        self.G.zero_grad()

        label_real = torch.full((batch_size,), 0.0).to(self.device)

        
        self.old_fake_images[(self.current_idx + self.nRoll) % self.nBuffer] = self.generate(batch_size)
        self.current_idx = (self.current_idx + 1) % self.nBuffer

        input_z = torch.randn(batch_size, self.z_dim, 1, 1).to(self.device)
        fake_images = self.G(input_z)   
        d_out_fake = self.D(fake_images)

        g_loss = self.criterion(d_out_fake.view(-1), label_real)

        g_loss.backward()
        self.g_optimizer.step()

        return g_loss.item()

    def generate_fake_images(self, z):
        self.D.eval()
        self.G.eval()

        fake_images = self.G(z)
        d_out = self.D(fake_images).view(-1).to("cpu").detach().numpy()
        fake_images = fake_images.to("cpu").detach()
        return fake_images, d_out

# 5. 学習はこちら
データの場所とか、エポック数とか学習率とかはここで指定してください。

In [None]:
batch_size = 64
z_dim = 100
setting = {"g_lr": 1.0e-4, "d_lr": 5.0e-4, "beta1": 0.5, "beta2": 0.999}

device = "cuda:0" if torch.cuda.is_available() else "cpu"
print (device)

datadir = "/content/drive/MyDrive/iiyama-lab2022/data/face/train"
#datadir = "/root/data/share/face/train"
dataset = GANImageDataset(datadir, transform=ImageTransform())
dataloader = torch.utils.data.DataLoader(
dataset, batch_size=batch_size, shuffle=True, num_workers=2)

G = Generator(z_dim=z_dim)
D = Discriminator()

trainer = Train_model(device)
trainer.initialize(G, D, z_dim, setting)
trainer.train(dataloader, num_epochs=200)