In [1]:
import os
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms
from torchvision.utils import save_image

In [2]:
# Define hyperparameters
image_size = 784
hidden_dim = 400
latent_dim = 20
batch_size = 128
epochs = 10

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# MNIST dataset
train_dataset = torchvision.datasets.MNIST(root='./data',
                                           train=True,
                                           transform=transforms.ToTensor())

test_dataset = torchvision.datasets.MNIST(root='./data',
                                          train=False,
                                          transform=transforms.ToTensor())

train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size, 
                                           shuffle=True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=batch_size, 
                                          shuffle=True)

# Create directory to save the reconstructed and sampled images (if directory not present)
sample_dir = 'results'
if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)

In [3]:
# VAE model
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()

        self.fc1 = nn.Linear(image_size, hidden_dim)
        self.fc2_mean = nn.Linear(hidden_dim, latent_dim)
        self.fc2_logvar = nn.Linear(hidden_dim, latent_dim)
        self.fc3 = nn.Linear(latent_dim, hidden_dim)
        self.fc4 = nn.Linear(hidden_dim, image_size)
        self.relu = nn.ReLU()

    def encode(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        mu = self.fc2_mean(out)
        log_var = self.fc2_logvar(out)
        return mu, log_var

    def reparameterize(self, mu, log_var):
        std = torch.exp(log_var/2)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        out = self.fc3(z)
        out = self.relu(out)
        out = self.fc4(out)
        out = torch.sigmoid(out)
        return out

    def forward(self, x):
        # x: (batch_size, 1, 28, 28) => (batch_size, 784)
        mu, log_var = self.encode(x.view(-1, image_size))
        z = self.reparameterize(mu, log_var)
        reconstructed = self.decode(z)
        return reconstructed, mu, log_var

In [4]:
# Define model
model = VAE().to(device)

# Define Loss
def loss_function(reconstructed_image, original_image, mu, logvar):
    bce = nn.functional.binary_cross_entropy(reconstructed_image, original_image.view(-1, 784), reduction = 'sum')
    # kld = torch.sum(0.5 * torch.sum(logvar.exp() + mu.pow(2) - 1 - logvar, 1))
    kld = 0.5 * torch.sum(logvar.exp() + mu.pow(2) - 1 - logvar)
    return bce + kld

# Define Optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [5]:
# Train function
def train(epoch):

    model.train()
    train_loss = 0

    for i, (images, _) in enumerate(train_loader):
        images = images.to(device)
        reconstructed, mu, log_var = model(images)
        loss = loss_function(reconstructed, images, mu, log_var)
        optimizer.zero_grad()
        loss.backward()
        train_loss += loss.item()
        optimizer.step()

        if i % 100 == 0:
            print(f"Train Epoch {epoch} [Batch {i}/{len(train_loader)}]\tLoss: {loss.item()/len(images):.3f}")
    
    print(f"=====> Epoch {epoch}, Average Loss: {train_loss/len(train_loader.dataset):.3f}")

In [6]:
# Test Function
def test(epoch):

    model.eval()
    test_loss = 0

    with torch.no_grad():
        for i, (images, _) in enumerate(train_loader):
            images = images.to(device)
            reconstructed, mu, log_var = model(images)
            test_loss += loss_function(reconstructed, images, mu, log_var).item()

            if i == 0:
                comparison = torch.cat([images[:5], reconstructed.view(batch_size, 1, 28, 28)[:5]])
                save_image(comparison.cpu(), f'results/reconstruction_{str(epoch)}.png', nrow=5)

    print(f"=====> Average Test Loss: {test_loss/len(train_loader.dataset):.3f}")

In [7]:
# Main function
for epoch in range(1, epochs + 1):
    train(epoch)
    test(epoch)
    with torch.no_grad():
        # Get rid of the encoder and sample z form the gaussian distribution
        # and feed it into the decoder to generate samples
        sample = torch.randn(64, 20).to(device)
        generated = model.decode(sample).cpu()
        save_image(generated.view(64, 1, 28, 28), f'results/sample_{str(epoch)}.png')

RuntimeError: Boolean value of Tensor with more than one value is ambiguous