# 变分自编码器

In [1]:
import torch
from torch import nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch import optim
import torchvision
from torchvision import transforms
from torchvision.utils import save_image
import os
import numpy as np
import matplotlib.pyplot as plt

In [None]:
# 1. 进行参数设置
batch_size = 128
z_dim = 20
h_dim = 400
num_epochs = 30
learning_rate = 0.01
sample_dir = './data2'
image_size = 784

In [None]:
# 2. 设置VAE模型
class VAE(nn.Module):
    def __init__(self, image_size=784, h_dim=400, z_dim=20):
        super(VAE, self).__init__()
        self.fc1 = nn.Linear(image_size,h_dim)
        self.fc21 = nn.Linear(h_dim,z_dim)
        self.fc22 = nn.Linear(h_dim,z_dim)
        self.fc3 = nn.Linear(z_dim,h_dim)
        self.fc4 = nn.Linear(h_dim,image_size)
    
    # 编码器
    def encode(self,x):
        h1 = F.relu(self.fc1(x))
        return self.fc21(h1), self.fc22(h1)
    
    # 在参数化
    def reparametrize(self, mu, logvar):
        std = torch.exp(logvar/2)
        eps = torch.rand_like(std)
        return mu + eps*std  # eps是从标准高斯分布中采集的一个值
    
    # 解码器
    def decode(self,z):
        h3 = F.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h3))
    
    def forward(self, x):
        x = x.reshape(-1, image_size)
        mu, logvar = self.encode(x)
        z = self.reparametrize(mu, logvar)
        y = self.decode(z)
        y = y.reshape(-1,28, 28).unsqueeze(1)
        return y, mu, logvar

In [None]:
# 3. 设置GPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = VAE().to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# 4. 目标损失函数
class myLoss(nn.Module):
    def __init__(self):
        super(myLoss, self).__init__()
        self.BCE = nn.BCELoss()

    def forward(self,recon_x, x, mu, logvar):
        BCE = self.BCE(recon_x, x)
        KLD = -0.5*torch.sum(1+logvar-mu.pow(2)-logvar.exp())
        return BCE + KLD 
criterion = myLoss().to(device)

In [None]:
# 数据预处理
dataset = torchvision.datasets.MNIST(root='./data/', train=True,
                                      transform=transforms.ToTensor(), 
                                      download=False)
# 数据加载
data_loader = torch.utils.data.DataLoader(dataset=dataset, 
                                          batch_size=128, 
                                          shuffle=True)

In [13]:
# 训练模型
def train(model):
    lossArr = []
    for epoch in range(30):
        train_loss = 0.0
        for i, (data, _) in enumerate(data_loader):
            data = data.to(device)
            recon_x, mu, logvar = model(data)
            loss = criterion(recon_x, data, mu, logvar)
            train_loss += loss
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        print("{} epoch Loss is {:.5f}".format(epoch+1, train_loss))
        lossArr.append(train_loss.item()) 
        with torch.no_grad():
            # 保存采用图像
            z = torch.randn(batch_size, z_dim).to(device)
            out = model.decode(z).view(-1, -1, 28, 28)
            save_image(out, os.path.join(sample_dir, 'sample-{}.png'.format(epoch+1)))
            # 保存重构图像
            out, _, _ = model(x)
            x_concat = torch.cat([x.view(-1,-1,28,28), out.view(-1,-1,28,28)], dim=3)
            save_image(x_concat, os.path.join(sample_dir, 'reconst-{}.png'.format(epoch+1)))
        plt.plot(np.arange(0, epoch+1), np.array(lossArr)) 
        plt.legend("loss")

In [14]:
train(model=model)

KeyboardInterrupt: 