In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
from torchvision.utils import make_grid

# 參數設定
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
img_size = 28
channels = 1
latent_dim = 20
batch_size = 64

# 使用 FashionMNIST 數據集
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
dataset = datasets.FashionMNIST(root="./data", train=True, download=True, transform=transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

# 定義 Variational Autoencoder (VAE) 模型
class VAE(nn.Module):
    def __init__(self, img_size, channels, latent_dim):
        super(VAE, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(channels, 32, kernel_size=3, stride=2, padding=1), nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1), nn.ReLU()
        )
        self.fc_mu = nn.Linear(64 * (img_size // 4) * (img_size // 4), latent_dim)
        self.fc_logvar = nn.Linear(64 * (img_size // 4) * (img_size // 4), latent_dim)
        self.fc_decode = nn.Linear(latent_dim, 64 * (img_size // 4) * (img_size // 4))
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1), nn.ReLU(),
            nn.ConvTranspose2d(32, channels, kernel_size=4, stride=2, padding=1), nn.Tanh()
        )

    def encode(self, x):
        h = self.encoder(x).view(x.size(0), -1)
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        h = self.fc_decode(z).view(z.size(0), 64, img_size // 4, img_size // 4)
        return self.decoder(h)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        x_recon = self.decode(z)
        return x_recon, mu, logvar

# 定義損失函數
def vae_loss(recon_x, x, mu, logvar):
    recon_loss = nn.MSELoss(reduction='sum')(recon_x, x)
    kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return recon_loss + kl_div

# 初始化模型
vae = VAE(img_size, channels, latent_dim).to(device)
optimizer = optim.Adam(vae.parameters(), lr=1e-3)

# 訓練過程
epochs = 20
vae.train()
for epoch in range(epochs):
    train_loss = 0
    for imgs, _ in dataloader:
        imgs = imgs.to(device)
        optimizer.zero_grad()
        recon_imgs, mu, logvar = vae(imgs)
        loss = vae_loss(recon_imgs, imgs, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
    print(f"Epoch {epoch + 1}/{epochs}, Loss: {train_loss / len(dataloader.dataset):.4f}")

# 生成圖片過程模擬（從雜訊到清晰圖片，生成 T-shirt 類型的圖片）
vae.eval()
with torch.no_grad():
    # 將數據集中第一個 T-shirt 類型的圖片作為目標
    for img, label in dataset:
        if label == 0:  # 0 是 T-shirt 類別的標籤
            target_img = img.unsqueeze(0).to(device)
            break

    mu, logvar = vae.encode(target_img)  # 編碼目標圖片
    z_start = torch.randn_like(mu)  # 隨機初始化潛在向量
    z_target = vae.reparameterize(mu, logvar)  # 目標潛在向量

    generated_imgs = []
    for alpha in torch.linspace(0, 1, steps=20).to(device):
        intermediate_z = (1 - alpha) * z_start + alpha * z_target  # 插值生成潛在向量
        intermediate_img = vae.decode(intermediate_z).cpu()
        generated_imgs.append(intermediate_img)

# 將生成的圖片連接起來以進行可視化
grid = make_grid(torch.cat(generated_imgs, dim=0), nrow=20, normalize=True).permute(1, 2, 0)
plt.figure(figsize=(20, 5))
plt.imshow(grid.numpy())
plt.axis("off")
plt.show()


KeyboardInterrupt: 