In [17]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from mpl_toolkits.mplot3d import Axes3D

matplotlib.use('TkAgg')

# 训练数据，AE不需要标签
training_data = torchvision.datasets.MNIST(
    root="data",
    train=True,
    download=True,
    transform=torchvision.transforms.ToTensor(),
)
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)

# 训练方式 CPU/CUDA
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

Using cuda device


In [18]:
# AE模型定义，将(28*28)的图像编码为3个特征
class AE(nn.Module):
    def __init__(self):
        super(AE, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(28 * 28, 128),
            nn.Tanh(),
            nn.Linear(128, 64),
            nn.Tanh(),
            nn.Linear(64, 12),
            nn.Tanh(),
            nn.Linear(12, 3),  # 编码为3个特征
        )
        self.decoder = nn.Sequential(
            nn.Linear(3, 12),
            nn.Tanh(),
            nn.Linear(12, 64),
            nn.Tanh(),
            nn.Linear(64, 128),
            nn.Tanh(),
            nn.Linear(128, 28 * 28),
            nn.Sigmoid(),  # 输出值为(0, 1)
        )

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return encoded, decoded

In [19]:
# 实例化
ae = AE().to(device)
optimizer = torch.optim.Adam(ae.parameters(), lr=0.001)
loss_func = nn.MSELoss()

In [20]:
# 训练
epochs = 20
ae.train()

for epoch in range(epochs):
    for batch, (x, y) in enumerate(train_dataloader):
        x1 = x.view(-1, 28 * 28)
        x2 = x.view(-1, 28 * 28)
        x1, x2 = x1.to(device), x2.to(device)

        encoded, decoded = ae(x1)
        loss = loss_func(decoded, x2)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(x1)
            print(f"loss: {loss:>7f}  [{current:>5d}/{len(train_dataloader.dataset):>5d}]")

loss: 0.231657  [    0/60000]
loss: 0.066269  [ 6400/60000]
loss: 0.064593  [12800/60000]
loss: 0.068239  [19200/60000]
loss: 0.062330  [25600/60000]
loss: 0.071733  [32000/60000]
loss: 0.069789  [38400/60000]
loss: 0.068686  [44800/60000]
loss: 0.072319  [51200/60000]
loss: 0.066503  [57600/60000]
loss: 0.059758  [    0/60000]
loss: 0.062318  [ 6400/60000]
loss: 0.059963  [12800/60000]
loss: 0.058667  [19200/60000]
loss: 0.056237  [25600/60000]
loss: 0.055942  [32000/60000]
loss: 0.057651  [38400/60000]
loss: 0.056127  [44800/60000]
loss: 0.056130  [51200/60000]
loss: 0.055185  [57600/60000]
loss: 0.056638  [    0/60000]
loss: 0.058691  [ 6400/60000]
loss: 0.056877  [12800/60000]
loss: 0.052161  [19200/60000]
loss: 0.052411  [25600/60000]
loss: 0.047495  [32000/60000]
loss: 0.054669  [38400/60000]
loss: 0.049947  [44800/60000]
loss: 0.048476  [51200/60000]
loss: 0.056276  [57600/60000]
loss: 0.051349  [    0/60000]
loss: 0.049167  [ 6400/60000]
loss: 0.050245  [12800/60000]
loss: 0.04

In [23]:
print(ae)

AE(
  (encoder): Sequential(
    (0): Linear(in_features=784, out_features=128, bias=True)
    (1): Tanh()
    (2): Linear(in_features=128, out_features=64, bias=True)
    (3): Tanh()
    (4): Linear(in_features=64, out_features=12, bias=True)
    (5): Tanh()
    (6): Linear(in_features=12, out_features=3, bias=True)
  )
  (decoder): Sequential(
    (0): Linear(in_features=3, out_features=12, bias=True)
    (1): Tanh()
    (2): Linear(in_features=12, out_features=64, bias=True)
    (3): Tanh()
    (4): Linear(in_features=64, out_features=128, bias=True)
    (5): Tanh()
    (6): Linear(in_features=128, out_features=784, bias=True)
    (7): Sigmoid()
  )
)


In [24]:
# 可视化输入和输出（对比）
show_n = 10
f, a = plt.subplots(2, show_n, figsize=(show_n, 2))

x = training_data.data[:show_n]
encoded, decoded = ae(x.view(-1, 28 * 28).type(torch.FloatTensor).to(device))
y = np.reshape(decoded.cpu().data.numpy(), (-1, 28, 28))

for i in range(show_n):
    a[0][i].imshow(x[i], cmap='gray')
    a[0][i].set_xticks(())
    a[0][i].set_yticks(())
    a[1][i].imshow(y[i], cmap='gray')
    a[1][i].set_xticks(())
    a[1][i].set_yticks(())

plt.show()

In [27]:
# （AE将原始数据降维到3维）空间显示若干个数据的编码
n = 2000
show_labels = [0, 5, 7]

view_data = training_data.data[:n].view(-1, 28 * 28).type(torch.FloatTensor).to(device)
encoded_data, _ = ae(view_data)
encoded_data = encoded_data.cpu()

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
X, Y, Z = encoded_data.data[:, 0].numpy(), encoded_data.data[:, 1].numpy(), encoded_data.data[:, 2].numpy()
values = training_data.targets[:n].numpy()

for x, y, z, s in zip(X, Y, Z, values):
    c = matplotlib.cm.rainbow(int(255 * s / 9))
    if s in show_labels:
        ax.text(x, y, z, s, backgroundcolor=c)
ax.set_xlim(X.min(), X.max())
ax.set_ylim(Y.min(), Y.max())
ax.set_zlim(Z.min(), Z.max())
plt.show()