In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, utils

device = 'cpu'

In [2]:
tfm = transforms.Compose([transforms.ToTensor()])

train = datasets.FashionMNIST(root='./download_data', train=True,  download=True, transform=tfm)
test  = datasets.FashionMNIST(root='./download_data', train=False, download=True, transform=tfm)
train_loader = DataLoader(train, batch_size=128, shuffle=True,  num_workers=2)
test_loader  = DataLoader(test,  batch_size=128, shuffle=False, num_workers=2)

In [7]:
class Vae(nn.Module):
    def __init__(self, latentDim=16):
        super().__init__()
        self.enc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(28 * 28, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU()
        )
        self.mean = nn.Linear(128, latentDim)
        self.logVar = nn.Linear(128, latentDim)

        # decoder
        self.decodeLayers = nn.Sequential(
            nn.Linear(latentDim, 128),
            nn.ReLU(),
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, 28 * 28),
            nn.Sigmoid()
        )

    def encode(self, input):
        latent = self.enc(input)
        mean = self.mean(latent)
        logvar = self.logVar(latent)

        return mean, logvar

    def decode(self, latent):
        result = self.decodeLayers(latent)
        return result.view(-1, 1, 28, 28)

    def reparameterize(self, mean, logvar):
        std = torch.exp(0.5 * logvar)
        rnd = torch.randn_like(std)

        return mean + rnd * std

    def forward(self, input_tensor):
        mean, logvar = self.encode(input_tensor)
        result = self.reparameterize(mean, logvar)

        img = self.decode(result)
        return img, mean, logvar

In [8]:
latent_dim, beta = 16, 4.0

vae = Vae(latent_dim).to(device)
opt = torch.optim.Adam(vae.parameters(), lr=1e-3)

In [9]:
def elbo_loss(x, xhat, mu, logv, beta=1.0):
    recon = F.binary_cross_entropy(xhat, x, reduction='sum') / x.size(0)
    kld = -0.5 * torch.sum(1 + logv - mu.pow(2) - logv.exp()) / x.size(0)
    return recon + beta * kld

In [10]:
## training
for epoch in range(10):
    vae.train()
    total_loss = 0

    for x, _ in train_loader:
        x = x.to(device)
        xhat, mean, logVar = vae(x)

        loss = elbo_loss(x, xhat, mean, logVar, beta)

        opt.zero_grad()
        loss.backward()
        opt.step()
        total_loss  += loss.item()

    print(f"epoch {epoch}: loss={total_loss / len(train_loader):.3f}")

epoch 0: loss=322.702
epoch 1: loss=282.884
epoch 2: loss=276.666
epoch 3: loss=273.950
epoch 4: loss=272.150
epoch 5: loss=271.211
epoch 6: loss=270.336
epoch 7: loss=269.781
epoch 8: loss=269.319
epoch 9: loss=268.775
