In [1]:
from torchvision import transforms
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
import torchvision
import torch
import torch.nn as nn
from torch.autograd import Variable
import numpy as np
import torch.nn.functional as F

# Device (GPU or CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
from torchvision import datasets
from torch.utils.data import sampler
batch_size = 256
lr = 1e-4
n_epoch = 100
train_loader = DataLoader(
    datasets.ImageFolder('/datasets/CelebA-stargan', transforms.Compose([
        transforms.CenterCrop((128,128)),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))  # 用均值和方差归一化图片
    ])),
    batch_size=batch_size, shuffle=False,
    num_workers=32,pin_memory=True)

In [3]:
class VAE_autoencoder(nn.Module):
    def __init__(self):
        super(VAE_autoencoder, self).__init__()
        
        #define encoder & decoder
        
        self.encoder = nn.Sequential(
            nn.Linear(16384 * 3, 2048),
            nn.ReLU(),
            nn.Linear(2048, 1024),
            nn.ReLU(),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 20),
            nn.ReLU(),
        )
        
        self.decoder = nn.Sequential(
            nn.Linear(12, 20),
            nn.ReLU(),
            nn.Linear(20, 128),
            nn.ReLU(),
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Linear(1024, 2048),
            nn.ReLU(),
            nn.Linear(2048, 16384 * 3),
            nn.Sigmoid()
        )
        
        self.fc_m = nn.Linear(20,12)
        self.fc_sigma = nn.Linear(20,12)
        
    def forward(self, x):
        code = x.view(x.size(0), -1)
        code = self.encoder(code)
        
        # m, sigma
        m = self.fc_m(code)
        sigma = self.fc_sigma(code)
        
        # define e
        e = torch.randn_like(sigma)
        
        #define c
        c = torch.exp(sigma) * e + m
        output = self.decoder(c)
#         output = output.view(x.size(0), 3, 128, 128)
        
        return output, m, sigma

In [4]:
model = VAE_autoencoder().to(device)
criterion = nn.MSELoss().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr = lr)
model

VAE_autoencoder(
  (encoder): Sequential(
    (0): Linear(in_features=49152, out_features=2048, bias=True)
    (1): ReLU()
    (2): Linear(in_features=2048, out_features=1024, bias=True)
    (3): ReLU()
    (4): Linear(in_features=1024, out_features=512, bias=True)
    (5): ReLU()
    (6): Linear(in_features=512, out_features=256, bias=True)
    (7): ReLU()
    (8): Linear(in_features=256, out_features=128, bias=True)
    (9): ReLU()
    (10): Linear(in_features=128, out_features=20, bias=True)
    (11): ReLU()
  )
  (decoder): Sequential(
    (0): Linear(in_features=12, out_features=20, bias=True)
    (1): ReLU()
    (2): Linear(in_features=20, out_features=128, bias=True)
    (3): ReLU()
    (4): Linear(in_features=128, out_features=256, bias=True)
    (5): ReLU()
    (6): Linear(in_features=256, out_features=512, bias=True)
    (7): ReLU()
    (8): Linear(in_features=512, out_features=1024, bias=True)
    (9): ReLU()
    (10): Linear(in_features=1024, out_features=2048, bias=True)
 

In [None]:
losses = []
for epoch in range(n_epoch):
     
    #trianing Net
    for batch_idx, (real_image, _) in enumerate(train_loader):
        
        real_image = real_image.view(real_image.shape[0], -1).to(device)
        fake_image, m, sigma = model(real_image)
        
        KLD = 0.5 * torch.sum(
            torch.pow(m, 2) +
            torch.pow(sigma, 2) -
            torch.log(1e-8 + torch.pow(sigma, 2)) - 1
        ) / (real_image.size(0)*128*128)
        KLD = KLD.to(device)
        
        MSE = criterion(fake_image, real_image)
        
        loss = KLD + MSE
        
        # updata parametor
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
    print('epoch is: {}, Loss is:{:.4f}'.format(epoch+1, loss.data))

epoch is: 1, Loss is:0.7193
epoch is: 2, Loss is:0.7178
epoch is: 3, Loss is:0.7158
epoch is: 4, Loss is:0.7129
epoch is: 5, Loss is:0.7083
epoch is: 6, Loss is:0.6950
epoch is: 7, Loss is:0.6247
epoch is: 8, Loss is:0.4233
epoch is: 9, Loss is:0.3491
epoch is: 10, Loss is:0.3416
epoch is: 11, Loss is:0.3547
epoch is: 12, Loss is:0.3652
epoch is: 13, Loss is:0.3702
epoch is: 14, Loss is:0.3689
epoch is: 15, Loss is:0.3631
epoch is: 16, Loss is:0.3554
epoch is: 17, Loss is:0.3480
epoch is: 18, Loss is:0.3415
epoch is: 19, Loss is:0.3358
epoch is: 20, Loss is:0.3318
epoch is: 21, Loss is:0.3285
epoch is: 22, Loss is:0.3252
epoch is: 23, Loss is:0.3240
epoch is: 24, Loss is:0.3206
epoch is: 25, Loss is:0.3204
epoch is: 26, Loss is:0.3161
epoch is: 27, Loss is:0.3170
epoch is: 28, Loss is:0.3134
epoch is: 29, Loss is:0.3139
epoch is: 30, Loss is:0.3119
epoch is: 31, Loss is:0.3111
epoch is: 32, Loss is:0.3100
epoch is: 33, Loss is:0.3085
epoch is: 34, Loss is:0.3076
epoch is: 35, Loss is:0

In [None]:
ax = plt.figure().gca()
ax.plot(losses)
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['train'])
plt.title('Loss over training epochs')
plt.show()