In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision import datasets
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as functional

import matplotlib.pyplot as plt
import numpy as np

import warnings
warnings.filterwarnings("ignore")

In [2]:
# Создаём VAE
class VAE(nn.Module):
    
    def __init__(self):
        super(VAE, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 16, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 32, 3, stride=2, padding=1), 
            nn.ReLU(),
            nn.Conv2d(32, 64, 7), 
            nn.ReLU(),
            Compress())
        
        # Матиматическое ожидание
        self.fc_mu = nn.Linear(64, 64)
        # Дисперсия  
        self.fc_logvar = nn.Linear(64, 64) 
     
        # Decoder:
        self.decoder = nn.Sequential(
            Decompress(),
            nn.ConvTranspose2d(64, 32, 7), 
            nn.ReLU(),
            nn.ConvTranspose2d(32, 16, 3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(16, 1, 3, stride=2, padding=1, output_padding=1),
            nn.Sigmoid()) # get probability
        

    # reparametrization trick:
    def sample(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        e = torch.randn_like(std)
        z = mu + (std * e)
        return z
    

    def generateExample(self, num_examples):
        # x must be 64, 1, 1 tensor
        mu = torch.zeros(num_examples, 64)
        logvar = torch.ones(num_examples, 64)
        z = self.sample(mu, logvar)
        return self.decoder(z)
    
    

    def forward(self, x): # x - batch of images
        print(f"Before encoding: {x.shape}")
        x = self.encoder(x)
        print(f"After encoding: {x.shape}")
        
        mu = self.fc_mu(x)
        mu = torch.zeros(10, 64)
        print(f"Mean (mu) Shape: {mu.shape}")
        logvar = self.fc_logvar(x)
        logvar = torch.ones(10, 64)
        print(f"Log Variance Shape: {logvar.shape}")
        
        z = self.sample(mu, logvar)
        print(f"Z Shape: {z.shape}")
        
        recon_x = self.decoder(z)
        #mu 0 logvar 1 for reconstrion
        print(f"Reconstructed x After Decoding: {recon_x.shape}")
        return recon_x, mu, logvar

    
class Compress(nn.Module):
    def forward(self, input):
        return input.view(input.size(0), -1) # растягиваем тензор
    
    
class Decompress(nn.Module):
    def forward(self, input, size=28*28):
        print(input.view(input.size(0), input.size(1), 1, 1).shape)
        return input.view(input.size(0), input.size(1), 1, 1)

def loss_fn(recon_x, x, mu, logvar):
    
    # squared error MSELoss
    loss = functional.mse_loss(recon_x, x, reduction="sum")
    KL_Div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return loss + KL_Div

In [None]:
# Получаем наш MNIST
transform = transforms.ToTensor()

dataset = datasets.MNIST(
    root = "data",
    download = True,
    transform = transform
    )

batch_size = 10

# Создаём DATA LOADER
train_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)

In [4]:
# Показываем изображения
def showImage(img, img_recon, epoch):
    
    # unnormalize
    fig = plt.figure()
    fig.add_subplot(1, 2, 1)
    
    img = img.numpy()
    plt.title(label="Original Epoch: #"+str(epoch))
    plt.imshow(np.transpose(img, (1, 2, 0)))
    fig.add_subplot(1, 2, 2)
    
    img_recon = img_recon.numpy()
    plt.title(label="Reconstruction Epoch: #"+str(epoch))
    plt.imshow(np.transpose(img_recon, (1, 2, 0)))
    plt.show(block=True)
    
def showExample(img):
    
    # unnormalize
    fig = plt.figure()
    fig.add_subplot(1, 2, 1)
    img = img.numpy()
    plt.title(label="Generated Example")
    plt.imshow(np.transpose(img, (1, 2, 0)))
    plt.show(block=True)

In [None]:
model = VAE()

learning_rate = .001  # шаг спуска - 0.001
optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)
epochs = 10
running_loss = 0.0


for t in range(epochs):
    print(f"Epoch {t+1}\n-----------------------------------")
    for idx, (images, _) in enumerate(train_dataloader):
        
        images = images
        recon_x, mu, logvar = model(images)
        loss = loss_fn(recon_x, images, mu, logvar)
        
        optimizer.zero_grad()   # обнулим градиенты
        loss.backward()         # calc gradients
        optimizer.step()        # do step of grads
        

        running_loss += loss.item()
        if idx % 2000 == 1999:    # print every 2000 mini-batches
            
            #output the original image and the reconstructed image
            showImage( torchvision.utils.make_grid(images), torchvision.utils.make_grid(recon_x.to("cpu")), t+1 )        
            
            print('loss: %.3f' %(running_loss / 2000))
            running_loss = 0.0
    
print("Generating a random example...")
num_examples = 10
                      

example = model.generateExample(num_examples)
showExample(torchvision.utils.make_grid(example))