# 定義變分自編碼器（VAE）

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import os

# 定義變分自編碼器（VAE）
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()

        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=4, stride=2, padding=1),
            nn.ReLU(True),
            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(True)
        )
        self.fc_mu = nn.Linear(64 * 7 * 7, 256)
        self.fc_logvar = nn.Linear(64 * 7 * 7, 256)
        self.fc_decode = nn.Linear(256, 64 * 7 * 7)

        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
            nn.ReLU(True),
            nn.ConvTranspose2d(32, 1, kernel_size=4, stride=2, padding=1),
            nn.Sigmoid()
        )

    def encode(self, x):
        h = self.encoder(x).view(-1, 64 * 7 * 7)
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        h = self.fc_decode(z).view(-1, 64, 7, 7)
        return self.decoder(h)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

# 訓練變分自編碼器（VAE）

In [None]:
# 訓練 VAE
def train_vae(vae, dataloader, num_epochs=20, learning_rate=1e-3):
    reconstruction_loss_fn = nn.BCELoss(reduction='sum')
    optimizer = optim.Adam(vae.parameters(), lr=learning_rate)

    vae.train()
    for epoch in range(num_epochs):
        running_loss = 0.0
        for images, _ in dataloader:
            images = images.to(device)

            # 前向傳播
            reconstructed, mu, logvar = vae(images)
            reconstruction_loss = reconstruction_loss_fn(reconstructed, images)
            kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
            loss = reconstruction_loss + kld_loss

            # 反向傳播和優化
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        avg_loss = running_loss / len(dataloader.dataset)
        print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {avg_loss:.4f}")

    # 保存模型
    torch.save(vae.state_dict(), 'vae_mnist.pth')
    print("Model saved as 'vae_mnist.pth'")

# 加載模型
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
vae = VAE().to(device)

# 訓練數據集
transform = transforms.Compose([
    transforms.ToTensor(),
])
dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

# 訓練 VAE
train_vae(vae, dataloader, num_epochs=10)

vae.eval()

# 觀察資料集中的圖片


In [None]:
# 從數據集中選擇兩張圖片，顯示出來
image1, _ = dataset[3] #這裡從資料集中挑選出第4張圖片(資料集中的圖片編號由0開始)
image2, _ = dataset[4] #這裡從資料集中挑選出第5張圖片
fig, axes = plt.subplots(1, 2, figsize=(10, 5))
axes[0].imshow(image1.squeeze(), cmap='gray')
axes[0].axis('off')
axes[1].imshow(image2.squeeze(), cmap='gray')
axes[1].axis('off')
plt.show()

# 進行image1到image2的圖片漸變

In [None]:
with torch.no_grad():
    # 編碼兩張圖片得到隱變量
    mu1, _ = vae.encode(image1)
    mu2, _ = vae.encode(image2)

    # 進行漸變，生成 10 張漸變圖片
    fig, axes = plt.subplots(1, 10, figsize=(20, 2))
    for i, alpha in enumerate(torch.linspace(0, 1, steps=10)):
        z = mu1 * (1 - alpha) + mu2 * alpha
        generated_image = vae.decode(z).cpu().squeeze(0)
        axes[i].imshow(generated_image.permute(1, 2, 0).squeeze(), cmap='gray')
        axes[i].axis('off')

    plt.show()