In [1]:
import torch
import numpy as np
import torch.nn as nn
import matplotlib.pyplot as plt
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
from torchvision.transforms import v2
import torch.optim as optim

In [2]:
transform = v2.Compose([
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True)
])

data = MNIST('./mnist_data', transform=transform, download=True)

In [3]:
loader = DataLoader(dataset=data, batch_size=100, shuffle=True)

In [4]:
class VAE(nn.Module):

    def __init__(self, hidden_dim=400, latent_dim=200):
        super(VAE, self).__init__()

        # encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(
                in_channels=1,
                out_channels=32,
                kernel_size=3,
                stride=1,
                padding=1,
            ),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(
                in_channels=32,
                out_channels=64,
                kernel_size=5,
                stride=1,
                padding=1,
            ),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
            nn.Flatten(),
            nn.Linear(2304, latent_dim)
            )
        
        # latent mean and variance 
        self.mean_layer = nn.Linear(latent_dim, 2)
        self.logvar_layer = nn.Linear(latent_dim, 2)
        
        # decoder
        self.decoder = nn.Sequential(
            nn.Linear(2, latent_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(latent_dim, 2304),
            nn.LeakyReLU(0.2),
            nn.Unflatten(1,(64,6,6)),
            nn.ConvTranspose2d(64, 32, 3, 1, 1),
            nn.LeakyReLU(0.2),
            nn.Upsample(scale_factor=2),
            nn.Flatten(),
            nn.Linear(4608, 6272),
            nn.LeakyReLU(0.2),
            nn.Unflatten(1,(32, 14, 14)),
            nn.Upsample(scale_factor=2),
            nn.ConvTranspose2d(32, 1, 3, 1, 1),
            nn.Sigmoid(),
            )
     
    def encode(self, x):
        x = self.encoder(x)
        mean, logvar = self.mean_layer(x), self.logvar_layer(x)
        return mean, logvar

    def reparameterization(self, mean, var):
        epsilon = torch.randn_like(var).to('cuda')      
        z = mean + var*epsilon
        return z

    def decode(self, x):
        return self.decoder(x)

    def forward(self, x):
        mean, logvar = self.encode(x)
        z = self.reparameterization(mean, logvar)
        x_hat = self.decode(z)
        return x_hat, mean, logvar

In [5]:
'''
v=VAE().to('cuda')
x=data[2][0].reshape((1,1,28,28)).to('cuda')
#print(x.shape)
x1=v.encoder(x)
x1m=v.mean_layer(x1)
x1v=v.mean_layer(x1)
r=v.reparameterization(x1m,x1v)
v.decoder(r).shape
'''

"\nv=VAE().to('cuda')\nx=data[2][0].reshape((1,1,28,28)).to('cuda')\n#print(x.shape)\nx1=v.encoder(x)\nx1m=v.mean_layer(x1)\nx1v=v.mean_layer(x1)\nr=v.reparameterization(x1m,x1v)\nv.decoder(r).shape\n"

In [6]:
model = VAE().to('cuda')
optimizer = optim.Adam(model.parameters(), lr=1e-3)

In [7]:
def loss_function(x, x_hat, mean, log_var):
    reproduction_loss = nn.functional.binary_cross_entropy(x_hat, x, reduction='sum')
    KLD = - 0.5 * torch.sum(1+ log_var - mean.pow(2) - log_var.exp())

    return reproduction_loss + KLD

In [None]:
def train(model, optimizer, epochs, x_dim=784):
    model.train()
    for epoch in range(epochs):
        overall_loss = 0
        for batch_idx, (x, _) in enumerate(loader):
            x = x.to('cuda')

            optimizer.zero_grad()

            x_hat, mean, log_var = model(x)
            loss = loss_function(x, x_hat, mean, log_var)
            
            overall_loss += loss.item()
            
            loss.backward()
            optimizer.step()

        print("\tEpoch", epoch + 1, "\tAverage Loss: ", overall_loss/(batch_idx*100))
    return overall_loss

train(model, optimizer, epochs=50)

	Epoch 1 	Average Loss:  206.4504948351419
	Epoch 2 	Average Loss:  206.45405923935726
	Epoch 3 	Average Loss:  206.44766593410893
	Epoch 4 	Average Loss:  206.4532203738001
	Epoch 5 	Average Loss:  206.4554186339211
	Epoch 6 	Average Loss:  206.4667662901711
	Epoch 7 	Average Loss:  206.44853473236645
	Epoch 8 	Average Loss:  206.45148108175084
	Epoch 9 	Average Loss:  206.44602850453882
	Epoch 10 	Average Loss:  206.46086801570326
	Epoch 11 	Average Loss:  206.44810067560516
	Epoch 12 	Average Loss:  206.44399686326167
	Epoch 13 	Average Loss:  206.44982793588272
	Epoch 14 	Average Loss:  206.44605690473705
	Epoch 15 	Average Loss:  206.44927724593072
	Epoch 16 	Average Loss:  206.44696391772746
	Epoch 17 	Average Loss:  206.45752797631468
	Epoch 18 	Average Loss:  206.44868517842238
	Epoch 19 	Average Loss:  206.44247326272955
	Epoch 20 	Average Loss:  206.4359237074812
	Epoch 21 	Average Loss:  206.4458152911102
	Epoch 22 	Average Loss:  206.45407684682806
	Epoch 23 	Average Loss: 

In [None]:
z_sample = torch.tensor([[0, 0.0]],dtype=torch.float32).to('cuda')
x_decoded = model.decode(z_sample)
digit = x_decoded.detach().cpu().reshape(28,28)
plt.imshow(digit, cmap='gray')
plt.axis('off')
plt.show()

In [None]:
def plot_latent_space(model, scale=1.0, n=25, digit_size=28, figsize=15):
    # display a n*n 2D manifold of digits
    figure = np.zeros((digit_size * n, digit_size * n))

    # construct a grid 
    grid_x = np.linspace(-scale, scale, n)
    grid_y = np.linspace(-scale, scale, n)[::-1]

    for i, yi in enumerate(grid_y):
        for j, xi in enumerate(grid_x):
            z_sample = torch.tensor([[xi, yi]], dtype=torch.float).to('cuda')
            x_decoded = model.decode(z_sample)
            digit = x_decoded[0].detach().cpu().reshape(digit_size, digit_size)
            figure[i * digit_size : (i + 1) * digit_size, j * digit_size : (j + 1) * digit_size,] = digit

    plt.figure(figsize=(figsize, figsize))
    plt.title('VAE Latent Space Visualization')
    start_range = digit_size // 2
    end_range = n * digit_size + start_range
    pixel_range = np.arange(start_range, end_range, digit_size)
    sample_range_x = np.round(grid_x, 1)
    sample_range_y = np.round(grid_y, 1)
    plt.xticks(pixel_range, sample_range_x)
    plt.yticks(pixel_range, sample_range_y)
    plt.xlabel("mean, z [0]")
    plt.ylabel("var, z [1]")
    plt.imshow(figure, cmap="Greys_r")
    plt.show()


plot_latent_space(model)