<a href="https://colab.research.google.com/github/chris-william0829/vae-mnist/blob/main/VAE_MNIST.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score
import os.path

In [3]:
# VAE模型定义
class VAE(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(VAE, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.latent_dim = latent_dim
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, latent_dim * 2)
        )
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim),
            nn.Sigmoid()
        )

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        h = self.encoder(x)
        mu, logvar = h[:, :latent_dim], h[:, latent_dim:]
        z = self.reparameterize(mu, logvar)
        return self.decoder(z), mu, logvar

In [4]:
# 损失函数
def loss_function(recon_x, x, mu, logvar):
    #BCE = nn.functional.binary_cross_entropy(recon_x, x, reduction='sum')
    BCE = nn.functional.mse_loss(recon_x, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

In [5]:
# 超参数设置
batch_size = 128
epochs = 10
learning_rate = 1e-3
input_dim = 784
hidden_dim = 400
latent_dim = 20

In [6]:
# 数据加载
transform = transforms.Compose([transforms.ToTensor(), transforms.Lambda(lambda x: x.view(-1))])
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [7]:
# 模型、优化器初始化
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = VAE(input_dim,hidden_dim,latent_dim).to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [8]:
# VAE训练
if os.path.exists('vae.pth'):
    model.load_state_dict(torch.load('vae.pth'))
else:
    model.train()
    for epoch in range(epochs):
      train_loss = 0
      for batch_idx, (data, _) in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        loss = loss_function(recon_batch, data, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
      print(f"Epoch {epoch + 1}/{epochs}, Loss: {train_loss / len(train_loader.dataset)}")
    torch.save(model.state_dict(), './vae.pth')

Epoch 1/10, Loss: 45.714595186360675
Epoch 2/10, Loss: 35.155865209960936
Epoch 3/10, Loss: 33.19220285644531
Epoch 4/10, Loss: 32.29306981608073
Epoch 5/10, Loss: 31.72323212483724
Epoch 6/10, Loss: 31.379750838216147
Epoch 7/10, Loss: 31.073921618652342
Epoch 8/10, Loss: 30.854611385091147
Epoch 9/10, Loss: 30.673968111165365
Epoch 10/10, Loss: 30.51398776855469


In [9]:
# 提取训练集隐变量
model.eval()
train_hidden_vars = []
train_labels = []
for data, target in train_loader:
    data = data.to(device)
    with torch.no_grad():
        h = model.encoder(data)
        mu, _ = h[:, :latent_dim], h[:, latent_dim:]
    train_hidden_vars.append(mu.cpu().numpy())
    train_labels.append(target.numpy())

train_hidden_vars = np.concatenate(train_hidden_vars, axis=0)
train_labels = np.concatenate(train_labels, axis=0)

In [10]:
# 提取测试集隐变量
test_hidden_vars = []
test_labels = []
for data, target in test_loader:
    data = data.to(device)
    with torch.no_grad():
        h = model.encoder(data)
        mu, _ = h[:, :latent_dim], h[:, latent_dim:]
    test_hidden_vars.append(mu.cpu().numpy())
    test_labels.append(target.numpy())

test_hidden_vars = np.concatenate(test_hidden_vars, axis=0)
test_labels = np.concatenate(test_labels, axis=0)

In [11]:
# 使用SVC进行分类
svc = SVC()
svc.fit(train_hidden_vars, train_labels)

In [12]:
# 计算训练集准确率
train_preds = svc.predict(train_hidden_vars)
train_acc = accuracy_score(train_labels, train_preds)
print(f"Training accuracy: {train_acc}")

Training accuracy: 0.9832833333333333


In [13]:
# 计算测试集准确率
test_preds = svc.predict(test_hidden_vars)
test_acc = accuracy_score(test_labels, test_preds)
print(f"Test accuracy: {test_acc}")

Test accuracy: 0.9771
