In [1]:
import numpy as np
import torch
import torch.nn.functional as F
from torch import optim
from torch import autograd
from torch.utils.data import DataLoader
from torch import nn

import matplotlib.pyplot as plt
%matplotlib inline
from torchvision.datasets import MNIST
from torchvision import transforms as tfs
from torchvision.utils import save_image

In [2]:
EPOCH = 100
BATCH_SIZE = 128
LEARNING_RATE = 1e-3
# True表示需下载，False表示已下载
DOWNLOAD_MNIST = False

In [3]:
# 训练时才会标准化
im_tfs = tfs.Compose([
    # 先将输入归一化到(0,1)，再使用公式”(x-mean)/std”，将每个元素分布到(-1,1) 
    tfs.ToTensor(),
    tfs.Normalize([0.5], [0.5])
])

train_set = MNIST('./mnist', train=True, transform=im_tfs, download=DOWNLOAD_MNIST)
train_data = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, drop_last=True)
print(len(train_data))

for i, batch in enumerate(train_data):
    #print(i)
    # batch[0]为数据,batch[1]为标签
    print(batch[0].shape)
    break
    

test_set = MNIST('./mnist', train=False)

468
torch.Size([128, 1, 28, 28])


In [4]:
#print(train_set.data.size())
#print(train_set.data[0])
#print(train_set.targets.size())
#print(train_set.targets[0])
print(len(train_data))
# plt.figure(figsize=(20, 10))
# plt.imshow(train_set.data[0].numpy(), cmap='gray')
# plt.title('%i' % train_set.targets[0])
# plt.show()

468


In [5]:
class VAE(nn.Module):
    def __init__(self, latent_num=2):
        super(VAE, self).__init__()
        
        self.fc1 = nn.Linear(28*28, 400)
        self.fc21 = nn.Linear(400, 20) # mean
        self.fc22 = nn.Linear(400, 20) # var
        self.fc3 = nn.Linear(20, 400)
        self.fc4 = nn.Linear(400, 28*28)

    # q(z|x)
    def encode_q(self, x):
        h1 = F.relu(self.fc1(x))
        return self.fc21(h1), self.fc22(h1)
    
    # 重参数化，使网络可以反向传播
    def reparametrize(self, mu, logvar):
        # mul逐乘
        # exp逐指数
        # exp_是exp的in-place形式
        std = logvar.mul(0.5).exp_()
        eps = torch.FloatTensor(std.size()).normal_()
        if torch.cuda.is_available():
            eps = eps.cuda()
        # z = mu + sigma * eps
        return eps.mul(std).add_(mu)
    
    # p(x|z)
    def decoder_p(self, z):
        h3 = F.relu(self.fc3(z))
        return torch.tanh(self.fc4(h3))

    def forward(self, x):
        mu, logvar = self.encode_q(x)
        z = self.reparametrize(mu, logvar)
        
        # 解码，同时输出均值和方差
        return self.decoder_p(z), mu, logvar 


In [6]:
net = VAE()
if torch.cuda.is_available():
    net = net.cuda()
    
print(net)

VAE(
  (fc1): Linear(in_features=784, out_features=400, bias=True)
  (fc21): Linear(in_features=400, out_features=20, bias=True)
  (fc22): Linear(in_features=400, out_features=20, bias=True)
  (fc3): Linear(in_features=20, out_features=400, bias=True)
  (fc4): Linear(in_features=400, out_features=784, bias=True)
)


In [7]:
train_set

Dataset MNIST
    Number of datapoints: 60000
    Split: train
    Root Location: ./mnist
    Transforms (if any): Compose(
                             ToTensor()
                             Normalize(mean=[0.5], std=[0.5])
                         )
    Target Transforms (if any): None

In [8]:
x = train_set.data[0].view(-1, 784)
x = x.to(dtype=torch.float32)

if torch.cuda.is_available():
    x = x.cuda()
_, mu, var = net(x)


In [9]:
mu

tensor([[ 14.6022, -45.9968,   8.9362, -34.1990, -22.0062,  11.0758,  -2.4679,
          -0.1938,  17.5876,  26.7471,   0.7580,   0.2342,  -3.3591,  -7.3759,
         -32.6608, -19.0109,  27.5281,   2.6232,  23.3838, -13.5074]],
       grad_fn=<AddmmBackward>)

In [10]:
var

tensor([[-18.7527,   1.0859, -20.1880,   4.6355,  13.0386,  12.9033, -16.1935,
         -18.2203, -38.4091,   0.5825,  -8.3420,  -5.8560,   2.1735,  14.7907,
          -8.0186,   8.9145, -28.9785, -17.5274,  10.7804,  -9.6349]],
       grad_fn=<AddmmBackward>)

In [11]:
reconstruction_function = nn.MSELoss(reduction='sum')

def loss_function(recon_x, x, mu, logvar):
    MSE = reconstruction_function(recon_x, x)
    # loss = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
    KLD = torch.sum(KLD_element).mul_(-0.5)
    # KL divergence
    return MSE + KLD

In [12]:
optimizer = optim.Adam(net.parameters(), lr=LEARNING_RATE)

In [13]:
def to_img(x):
    x = 0.5 * (x + 1.)
    x = x.clamp(0, 1)
    x = x.view(x.shape[0],1,28,28)
    return x

In [14]:
import os
for e in range(EPOCH):
    for i, batch in enumerate(train_data):
        # print(i)
        # batch[0]为数据,batch[1]为标签
        # print(batch[0], batch[1])
        
#         if torch.cuda.is_available():
#             batch = batch.cuda()
        img = batch[0].view(BATCH_SIZE, -1)
        #print(img.size())
        recon_img, mu, logvar = net(img)
        loss = loss_function(recon_img, img, mu, logvar) / BATCH_SIZE
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    if (e + 1) % 5 == 0:
        print('epoch: {}, Loss: {:.4f}'.format(e + 1, loss.data))
        save = to_img(recon_img.cpu().data)
        if not os.path.exists('./vae_img'):
            os.mkdir('./vae_img')
        save_image(save, './vae_img/image_{}.png'.format(e + 1))

epoch: 5, Loss: 74.7497
epoch: 10, Loss: 68.9868
epoch: 15, Loss: 68.1801
epoch: 20, Loss: 66.1932
epoch: 25, Loss: 64.7258
epoch: 30, Loss: 64.2296
epoch: 35, Loss: 63.5790
epoch: 40, Loss: 63.0088
epoch: 45, Loss: 63.9606
epoch: 50, Loss: 61.5279
epoch: 55, Loss: 60.9257
epoch: 60, Loss: 58.9894
epoch: 65, Loss: 61.9385
epoch: 70, Loss: 62.1132
epoch: 75, Loss: 62.2782
epoch: 80, Loss: 63.5342
epoch: 85, Loss: 63.6882
epoch: 90, Loss: 60.2739
epoch: 95, Loss: 61.5337
epoch: 100, Loss: 61.5340


In [15]:
def save_model():
    # entire net
    torch.save(net, 'VAE_net1.pkl')
    # parameters
    torch.save(net.state_dict(), 'VAE_net1_params.pkl')
    
def restore_net():
    net2 = torch.load('VAE_net1.pkl')
    net2.eval()
    
def restore_params():
    # 要与原net的结构一样
    net3 = VAE(*args, **kwargs)
    net3.load_state_dict(torch.load('VAE_net1_params.pkl'))
    net3.eval()
    
save_model()

  "type " + obj.__name__ + ". It won't be checked "
