### MNIST数据集

In [None]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from tqdm import tqdm 
import torch.optim as optim
import wandb

from model.vae import VAE, CVAE, CVAE_su
from model.loss import VAE_loss

In [None]:
# 初始化wandb项目
wandb.init(project="VAE2")
# pytorch minst数据集
mean = 0.1307
std = 0.3081
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((mean,), (std,)) 
])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)


### 设置主要参数

In [None]:
batch_size  = 512
kernel_size = 3
filters     = 16
epochs      = 30
latent_dim  = 2   ## 隐变量取2维只是为了方便后面画图，适当提高可以提高生成质量，比如提高到8
device      = 1   ## 选取gpu，这里选择了第一个gpu
num_classes = 10  
image_size  = train_dataset[0][0].shape[1] ## 1 * 28 * 28

train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

### 开始训练

In [None]:
vae = CVAE_su(filters, kernel_size, latent_dim, image_size)
# vaeloss = VAE_loss()
optimizer = optim.Adam(vae.parameters(), lr=1e-3)

vae.to(device)
vae.train()

for epoch in tqdm(range(epochs)):
    
    for i, (x,y) in enumerate(train_loader):
        optimizer.zero_grad()
        x, y = x.to(device), y.to(device)
        
        x_recon, mu, logvar, loss = vae(x, y)
        recon_loss, kl_div = loss
        loss = recon_loss + kl_div
        loss.backward()
        optimizer.step()
        wandb.log({"iter": i, "reconstruction_loss": recon_loss.item(), "kl_divergence": kl_div.item()}, commit=True)


### 结果对比

In [None]:
import torchvision
from torchvision.utils import make_grid
import matplotlib.pyplot as plt

vae.eval()
for x_test, y_test in test_loader:
    x_test, y_test = x_test.to(device), y_test.to(device)
    x_recon, mu, logvar, loss = vae(x_test, y_test)
    break

original_images = x_test *std + mean
generated_images = x_recon

comparison_grid = torch.cat((original_images, generated_images), dim=2)
grid = make_grid(comparison_grid, nrow=40, padding=2).cpu()

# 展示结果，第一行是原图，第二行是生成图，以此类推
torchvision.transforms.ToPILImage()(grid).show()


### 展示每个数字类别与latent向量的关系(当latent=2时)

In [None]:
vae.eval()
for x_test, y_test in test_loader:
    x_test = x_test.to(device)
    mu, logvar = vae.encoder(x_test)
    z = vae.reparameterize(mu, logvar).cpu().detach().numpy()
    break
plt.figure(figsize=(6, 6))
plt.scatter(z[:, 0], z[:, 1], c=y_test)
plt.colorbar()
plt.show()

### vanilla VAE show

In [None]:
import numpy as np
from scipy.stats import norm


# 观察隐变量的两个维度变化是如何影响输出结果的
n = 15  # figure with 15x15 digits
digit_size = 28
figure = np.zeros((digit_size * n, digit_size * n))

#用正态分布的分位数来构建隐变量对
grid_x = norm.ppf(np.linspace(0.05, 0.95, n))
grid_y = norm.ppf(np.linspace(0.05, 0.95, n))
with torch.no_grad():
    for i, yi in enumerate(grid_x):
        for j, xi in enumerate(grid_y):
            z_sample = np.array([[xi, yi]])
            x_decoded = vae.decoder(torch.from_numpy(z_sample).to(device).float())
            digit = x_decoded[0].reshape(digit_size, digit_size).cpu()
            figure[i * digit_size: (i + 1) * digit_size,
                j * digit_size: (j + 1) * digit_size] = digit

plt.figure(figsize=(10, 10))
plt.imshow(figure, cmap='Greys_r')
plt.show()

### CVAE show

In [None]:
import numpy as np
from scipy.stats import norm
import torch.nn.functional as F

# 观察隐变量的两个维度变化是如何影响输出结果的
n = 15  # figure with 15x15 digits
digit_size = 28
figure = np.zeros((digit_size * n, digit_size * n))
target = 9 ## 生成数字9的图片
#用正态分布的分位数来构建隐变量对
grid_x = norm.ppf(np.linspace(0.05, 0.95, n))
grid_y = norm.ppf(np.linspace(0.05, 0.95, n))
with torch.no_grad():
    for i, yi in enumerate(grid_x):
        for j, xi in enumerate(grid_y):
            z_sample = torch.from_numpy(np.array([[xi, yi]])).to(device).float() ## 1*2
            y = torch.tensor([target]).to(device)  ## 1*1
            y = F.one_hot(y, num_classes=num_classes).float() ## 1*num_classes
            z = torch.cat([z_sample, y], dim=1)
            x_decoded = vae.decoder(z)
            digit = x_decoded[0].reshape(digit_size, digit_size).cpu()
            figure[i * digit_size: (i + 1) * digit_size,
                j * digit_size: (j + 1) * digit_size] = digit

plt.figure(figsize=(10, 10))
plt.imshow(figure, cmap='Greys_r')
plt.show()

### CVAE_su show

In [None]:

import numpy as np
from scipy.stats import norm
import torch.nn.functional as F

# 观察隐变量的两个维度变化是如何影响输出结果的
n = 15  # figure with 15x15 digits
digit_size = 28
figure = np.zeros((digit_size * n, digit_size * n))
target = 9 ## 生成数字9的图片
class_mu = vae.encoder_class(torch.eye(num_classes).to(device).float())
class_mu = class_mu.cpu().detach().numpy()
#用正态分布的分位数来构建隐变量对
grid_x = norm.ppf(np.linspace(0.05, 0.95, n)) + class_mu[target][0]
grid_y = norm.ppf(np.linspace(0.05, 0.95, n)) + class_mu[target][1]
with torch.no_grad():
    for i, yi in enumerate(grid_x):
        for j, xi in enumerate(grid_y):
            z = torch.from_numpy(np.array([[xi, yi]])).to(device).float() ## 1*2
            # y = torch.tensor([target]).to(device)  ## 1*1
            # y = F.one_hot(y, num_classes=num_classes).float() ## 1*num_classes
            # z = torch.cat([z_sample, y], dim=1)
            x_decoded = vae.decoder(z)
            digit = x_decoded[0].reshape(digit_size, digit_size).cpu()
            figure[i * digit_size: (i + 1) * digit_size,
                j * digit_size: (j + 1) * digit_size] = digit

plt.figure(figsize=(10, 10))
plt.imshow(figure, cmap='Greys_r')
plt.show()