In [5]:
import math
import torch
import torch.nn as nn
import torch.optim as optim


In [None]:
class Simple_VAE(nn.Module):
    """Some Information about Simple_VAE"""
    def __init__(self, input_dim, latentdim):
        super().__init__()
        self.input_dim = input_dim
        hiddendim = 64
        # * 输入是默认将图片展成1维的向量
        # * 这里面是用简单的DNN来构建encoder和decoder, 实际上可以采用更复杂的网络结构，resnet或者其他blocks
        self.encoder_l1 = nn.Linear(input_dim, hiddendim)
        self.encoder_lodvar = nn.Linear(hiddendim, latentdim)
        self.encoder_mu = nn.Linear(hiddendim, latentdim)
        self.activation = nn.ReLU()
        
        self.decoder_l1 = nn.Linear(latentdim, hiddendim)
        self.decoder_l2 = nn.Linear(hiddendim, input_dim)
        self.decoder_activation = nn.Sigmoid()
        
    def encoder(self, x):
        hidden_output = self.encoder_l1(x)
        hidden_output = self.activation(hidden_output)
        
        return self.encoder_mu(hidden_output), self.encoder_lodvar(hidden_output)
    
    # * 重参数化技巧
    def sample_z(self, mu, logvar):
        # * 利用log(variance)计算标准差
        std_var = torch.exp(0.5 * logvar)
        eps = torch.rand_like(std_var)
        
        return mu + eps * std_var
    
    
    def decoder(self, z):
        
        decoder_hidden = self.decoder_l1(z)
        decoder_hidden = self.activation(decoder_hidden)
        decoder_output = self.decoder_l2(decoder_hidden)
        decoder_output = self.decoder_activation(decoder_output)
        
        return decoder_output
        

    def forward(self, x):
        
        mu, logvar = self.encoder(x.view(-1, self.input_dim))
        
        z = self.sample_z(mu, logvar)
        
        output = self.decoder(z)
        return output, mu, logvar



def loss_function(recon_x, x, mu, logvar, input_dim):
    mse_loss = nn.MSELoss()
    reconstruction_loss = mse_loss(recon_x, x.view(-1, input_dim))
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    
    return reconstruction_loss + KLD


# * 随机创建3个batch的数据，batch_size为5
train_X = torch.randint(0, 1, (3, 5, 10, 10)).float()





model = Simple_VAE(100, 128)
optimizer_adam = optim.Adam(model.parameters(), lr=0.001)
model.train()
train_loss = 0
num_epochs = 20
for epoch in range(num_epochs):
    for batch_idx, data in enumerate(train_X):
        x = data
        optimizer_adam.zero_grad()
        recon_x, mu, logvar = model(x)
        loss = loss_function(recon_x, x, mu, logvar, 100)
        loss.backward()
        train_loss += loss.item()
        optimizer_adam.step()
        
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")



Epoch [1/20], Loss: 2.9130
Epoch [2/20], Loss: 2.3976
Epoch [3/20], Loss: 1.9531
Epoch [4/20], Loss: 1.5636
Epoch [5/20], Loss: 1.2291
Epoch [6/20], Loss: 0.9396
Epoch [7/20], Loss: 0.7021
Epoch [8/20], Loss: 0.5050
Epoch [9/20], Loss: 0.3579
Epoch [10/20], Loss: 0.2493
Epoch [11/20], Loss: 0.1741
Epoch [12/20], Loss: 0.1210
Epoch [13/20], Loss: 0.0838
Epoch [14/20], Loss: 0.0576
Epoch [15/20], Loss: 0.0391
Epoch [16/20], Loss: 0.0265
Epoch [17/20], Loss: 0.0177
Epoch [18/20], Loss: 0.0120
Epoch [19/20], Loss: 0.0083
Epoch [20/20], Loss: 0.0060
