In [None]:
import torch
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
# 这是什么
from torch.autograd import Variable
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt
import os
from tqdm import tqdm
import matplotlib.gridspec as gridspec

In [2]:
# 定义数据
mb_size = 64
Z_dim = 100
X_dim = 784
y_dim = 10
h_dim = 128
c = 0
lr = 1e-3

torch.manual_seed(42)
transform = transforms.Compose([transforms.ToTensor()])
train_loader = torch.utils.data.DataLoader(
    datasets.FashionMNIST('../data', train=True, download=True, transform=transform),
    batch_size=mb_size, shuffle=True, drop_last=True)

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ../data\FashionMNIST\raw\train-images-idx3-ubyte.gz


  0%|          | 0/26421880 [00:00<?, ?it/s]

Extracting ../data\FashionMNIST\raw\train-images-idx3-ubyte.gz to ../data\FashionMNIST\raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ../data\FashionMNIST\raw\train-labels-idx1-ubyte.gz


  0%|          | 0/29515 [00:00<?, ?it/s]

Extracting ../data\FashionMNIST\raw\train-labels-idx1-ubyte.gz to ../data\FashionMNIST\raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ../data\FashionMNIST\raw\t10k-images-idx3-ubyte.gz


  0%|          | 0/4422102 [00:00<?, ?it/s]

Extracting ../data\FashionMNIST\raw\t10k-images-idx3-ubyte.gz to ../data\FashionMNIST\raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ../data\FashionMNIST\raw\t10k-labels-idx1-ubyte.gz


  0%|          | 0/5148 [00:00<?, ?it/s]

Extracting ../data\FashionMNIST\raw\t10k-labels-idx1-ubyte.gz to ../data\FashionMNIST\raw


In [3]:
def xavier_init(size):
    #
    in_dim = size[0]
    xavier_stddev = 1. / np.sqrt(in_dim / 2.)
    return Variable(torch.randn(*size) * xavier_stddev, requires_grad=True)

In [4]:
# 定义生成器
Wxh = xavier_init(size=[X_dim, h_dim])
bxh = Variable(torch.zeros(h_dim), requires_grad=True)

Whz_mu = xavier_init(size=[h_dim, Z_dim])
bhz_mu = Variable(torch.zeros(Z_dim), requires_grad=True)

Whz_var = xavier_init(size=[h_dim, Z_dim])
bhz_var = Variable(torch.zeros(Z_dim), requires_grad=True)

In [5]:
def Q(x):
    h = F.relu(x @ Wxh + bxh.repeat(x.size(0), 1))
    z_mu = h @ Whz_mu + bhz_mu.repeat(h.size(0), 1)
    z_var = h @ Whz_var + bhz_var.repeat(h.size(0), 1)
    return z_mu, z_var

def sample_z(mu, log_var):
    eps = Variable(torch.randn(mu.size(0), Z_dim))
    return mu + torch.exp(log_var / 2) * eps

Wzh = xavier_init(size=[Z_dim, h_dim])
bzh = Variable(torch.zeros(h_dim), requires_grad=True)

Whx = xavier_init(size=[h_dim, X_dim])
bhx = Variable(torch.zeros(X_dim), requires_grad=True)

def P(z):
    h = F.relu(z @ Wzh + bzh.repeat(z.size(0), 1))
    X = F.sigmoid(h @ Whx + bhx.repeat(h.size(0), 1))
    return X

In [7]:
params = [Wxh, bxh, Whz_mu, bhz_mu, Whz_var, bhz_var,
          Wzh, bzh, Whx, bhx]

solver = optim.Adam(params, lr=lr)

it = 0
while it <= 20000:
    for data, _ in tqdm(train_loader, desc = "Training"):
        it += 1
        X = data.view(-1, 28 * 28)
        X = Variable(X)
        
        # Forward
        z_mu, z_var = Q(X)
        z = sample_z(z_mu, z_var)
        X_sample = P(z)
        
        # Loss
        recon_loss = F.binary_cross_entropy(X_sample, X, size_average=False) / mb_size
        kl_loss = torch.mean(0.5 * torch.sum(torch.exp(z_var) + z_mu**2 - 1. - z_var, 1))
        loss = recon_loss + kl_loss
        
        # Backward
        loss.backward()
        
        # Update
        solver.step()
        
        # 
        solver.zero_grad()
        
        # Print and plot every now and then
        if it % 1000 == 0:
            print('Iter-{}; Loss: {:.4}'.format(it, loss.item()))
            
            samples = P(z).data.numpy()[:16]
            
            fig = plt.figure(figsize=(4, 4))
            gs = gridspec.GridSpec(4, 4)
            gs.update(wspace=0.05, hspace=0.05)
            
            for i, sample in enumerate(samples):
                ax = plt.subplot(gs[i])
                plt.axis('off')
                ax.set_xticklabels([])
                ax.set_yticklabels([])
                ax.set_aspect('equal')
                plt.imshow(sample.reshape(28, 28), cmap='Greys_r')
                
            if not os.path.exists('out/'):
                os.makedirs('out/')
            
            plt.savefig('out/{}.png'.format(str(c).zfill(3)), bbox_inches='tight')
            c += 1
            plt.close(fig)

Training: 100%|██████████| 937/937 [00:12<00:00, 74.32it/s]
Training:   6%|▌         | 56/937 [00:00<00:12, 70.09it/s]

Iter-1000; Loss: 263.8


Training: 100%|██████████| 937/937 [00:13<00:00, 70.65it/s]
Training:  13%|█▎        | 124/937 [00:01<00:10, 76.47it/s]

Iter-2000; Loss: 246.3


Training: 100%|██████████| 937/937 [00:13<00:00, 68.83it/s]
Training:  20%|█▉        | 187/937 [00:02<00:10, 72.36it/s]

Iter-3000; Loss: 244.8


Training: 100%|██████████| 937/937 [00:13<00:00, 69.88it/s]
Training:  27%|██▋       | 249/937 [00:03<00:09, 72.86it/s]

Iter-4000; Loss: 255.0


Training: 100%|██████████| 937/937 [00:13<00:00, 70.93it/s]
Training:  33%|███▎      | 308/937 [00:04<00:08, 73.30it/s]

Iter-5000; Loss: 245.0


Training: 100%|██████████| 937/937 [00:13<00:00, 69.53it/s]
Training:  40%|███▉      | 371/937 [00:04<00:07, 73.89it/s]

Iter-6000; Loss: 242.0


Training: 100%|██████████| 937/937 [00:12<00:00, 72.49it/s]
Training:  47%|████▋     | 438/937 [00:06<00:06, 74.32it/s]

Iter-7000; Loss: 248.6


Training: 100%|██████████| 937/937 [00:13<00:00, 70.42it/s]
Training:  53%|█████▎    | 495/937 [00:06<00:06, 71.21it/s]

Iter-8000; Loss: 239.4


Training: 100%|██████████| 937/937 [00:13<00:00, 71.89it/s]
Training:  60%|██████    | 564/937 [00:07<00:04, 80.97it/s]

Iter-9000; Loss: 247.1


Training: 100%|██████████| 937/937 [00:13<00:00, 69.57it/s]
Training:  67%|██████▋   | 629/937 [00:08<00:03, 80.04it/s]

Iter-10000; Loss: 251.6


Training: 100%|██████████| 937/937 [00:13<00:00, 71.19it/s]
Training:  73%|███████▎  | 685/937 [00:09<00:03, 71.70it/s]

Iter-11000; Loss: 234.5


Training: 100%|██████████| 937/937 [00:13<00:00, 70.18it/s]
Training:  80%|████████  | 753/937 [00:10<00:02, 75.24it/s]

Iter-12000; Loss: 240.5


Training: 100%|██████████| 937/937 [00:13<00:00, 69.54it/s]
Training:  87%|████████▋ | 814/937 [00:10<00:01, 73.12it/s]

Iter-13000; Loss: 224.4


Training: 100%|██████████| 937/937 [00:13<00:00, 71.80it/s]
Training:  94%|█████████▍| 879/937 [00:11<00:00, 74.93it/s]

Iter-14000; Loss: 241.9


Training: 100%|██████████| 937/937 [00:13<00:00, 71.87it/s]
Training: 100%|██████████| 937/937 [00:14<00:00, 63.00it/s]
Training:   1%|          | 7/937 [00:00<00:13, 67.10it/s]

Iter-15000; Loss: 246.8


Training: 100%|██████████| 937/937 [00:14<00:00, 64.44it/s]
Training:   7%|▋         | 67/937 [00:01<00:12, 68.40it/s]

Iter-16000; Loss: 241.3


Training: 100%|██████████| 937/937 [00:14<00:00, 64.46it/s]
Training:  14%|█▍        | 132/937 [00:01<00:11, 68.70it/s]

Iter-17000; Loss: 237.3


Training: 100%|██████████| 937/937 [00:14<00:00, 65.28it/s]
Training:  21%|██        | 193/937 [00:02<00:10, 69.67it/s]

Iter-18000; Loss: 230.7


Training: 100%|██████████| 937/937 [00:14<00:00, 63.48it/s]
Training:  27%|██▋       | 254/937 [00:03<00:09, 69.50it/s]

Iter-19000; Loss: 251.5


Training: 100%|██████████| 937/937 [00:14<00:00, 64.68it/s]
Training:  34%|███▍      | 322/937 [00:04<00:09, 67.00it/s]

Iter-20000; Loss: 238.8


Training: 100%|██████████| 937/937 [00:15<00:00, 61.21it/s]
