### Module23 DNN AutoEncoder

In [None]:
# permute() 交換===============================
# coding: utf-8
import torch
inputs = [[[1, 2 ,3], [4, 5, 6]],
          [[7, 8, 9], [10, 11, 12]]]
inputs = torch.tensor(inputs)
print('inputs=\n',inputs) ; print('Inputs.shape:', inputs.shape) 
# torch.Size([2, 2, 3]) , dim(axis) 0 is 2維, dim 1 is 2維 , dim2 is 3維 
# permute() 則是可以透過設定 dim(axis) 這個編號，置換維度。
outputs = inputs.permute(0, 2, 1)  # 順序改為 (dim 0, dim 2, dim 1)
print('outputs = \n',outputs); print('Outputs.shape:', outputs.shape)

In [None]:
# view() ==============================
import torch
inputs = [[[1, 2 ,3], [4, 5, 6]],
          [[7, 8, 9], [10, 11, 12]]]
inputs = torch.tensor(inputs)
print('inputs=\n', inputs) ;print('Inputs.shape:', inputs.shape)
outputs = inputs.view(2, 3, 2)
print('outputs =\n',outputs) ;  print('Outputs.shape:', outputs.shape)
outputs = inputs.view(-1)  # 不指定維度形狀，將所有元素放在同一維度下(1排)
print(outputs) ;  print('Outputs:', outputs.shape)

In [None]:
# torch.cat() 的使用方法 =======================
import torch
a = torch.tensor([1, 2, 3]) ;  b = torch.tensor([4, 5, 6])
ab = torch.cat((a, b),0)  # 0 & -1 (-1就是不指定) 結果一樣
ba = torch.cat((b, a), 0) ;  print('ab:', ab) ;  print('ba:', ba)
#=========================
a = torch.tensor([[1, 2, 3]])   ;   b = torch.tensor([[4, 5, 6]])
print('0:', torch.cat((a, b), 0)) ;  print('1:', torch.cat((a, b), 1))

In [None]:
import torch
import torch.nn as nn
import torch.utils.data as data
import torchvision
epochs = 10  ;   batch_size = 128  ;lr = 0.008  # learning rate
# DataLoader
train_set = torchvision.datasets.MNIST(root='./dataset/mnist',train=True,
    download=True, transform=torchvision.transforms.ToTensor(),)
train_loader = data.DataLoader(train_set, batch_size=batch_size, shuffle=True)

In [None]:
class AutoEncoder(nn.Module):
    def __init__(self):
        super(AutoEncoder, self).__init__()
        # Encoder
        self.encoder = nn.Sequential(
            nn.Linear(784, 128),
            nn.Tanh(),
            nn.Linear(128, 64),
            nn.Tanh(),
            nn.Linear(64, 16),
            nn.Tanh(),
            nn.Linear(16, 2), )
        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(2, 16),
            nn.Tanh(),
            nn.Linear(16, 64),
            nn.Tanh(),
            nn.Linear(64, 128),
            nn.Tanh(),
            nn.Linear(128, 784),
            nn.Sigmoid() )
    def forward(self, inputs):
        codes = self.encoder(inputs)
        decoded = self.decoder(codes)
        return codes, decoded

In [None]:
# Optimizer and loss function
model = AutoEncoder()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
loss_function = nn.MSELoss()

In [None]:
# Train
for epoch in range(epochs):
    for batch_data, labels in train_loader:
        inputs = batch_data.view(-1, 784)
        # Forward
        codes, decoded = model(inputs)
        # Backward
        optimizer.zero_grad()
        loss = loss_function(decoded, inputs)
        loss.backward()
        optimizer.step()
    # Show progress
    print('[{}/{}] Loss:'.format(epoch+1, epochs), loss.item())
# Save
torch.save(model, 'autoencoder.pth')

In [None]:
import torch
import torch.nn as nn
import torch.utils.data as data
import torchvision
import numpy as np
import matplotlib.pyplot as plt
# Settings
plt.rcParams['figure.figsize'] = (10.0, 8.0)
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'
# Show images
def show_images(images):
    sqrtn = int(np.ceil(np.sqrt(images.shape[0])))
    for index, image in enumerate(images):
        plt.subplot(sqrtn, sqrtn, index+1)
        plt.imshow(image.reshape(28, 28))
        plt.axis('off')
# Model structure
class AutoEncoder(nn.Module):
    def __init__(self):
        super(AutoEncoder, self).__init__()
        self.encoder = nn.Sequential(   # Encoder
            nn.Linear(784, 128),
            nn.Tanh(),
            nn.Linear(128, 64),
            nn.Tanh(),
            nn.Linear(64, 16),
            nn.Tanh(),
            nn.Linear(16, 2), )
        self.decoder = nn.Sequential(  # Decoder
            nn.Linear(2, 16),
            nn.Tanh(),
            nn.Linear(16, 64),
            nn.Tanh(),
            nn.Linear(64, 128),
            nn.Tanh(),
            nn.Linear(128, 784),
            nn.Sigmoid() )
    def forward(self, inputs):
        codes = self.encoder(inputs)
        decoded = self.decoder(codes) ; return codes, decoded
# Load model
model = torch.load('autoencoder.pth')
model.eval()  ; print(model)
# DataLoader
test_set = torchvision.datasets.MNIST(
    root='./dataset/mnist',
    train=False, download=True,
    transform=torchvision.transforms.ToTensor(),)
test_loader = data.DataLoader(test_set, batch_size=16, shuffle=False)
with torch.no_grad():
    for batch_data in test_loader:
        inputs = batch_data[0].view(-1, 28*28)
        show_images(inputs)  ;   plt.show()
        code, outputs = model(inputs)
        show_images(outputs)  ; plt.show() ;  break

### VAE : 受歡迎的 autoencoder

In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torchvision.utils import save_image
#torch.cuda.set_device(1)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
sample_dir = 'samples'        # 結果放置目錄
if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)

In [None]:
image_size = 784  ;  h_dim = 400  ;  z_dim = 20 
num_epochs = 15   ;  batch_size = 128  ; learning_rate = 1e-3
dataset = torchvision.datasets.MNIST(root='./dataset/mnist',train=True,
                           transform=transforms.ToTensor(),download=True)
data_loader = torch.utils.data.DataLoader(dataset=dataset, batch_size=batch_size, 
                                          shuffle=True)

In [None]:
class VAE(nn.Module):   ## VAE model
    def __init__(self, image_size=784, h_dim=400, z_dim=20):
        super(VAE, self).__init__()
        # Encoder ==============================
        self.fc1 = nn.Linear(image_size, h_dim)
        self.fc2 = nn.Linear(h_dim, z_dim)  # fc2, fc3 一樣
        self.fc3 = nn.Linear(h_dim, z_dim) 
        # decoder =============================
        self.fc4 = nn.Linear(z_dim, h_dim)
        self.fc5 = nn.Linear(h_dim, image_size)
    def encode(self, x):
        h = F.relu(self.fc1(x))
        return self.fc2(h), self.fc3(h)  # 兵分1個兩路, 1給 mean , 1個給std
    def reparameterize(self, mu, log_var):
        std = torch.exp(log_var/2) ;  eps = torch.randn_like(std)
        return mu + eps * std
    def decode(self, z):
        h = F.relu(self.fc4(z))  ; return F.sigmoid(self.fc5(h))
    def forward(self, x):
        mu, log_var = self.encode(x)
        z = self.reparameterize(mu, log_var)
        x_reconst = self.decode(z)  ; return x_reconst, mu, log_var
model = VAE().to(device)

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
for epoch in range(num_epochs):
    for i, (x, _) in enumerate(data_loader):
        x = x.to(device).view(-1, image_size)
        x_reconst, mu, log_var = model(x)
        # page 23-19 ========================
        # 1 reconstruction loss-----------------------------------------
        reconst_loss = F.binary_cross_entropy(x_reconst, x, size_average=False)
        # 2 KL deviation loss
        kl_div = - 0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
        loss = reconst_loss + kl_div  # 1+2 --> 做優化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if (i+1) % 100 == 0:
            print ("Epoch[{}/{}], Step [{}/{}], Reconst Loss: {:.4f}, KL Div: {:.4f}" 
                   .format(epoch+1, num_epochs, i+1, len(data_loader), reconst_loss.item(), kl_div.item()))
    with torch.no_grad():
        z = torch.randn(batch_size, z_dim).to(device)
        out = model.decode(z).view(-1, 1, 28, 28)
        save_image(out, os.path.join(sample_dir, 'sampled-{}.png'.format(epoch+1)))
        out, _, _ = model(x)
        x_concat = torch.cat([x.view(-1, 1, 28, 28), out.view(-1, 1, 28, 28)], dim=3)
        save_image(x_concat, os.path.join(sample_dir, 'reconst-{}.png'.format(epoch+1)))

In [None]:
import matplotlib.pyplot as plt 
import matplotlib.image as mpimg 
import numpy as np
reconsPath = './samples/reconst-15.png'
Image = mpimg.imread(reconsPath)
plt.imshow(Image)
plt.axis('off')
plt.show()
genPath = './samples/sampled-15.png'
Image = mpimg.imread(genPath)
plt.imshow(Image) 
plt.axis('off') 
plt.show()