In [25]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.fc1 = nn.Linear(784, 400)
        self.fc2 = nn.Linear(400, 200)
        self.fc3 = nn.Linear(200, 100)
        self.fc4 = nn.Linear(100, 50)
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = torch.relu(self.fc3(x))
        x = self.fc4(x)
        return x
    
class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.fc1 = nn.Linear(50, 100)
        self.fc2 = nn.Linear(100, 200)
        self.fc3 = nn.Linear(200, 400)
        self.fc4 = nn.Linear(400, 784)
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = torch.relu(self.fc3(x))
        x = self.fc4(x)
        return x 

In [26]:
batch_size = 64
learning_rate = 0.001
root = './datasets'
# 加载数据集
train_dataset =  datasets.MNIST(root, train=True, transform=transforms.ToTensor(), download=True)

# 创建数据加载器
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, num_workers=0, drop_last=False)

encoder = Encoder()
decoder = Decoder()
criterion = nn.MSELoss()
optimizer = optim.Adam(list(encoder.parameters())+list(decoder.parameters()), lr=learning_rate)

# 训练模型
num_epochs = 10
for epoch in range(num_epochs):
    total_loss = 0
    for inputs, labels in train_loader:
        inputs = torch.flatten(inputs, start_dim=-2) 
        optimizer.zero_grad()
        en_outputs = encoder(inputs)
        de_outputs = decoder(en_outputs)
        loss = criterion(de_outputs, inputs)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    avg_loss = total_loss / len(train_loader)
    print(f'Epoch [{epoch+1}/{num_epochs}], Average Loss: {avg_loss:.4f}')




Epoch [1/10], Average Loss: 0.0379
Epoch [2/10], Average Loss: 0.0223
Epoch [3/10], Average Loss: 0.0184
Epoch [4/10], Average Loss: 0.0164
Epoch [5/10], Average Loss: 0.0150
Epoch [6/10], Average Loss: 0.0138
Epoch [7/10], Average Loss: 0.0130
Epoch [8/10], Average Loss: 0.0123
Epoch [9/10], Average Loss: 0.0116
Epoch [10/10], Average Loss: 0.0111


In [27]:
import os
import torch

# 假设encoder和decoder是你的模型实例

# 创建保存模型的目录（如果它还不存在）
model_dir = './model'
os.makedirs(model_dir, exist_ok=True)

# 保存模型状态字典
torch.save(encoder.state_dict(), os.path.join(model_dir, 'encoder.pth'))
torch.save(decoder.state_dict(), os.path.join(model_dir, 'decoder.pth'))
