In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, utils

device = 'cpu'

In [None]:
tfm = transforms.Compose([transforms.ToTensor()])

train = datasets.FashionMNIST(root='./download_data', train=True,  download=True, transform=tfm)
test  = datasets.FashionMNIST(root='./download_data', train=False, download=True, transform=tfm)
train_loader = DataLoader(train, batch_size=128, shuffle=True,  num_workers=2)
test_loader  = DataLoader(test,  batch_size=128, shuffle=False, num_workers=2)

In [None]:
class Vae(nn.Module):
    def __init__(self, latentDim=16):
        super().__init__()
        self.enc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(28 * 28, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU()
        )
        self.mean = nn.Linear(128, latentDim)
        self.logVar = nn.Linear(128, latentDim)

        # decoder
        self.decode = nn.Sequential(
            nn.Linear(latentDim, 128),
            nn.ReLU(),
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, 28 * 28),
            nn.Sigmoid()
        )

    def encode(self, input):
        latent = self.enc(input)
        mean = self.mean(latent)
        logvar = self.logVar(latent)

        return mean, logvar

    def decode(self, latent):
        result = self.decode(latent)
        return result.view(-1, 1, 28, 28)

    def reparameterize(self, mean, logvar):
        std = torch.exp(0.5 * logvar)
        rnd = torch.randn_like(std)

        return mean + rnd * std

    def forward(self, input_tensor):
        mean, logvar = self.encode(input_tensor)
        result = self.reparameterize(mean, logvar)

        img = self.decode(result)
        return img, mean, logvar

In [None]:
latent_dim, beta = 16, 4.0

vae = Vae(latent_dim).to(device)
opt = torch.optim.Adam(vae.parameters(), lr=1e-3)