### MNIST数据集

In [None]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from tqdm import tqdm 
import torch.optim as optim
import wandb
import torch.nn.functional as F

# 初始化wandb项目
wandb.init(project="VAE2")
# pytorch minst数据集
mean = 0.1307
std = 0.3081
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((mean,), (std,)) 
])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)



### 设置主要参数

In [None]:
batch_size  = 512
kernel_size = 3
filters     = 16
epochs      = 30
latent_dim  = 2   ## 隐变量取2维只是为了方便后面画图，适当提高可以提高生成质量，比如提高到8
device      = 0   ## 选取gpu，这里选择了第一个gpu

image_size  = train_dataset[0][0].shape[1] ## 1 * 28 * 28
features    = 2*filters*(image_size//4)**2 ## 两层卷积后的特征向量长度

train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

### Encoder, Decoder, 和 loss 定义

In [None]:
## Encoder, Decoder也可以不用卷积，只用全连接层

class Encoder(torch.nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.conv1 = torch.nn.Conv2d(in_channels=1, out_channels=filters, kernel_size=kernel_size, stride=2, padding=1)
        self.conv2 = torch.nn.Conv2d(in_channels=filters, out_channels=filters*2, kernel_size=kernel_size, stride=2, padding=1)
        self.fc = torch.nn.Linear(in_features=features, out_features=filters)
        
        self.mean = torch.nn.Linear(in_features=filters, out_features=latent_dim)
        self.varlog = torch.nn.Linear(in_features=filters, out_features=latent_dim)
    
    def forward(self, x):
        x = torch.nn.functional.relu(self.conv1(x))
        x = torch.nn.functional.relu(self.conv2(x))
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        mean = self.mean(x)
        varlog = self.varlog(x)
        return mean, varlog
    
class Decoder(torch.nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.fc = torch.nn.Linear(in_features=latent_dim, out_features=features)
        self.conv2 = torch.nn.ConvTranspose2d(in_channels=filters*2, out_channels=filters, kernel_size=kernel_size, stride=2, padding=1, output_padding=1)
        self.conv1 = torch.nn.ConvTranspose2d(in_channels=filters, out_channels=1, kernel_size=kernel_size, stride=2, padding=1, output_padding=1)
    
    def forward(self, x):
        x = self.fc(x)
        x = x.view(x.size(0), filters*2, image_size//4, image_size//4)
        x = torch.nn.functional.relu(self.conv2(x))
        x = torch.nn.functional.sigmoid(self.conv1(x))
        return x
    
class VAE(torch.nn.Module):
    def __init__(self):
        super(VAE, self).__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()
    
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std
    
    def forward(self, x):
        mu, logvar = self.encoder(x)
        z = self.reparameterize(mu, logvar)
        x_recon = self.decoder(z)
        return x_recon, mu, logvar
    
class VAEloss(torch.nn.Module):
    def __init__(self):
        super(VAEloss, self).__init__()
    
    def forward(self, x, x_recon, mu, logvar):
        ## BCE loss
        recon_loss = torch.nn.functional.binary_cross_entropy(x_recon, x,reduction='sum', size_average=False)
        ## MSE loss
        # recon_loss = F.mse_loss(x_recon.view(-1, 784), x.view(-1, 784), reduction='sum')
        kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) ## batch_size, 2
        return recon_loss/ batch_size, kl_div / batch_size

### 开始训练

In [None]:
vae = VAE()
vaeloss = VAEloss()
optimizer = optim.Adam(vae.parameters(), lr=1e-3)

vae.to(device)
vae.train()

for epoch in tqdm(range(epochs)):
    
    for i, (x,y) in enumerate(train_loader):
        optimizer.zero_grad()
        x = x.to(device)
        
        x_recon, mu, logvar = vae(x)
        recon_loss, kl_div = vaeloss(x, x_recon, mu, logvar)
        loss = recon_loss + kl_div
        
        loss.backward()
        optimizer.step()
        wandb.log({"iter": i, "reconstruction_loss": recon_loss.item(), "kl_divergence": kl_div.item()}, commit=True)


### 结果对比

In [None]:
import torchvision
from torchvision.utils import make_grid
import matplotlib.pyplot as plt

vae.eval()
for x_test, y_test in test_loader:
    x_test = x_test.to(device)
    x_recon, mu, logvar = vae(x_test)
    break

original_images = x_test
generated_images = x_recon

comparison_grid = torch.cat((original_images, generated_images), dim=2)
grid = make_grid(comparison_grid, nrow=40, padding=2).cpu()*std + mean

# 展示结果，第一行是原图，第二行是生成图，以此类推
torchvision.transforms.ToPILImage()(grid).show()


### 展示每个数字类别与latent向量的关系(当latent=2时)

In [None]:
vae.eval()
for x_test, y_test in test_loader:
    x_test = x_test.to(device)
    mu, logvar = vae.encoder(x_test)
    z = vae.reparameterize(mu, logvar).cpu().detach().numpy()
    break
plt.figure(figsize=(6, 6))
plt.scatter(z[:, 0], z[:, 1], c=y_test)
plt.colorbar()
plt.show()