# VAE 
- ### VAE(Variational)とは?
####  ディープラーニングを使用した生成モデルの一つです。具体的には学習したデータに**近似**したデータを生み出せます。


###1. 必要なライブラリのインストール&インポート




In [None]:
!pip install torchinfo

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting torchinfo
  Downloading torchinfo-1.7.0-py3-none-any.whl (22 kB)
Installing collected packages: torchinfo
Successfully installed torchinfo-1.7.0


In [None]:
import os
import cv2
import torch
import datetime
from torch.autograd import Variable
import numpy as np
import torch.nn.functional as F
import torchvision
from torchvision import transforms
import torch.optim as optim
from torch import nn
import matplotlib.pyplot as plt
from torchinfo import summary

from torch.utils.data import Dataset
from pathlib import Path
from typing import List, Tuple

from tqdm.notebook import tqdm

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

cuda:0


###2. 学習,テストデータ読み出し

In [None]:
from pathlib import Path

from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image

class ImageFolder(Dataset):
    IMG_EXTENSIONS = [".jpg", ".jpeg", ".png", ".bmp"]

    def __init__(self, img_dir, transform=None):
        # 画像ファイルのパス一覧を取得する。
        self.img_paths = self._get_img_paths(img_dir)
        self.transform = transform

    def __getitem__(self, index):
        path = self.img_paths[index]
        
        # 画像を読み込む。
        img = Image.open(path)

        if self.transform is not None:
            # 前処理がある場合は行う。
            img = self.transform(img)

        return img

    def _get_img_paths(self, img_dir):
        """指定したディレクトリ内の画像ファイルのパス一覧を取得する。
        """
        img_dir = Path(img_dir)
        img_paths = [
            p for p in img_dir.iterdir() if p.suffix in ImageFolder.IMG_EXTENSIONS
        ]
        img_paths = sorted(img_paths)

        return img_paths

    def __len__(self):
        """ディレクトリ内の画像ファイルの数を返す。
        """
        return len(self.img_paths)

In [None]:
batch_size = 1000

transform = transforms.Compose([transforms.ToTensor(),transforms.Grayscale()])

train = ImageFolder("train", transform) 
test = ImageFolder("test", transform)

train_loader = torch.utils.data.DataLoader(dataset=train, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test, batch_size=batch_size, shuffle=False)

###3. VAEの学習モデル

In [None]:
class VAE(nn.Module):
  
    def __init__(self):
        super(VAE, self).__init__()
        
        self.conv1 = nn.Conv2d(1, 48, 3)
        self.conv2 = nn.Conv2d(48, 32, 3)
        self.conv5 = nn.Conv2d(32, 16, 3)
        self.max_pool = nn.MaxPool2d(kernel_size = 2, stride = 2)
        self.fc1_mu = nn.Linear(21 * 21 * 16, 20)
        self.fc1_sig = nn.Linear(21 * 21 * 16, 20)
        self.fc2 = nn.Linear(20, 21 * 21 * 16)
        self.up_sample = nn.UpsamplingNearest2d(scale_factor=2)
        self.conv6 = nn.ConvTranspose2d(16, 32, 3)
        self.conv3 = nn.ConvTranspose2d(32, 48, 3)
        self.conv4 = nn.ConvTranspose2d(48, 1, 3)
         
    def encode(self,x):
        a1 = F.relu(self.conv1(x))
        #print(a1.shape)
        a2 = F.relu(self.conv2(a1))
        #print(a2.shape)
        a3 = F.relu(self.conv5(a2))
        #print(a3.shape)
        
        mx_poold = self.max_pool(a3)
        #print(mx_poold.shape)
        a_reshaped = mx_poold.reshape(-1 , 21 * 21 * 16)
        #print(a_reshaped.shape)
        a_mu = self.fc1_mu(a_reshaped)
        a_logvar = self.fc1_sig(a_reshaped)
        #print(a_mu.shape)
        return a_mu, a_logvar
  
    def decode(self,z):
        a3 = F.relu(self.fc2(z))
        #print(a3.shape)
        a3 = a3.reshape(-1, 16, 21, 21)
        #print(a3.shape)
        a3_upsample = self.up_sample(a3)
        #print(a3_upsample.shape)
        a4 = F.relu(self.conv6(a3_upsample))
        #print(a4.shape)
        a5 = F.relu(self.conv3(a4))
        #print(a5.shape)
        a6 = torch.sigmoid(self.conv4(a5))
        return a6
  
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return eps.mul(std).add_(mu)
  
    def forward(self,x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar
    

In [None]:
model = VAE()
model = model.to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

summary(model,input_size=(10, 1, 48, 48),col_names=["output_size", "num_params"])

###4. 損失関数

In [None]:
def loss_function(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')

    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    return BCE + KLD

###5. VAE学習

In [None]:
num_epochs = 2000 #学習回数
print_per = 100
model.train()

for epoch in range(num_epochs):
    train_loss = 0
    print_loss = 0
    loss_record = []
    
    t_delta = datetime.timedelta(hours=9)
    JST = datetime.timezone(t_delta, 'JST')
    now = datetime.datetime.now(JST)
    d = now.strftime('%Y%m%d%H%M%S')
    
    for i, (images) in enumerate(train_loader):
        
        optimizer.zero_grad()
        images = images.to(device)
        
        recon_batch, mu, logvar = model(images)
        loss = loss_function(recon_batch, images, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        print_loss += loss.item()
        optimizer.step()
        if (i%print_per == 0):
            print("Epoch : {} , Minibatch : {} Loss = {:.4f}".format(epoch+1, i, print_loss))
            loss_record.append(print_loss)
            print_loss = 0
    
    if epoch%10 == 0:
        torch.save(model.state_dict(), "model/{}_{}_.pth".format(d,epoch))
        
    print("Epoch {} : Loss = ({:.4f}) ".format(epoch+1, train_loss))

###6. 学習,テストデータから画像生成

In [None]:
for i, (img) in tqdm(enumerate(test_loader)):
    img = img.to(device)
    
for i, (images) in tqdm(enumerate(train_loader)):
    images = images.to(device)
    if i == 1:
        break

In [None]:
fig=plt.figure(figsize=(12, 12))

columns = 1
rows = 3

for i,test in enumerate(range(10)):
    image1 = images[0].reshape(1, 48, 48)

    print(images[0].shape)
    
    fig.add_subplot(rows, columns, 1)
    plt.imshow(images.data[test].cpu().numpy().reshape(48, 48), cmap='gray')
    #plt.show(block=True)
    
    fig.add_subplot(rows, columns, 2)    
    plt.imshow((model(images)[0].data[test].cpu().numpy().reshape(48, 48)), cmap='gray')
    #plt.show(block=True)
    
    a = model(images)[0].data[test].cpu().numpy().reshape(48, 48)
    b = (images.data[test].cpu().numpy().reshape(48, 48))
    d = b-a
    
    fig.add_subplot(rows, columns, 3)
    plt.imshow(d, cmap='gray')

    plt.savefig("nomal_{}".format(i))
    
plt.show(block=True)
    