In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np

# 检查是否有可用的GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")

# 定义自动编码器网络
class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()
        # 编码器部分
        self.encoder = nn.Sequential(
            nn.Linear(28 * 28, 128),
            nn.ReLU(True),
            nn.Linear(128, 64),
            nn.ReLU(True),
            nn.Linear(64, 12),
            nn.ReLU(True),
            nn.Linear(12, 3)  # 将数据压缩到3维
        )
        # 解码器部分
        self.decoder = nn.Sequential(
            nn.Linear(3, 12),
            nn.ReLU(True),
            nn.Linear(12, 64),
            nn.ReLU(True),
            nn.Linear(64, 128),
            nn.ReLU(True),
            nn.Linear(128, 28 * 28),
            nn.Sigmoid()  # 输出范围在 [0,1]
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

# 超参数设置
batch_size = 128
learning_rate = 1e-3
num_epochs = 20

# 数据预处理
transform = transforms.Compose([
    transforms.ToTensor(),  # 转换为张量
    transforms.Normalize((0.5,), (0.5,))  # 归一化到 [-1,1]
])

# 加载MNIST数据集
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# 实例化模型、定义损失函数和优化器
model = Autoencoder().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# 训练自动编码器
for epoch in range(num_epochs):
    for data in train_loader:
        img, _ = data
        img = img.view(img.size(0), -1).to(device)  # 展平图像
        # 前向传播
        output = model(img)
        loss = criterion(output, img)
        # 反向传播和优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")

# 测试自动编码器并可视化重构结果
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=10, shuffle=True)
dataiter = iter(test_loader)
images, _ = dataiter.next()
images_flat = images.view(images.size(0), -1).to(device)
with torch.no_grad():
    reconstructed = model(images_flat)

# 将重构的图像转换为CPU并重新调整形状
reconstructed = reconstructed.view(-1, 1, 28, 28).cpu()

# 可视化原始图像和重构图像
fig, axes = plt.subplots(nrows=2, ncols=10, figsize=(15, 4))
for i in range(10):
    # 原始图像
    axes[0, i].imshow(images[i].squeeze(), cmap='gray')
    axes[0, i].axis('off')
    # 重构图像
    axes[1, i].imshow(reconstructed[i].squeeze(), cmap='gray')
    axes[1, i].axis('off')
plt.suptitle("上排: 原始图像 | 下排: 重构图像", fontsize=16)
plt.show()

# 可视化编码器的3维潜在空间
model.eval()
encoded_imgs = []
labels = []
with torch.no_grad():
    for data in test_loader:
        img, label = data
        img = img.view(img.size(0), -1).to(device)
        encoded = model.encoder(img)
        encoded_imgs.append(encoded.cpu())
        labels.append(label)
        if len(encoded_imgs) * batch_size >= 1000:
            break

encoded_imgs = torch.cat(encoded_imgs)[:1000]
labels = torch.cat(labels)[:1000]

from mpl_toolkits.mplot3d import Axes3D

fig = plt.figure(figsize=(8, 6))
ax = fig.add_subplot(111, projection='3d')
scatter = ax.scatter(encoded_imgs[:,0], encoded_imgs[:,1], encoded_imgs[:,2], c=labels, cmap='tab10', alpha=0.7)
legend = ax.legend(*scatter.legend_elements(), title="数字")
ax.add_artist(legend)
ax.set_xlabel("维度1")
ax.set_ylabel("维度2")
ax.set_zlabel("维度3")
plt.title("编码器的3维潜在空间表示")
plt.show()


使用设备: cpu
Epoch [1/20], Loss: 0.9260
Epoch [2/20], Loss: 0.9237
Epoch [3/20], Loss: 0.9263
Epoch [4/20], Loss: 0.9244
Epoch [5/20], Loss: 0.9218
Epoch [6/20], Loss: 0.9261
Epoch [7/20], Loss: 0.9278
Epoch [8/20], Loss: 0.9256
Epoch [9/20], Loss: 0.9216
Epoch [10/20], Loss: 0.9262
Epoch [11/20], Loss: 0.9224
Epoch [12/20], Loss: 0.9263
Epoch [13/20], Loss: 0.9234
Epoch [14/20], Loss: 0.9291
Epoch [15/20], Loss: 0.9258
Epoch [16/20], Loss: 0.9255
Epoch [17/20], Loss: 0.9244
Epoch [18/20], Loss: 0.9240
Epoch [19/20], Loss: 0.9253
Epoch [20/20], Loss: 0.9232


AttributeError: '_SingleProcessDataLoaderIter' object has no attribute 'next'