In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
import torch.optim as optim
from torchvision.utils import save_image
from matplotlib import pyplot as plt

if torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

In [2]:
bs = 32
# MNIST Dataset
train_dataset = datasets.MNIST(root='./mnist_data/',
                               train=True,
                               transform=transforms.ToTensor(),
                               download=True)
test_dataset = datasets.MNIST(root='./mnist_data/',
                              train=False,
                              transform=transforms.ToTensor(),
                              download=False)

# Data Loader (Input Pipeline)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=bs, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=bs, shuffle=False)

In [3]:
class Encoder(nn.Module):
    def __init__(self, x_dim, h_dim1, h_dim2, z_dim):
        super(Encoder, self).__init__()
        self.lin1 = nn.Linear(x_dim, h_dim1)
        self.lin2 = nn.Linear(h_dim1, h_dim2)
        self.lin31 = nn.Linear(h_dim2, z_dim)
        self.lin32 = nn.Linear(h_dim2, z_dim)
        
    def forward(self, x):
        x = F.relu(self.lin1(x))
        x = F.relu(self.lin2(x))
        return self.lin31(x), self.lin32(x) # mu, log_var

In [4]:
class Decoder(nn.Module):
    def __init__(self, x_dim, h_dim1, h_dim2, z_dim):
        super(Decoder, self).__init__()
        self.lin4 = nn.Linear(z_dim, h_dim2)
        self.lin5 = nn.Linear(h_dim2, h_dim1)
        self.lin6 = nn.Linear(h_dim1, x_dim)

    def forward(self, z):
        h = F.relu(self.lin4(z))
        h = F.relu(self.lin5(h))
        return torch.sigmoid(self.lin6(h)) 

In [5]:
class VAE(nn.Module):
    def __init__(self, x_dim, h_dim1, h_dim2, z_dim):
        super(VAE, self).__init__()
        self.encoder = Encoder(x_dim, h_dim1, h_dim2, z_dim)
        self.decoder = Decoder(x_dim, h_dim1, h_dim2, z_dim)
        
    def sampling(self, mu, log_var): 
        std = torch.sqrt(torch.exp(log_var))
        eps = torch.randn_like(std)
        return eps * std + mu

    def forward(self, x):
        mu, log_var = self.encoder.forward(x)
        z = self.sampling(mu, log_var)
        mu_d = self.decoder(z)
        return mu_d, mu, log_var
    
    def generate(self, z):
        return self.decoder(z)

In [6]:
vae = VAE(x_dim=784, h_dim1= 512, h_dim2=256, z_dim=2)
vae = vae.to(device)
optimizer = optim.Adam(vae.parameters())

In [7]:
def kl_mvn(mu, sigma, device):
    N = mu.shape[1]
    mu = mu.reshape(mu.shape[0], mu.shape[1], 1)
    tr_term   = torch.sum(sigma, axis = 1)
    det_term  = -torch.log(torch.prod(sigma, axis = 1))
    quad_term = torch.transpose(mu, 1, 2) @ mu
    return torch.sum(0.5 * (tr_term + det_term + quad_term - N))

In [8]:
def loss_function(recon_x, x, mu, log_var):
    BCE =  F.binary_cross_entropy(recon_x, x, reduction='sum')
    #print(f'BCE {BCE}')
    KLD = kl_mvn(mu, torch.exp(log_var), device)
    #print(f'KLD {KLD}')
    return BCE + 0.1 * KLD

In [17]:
def test():
    vae.eval()
    test_loss = 0
    with torch.no_grad():
        for batch_idx, (data, _) in enumerate(test_loader):
            data = data.reshape([-1, 784])
            recon, mu, log_var = vae(data)
            test_loss += loss_function(recon, data, mu, log_var)
        
    test_loss /= len(test_loader.dataset)
    print(f'====> Test set loss: {test_loss}')

In [12]:
def train(epoch):
    vae.train()
    train_loss = 0
    batch_idx = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.to(device)
        data = data.reshape([-1, 784])
        optimizer.zero_grad()
        
        recon_batch, mu, log_var = vae.forward(data)
        loss = loss_function(recon_batch, data, mu, log_var)
        
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        batch_idx += 1

    print(f'====> Epoch: {epoch} Average loss: {train_loss / len(train_loader.dataset)}')

In [15]:
for epoch in range(1, 3):
    train(epoch)
    #test()

====> Epoch: 1 Average loss: 161.43565283203125
====> Epoch: 2 Average loss: 159.4990600423177


In [18]:
vae = vae.cpu()
test()
vae = vae.to(device)

dupa1
dupa2
dupa3
dupa4
dupa5
dupa6
dupa7
dupa4
dupa5
dupa6
dupa7
dupa4
dupa5
dupa6
dupa7
dupa4
dupa5
dupa6
dupa7
dupa4
dupa5
dupa6
dupa7
dupa4
dupa5
dupa6
dupa7
dupa4
dupa5
dupa6
dupa7
dupa4
dupa5
dupa6
dupa7
dupa4
dupa5
dupa6
dupa7
dupa4
dupa5
dupa6
dupa7
dupa4
dupa5
dupa6
dupa7
dupa4
dupa5
dupa6
dupa7
dupa4
dupa5
dupa6
dupa7
dupa4
dupa5
dupa6
dupa7
dupa4
dupa5
dupa6
dupa7
dupa4
dupa5
dupa6
dupa7
dupa4
dupa5
dupa6
dupa7
dupa4
dupa5
dupa6
dupa7
dupa4
dupa5
dupa6
dupa7
dupa4
dupa5
dupa6
dupa7
dupa4
dupa5
dupa6
dupa7
dupa4
dupa5
dupa6
dupa7
dupa4
dupa5
dupa6
dupa7
dupa4
dupa5
dupa6
dupa7
dupa4
dupa5
dupa6
dupa7
dupa4
dupa5
dupa6
dupa7
dupa4
dupa5
dupa6
dupa7
dupa4
dupa5
dupa6
dupa7
dupa4
dupa5
dupa6
dupa7
dupa4
dupa5
dupa6
dupa7
dupa4
dupa5
dupa6
dupa7
dupa4
dupa5
dupa6
dupa7
dupa4
dupa5
dupa6
dupa7
dupa4
dupa5
dupa6
dupa7
dupa4
dupa5
dupa6
dupa7
dupa4
dupa5
dupa6
dupa7
dupa4
dupa5
dupa6
dupa7
dupa4
dupa5
dupa6
dupa7
dupa4
dupa5
dupa6
dupa7
dupa4
dupa5
dupa6
dupa7
dupa4
dupa5
dupa6
dupa

In [60]:
with torch.no_grad():
    z = torch.randn(64, 2).to(device)
    sample = vae.generate(z).cpu()
    save_image(sample.view(64, 1, 28, 28), './samples/sample' + '.png')