In [15]:
import pandas as pd 
import numpy as np 
import os 
import time 
import gc 
import random 
import warnings 
from pprint import pprint 
from PIL import Image 
import cv2 
from tqdm import tqdm 

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
from torch.utils.data import DataLoader, Dataset 

# アニメ画像の生成タスク

ランダムなベクトルから顔画像を生成するモデルを学習させる。  
`DCGAN`を使用する.


In [16]:
!nvidia-smi

In [17]:
def random_seed(SEED):
    random.seed(SEED)
    os.environ['PYTHONHASHSEED'] = str(SEED)
    np.random.seed(SEED)
    torch.manual_seed(SEED)
    torch.cuda.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    
    
config = {
    "model_name": "dcgan", 
    "device": "cuda:0" if torch.cuda.is_available() else "cpu", 
    "batch_size": 128, 
    "img_size": 64, 
    "z_fill": 100, 
    "n_channel": 3, 
    "mid_size": 64, 
    "epochs": 50, 
    "lr": 0.0002, 
    "beta1": 0.5, 
    "debug": False, 
}


ROOT_PATH = "../input/animefacedataset/images"
warnings.simplefilter("ignore")
pprint(config)

In [18]:
fig, axes = plt.subplots(8, 8, figsize=(8, 8))
ax = axes.ravel()

for i, path in enumerate(os.listdir(ROOT_PATH)):
    img_file_path = os.path.join(ROOT_PATH, path)
    img = Image.open(img_file_path)
    
    ax[i].imshow(img)
    ax[i].set_xticks([])
    ax[i].set_yticks([])   
    
    if i >= 63:
        break 
    
plt.tight_layout()

In [19]:
class AnimeDataset(Dataset):
    def __init__(self, root_path: str=ROOT_PATH, config=None):
        self.root_path = root_path 
        self.img_files = os.listdir(self.root_path)
        self.transform = transforms.Compose([
            transforms.Resize(config["img_size"]),
            transforms.CenterCrop(config["img_size"]), 
            transforms.ToTensor(), 
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        
    def decode(self, img_file):
        img = Image.open(img_file)
        return self.transform(img)
        
    def __getitem__(self, idx):
        img_file = self.img_files[idx]
        img = self.decode(os.path.join(self.root_path, img_file))
        return img 
    
    def __len__(self):
        return len(self.img_files)
    

In [20]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)
        

In [21]:
class Generator(nn.Module):
    def __init__(self, config):
        super(Generator, self).__init__()
        ngf = config["mid_size"]
        self.main = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d(config["z_fill"], ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # state size. (ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # state size. (ngf*2) x 16 x 16
            nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # state size. (ngf) x 32 x 32
            nn.ConvTranspose2d( ngf, config["n_channel"], 4, 2, 1, bias=False),
            nn.Tanh()
            # state size. (nc) x 64 x 64
        )

    def forward(self, input):
        return self.main(input)

In [22]:
G = Generator(config)
G.apply(weights_init)

a = torch.rand(2, 100, 1, 1) # インプットするベクトルサイズ
y = G(a)

for bs in range(y.size()[0]):
    plt.imshow(y[bs, :, :, :].detach().cpu().permute(1, 2, 0))
    plt.axis("off")
    
    break 

In [23]:
class Discriminator(nn.Module):
    def __init__(self, config):
        super(Discriminator, self).__init__()
        ndf = config["mid_size"]
        self.main = nn.Sequential(
            # input is (nc) x 64 x 64
            nn.Conv2d(config["n_channel"], ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf) x 32 x 32
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*2) x 16 x 16
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*4) x 8 x 8
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*8) x 4 x 4
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input)

In [24]:
D = Discriminator(config)
D.apply(weights_init)

print(f"input shape: {y.size()}")

with torch.no_grad():
    yy = D(y)
    
print(f"output shape: {yy.size()}")

### 訓練する
正解データが存在しないのですべてのデータを使って学習させる。

In [25]:
criterion = nn.BCELoss()

real_label = 1 
fake_label = 0 

fixed_noise = torch.randn(64, config["z_fill"], 1, 1, device=config["device"])

optimG = optim.Adam(G.parameters(), lr=config["lr"], betas=(config["beta1"], 0.999))
optimD = optim.Adam(D.parameters(), lr=config["lr"], betas=(config["beta1"], 0.999))

dataset = AnimeDataset(ROOT_PATH, config=config)
dataloader = DataLoader(dataset, batch_size=config["batch_size"], shuffle=True, num_workers=2)


In [26]:
dataset[9]

In [27]:
img_list = []
G_losses = []
D_losses = []
iters = 0

G.to(config["device"])
D.to(config["device"])

G.train()
D.train()


for e in range(2 if config["debug"] else config["epochs"]):
    
    for i, img in tqdm(enumerate(dataloader)):
        real_img = img.to(config["device"])
        
        # Discriminator の学習
        # 本物画像は１に、偽物画像は0に分類するように学習する
        D.zero_grad()
        
        bs = img.size()[0]
        label = torch.full((bs, ), real_label, device=config["device"], dtype=torch.float)
        
        output = D(real_img).view(-1)
        
        errD_real = criterion(output, label)
        
        errD_real.backward()
        
        noise = torch.randn(bs, config["z_fill"], 1, 1, device=config["device"])
        
        fake_img = G(noise)
        label.fill_(fake_label)
        
        output = D(fake_img.detach()).view(-1)
        errD_fake = criterion(output, label)
        errD_fake.backward()
        
        errD = errD_real + errD_fake
        
        optimD.step()
        
        # Generator の学習
        # 偽物画像を1に分類するように学習する
        G.zero_grad()
        label.fill_(real_label)
        
        output = D(fake_img).view(-1)
        errG = criterion(output, label)
        
        errG.backward()
        optimG.step()
        
        if i % 1000 == 0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f'
                  % (e, config["epochs"], i, len(dataloader),
                     errD.item(), errG.item()))

            # Save Losses for plotting later
            G_losses.append(errG.item())
            D_losses.append(errD.item())

        if (iters % 250 == 0) or ((e == config["epochs"]-1) and (i == len(dataloader)-1)):
            with torch.no_grad():
                fake_img = G(fixed_noise).detach().cpu()
            img_list.append(vutils.make_grid(fake_img, padding=2, normalize=True))

        iters += 1

### Saved model weighted


In [30]:
os.makedirs("models", exist_ok=True)

torch.save(G.state_dict(), f"models/anime_G_{str(config['epochs'])}.pth")

torch.save(D.state_dict(), f"models/anime_D_{str(config['epochs'])}.pth")

In [31]:
plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses,label="G")
plt.plot(D_losses,label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()

In [32]:
fig = plt.figure(figsize=(8,8))
plt.axis("off")
ims = [[plt.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in img_list]
ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)

HTML(ani.to_jshtml())

### Final Results 

In [33]:
# Grab a batch of real images from the dataloader
real_batch = next(iter(dataloader))

# Plot the real images
plt.figure(figsize=(15,15))
plt.subplot(1,2,1)
plt.axis("off")
plt.title("Real Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch.to(config["device"])[:64], padding=5, normalize=True).cpu(),(1,2,0)))

# Plot the fake images from the last epoch
plt.subplot(1,2,2)
plt.axis("off")
plt.title("Fake Images")
plt.imshow(np.transpose(img_list[-1],(1,2,0)))
plt.show()