In [1]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets, transforms

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
class VAE(nn.Module):
    def __init__(
        self,
        encoder: nn.Module,
        decoder: nn.Module,
        encoding_dim: int,
        latent_dim: int,
        recons_criterion = torch.nn.MSELoss(reduction='sum'),
    ):
        super(VAE, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

        self.mu = nn.Linear(encoding_dim, latent_dim)
        self.sigma = nn.Sequential(
            nn.Linear(self.encoding_dim, self.latent_dim), nn.Softplus()
        )
        self.recons_criterion = recons_criterion

    def encode(self, x: torch.Tensor):
        h = self.encoder(x)
        return self.mu(h), self.sigma(h)

    def decode(self, z: torch.Tensor):
        return self.decoder(z)

    def forward(self, x: torch.Tensor):
        # Encode the inputs
        mu, log_var = self.encode(x)
        # Obtain latent samples and latent loss
        z_tilde, kl_div = self.latent(x, mu, log_var)
        # Decode the samples
        x_tilde = self.decode(z_tilde)
        return x_tilde, kl_div

    def latent(self, mu: torch.Tensor, log_var: torch.Tensor):
        # reparametrization
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        z = eps * std + mu

        kl_div = -0.5 * torch.sum(1 + log_var**2 - mu**2 + log_var.exp())

        return z, kl_div

    def loss(self, x, x_tilde, kl_div, beta):
        return self.recons_criterion(x_tilde, x) + beta * kl_div

In [6]:
# load MNIST
valid_ratio = 0.2
# Load the dataset for the training/validation sets
train_valid_dataset = torchvision.datasets.MNIST(
    root="../data",
    train=True,
    transform=torchvision.transforms.ToTensor(),
    download=True,
)
# Split it into training and validation sets
nb_train = int((1.0 - valid_ratio) * len(train_valid_dataset))
nb_valid = int(valid_ratio * len(train_valid_dataset))
train_dataset, valid_dataset = torch.utils.data.dataset.random_split(
    train_valid_dataset, [nb_train, nb_valid]
)
# Load the test set
test_dataset = torchvision.datasets.MNIST(
    root="../data",
    transform=torchvision.transforms.ToTensor(), train=False
)
# Prepare
num_threads = 4  # Loading the dataset is using 4 CPU threads
batch_size = 128  # Using minibatches of 128 samples
train_loader = torch.utils.data.DataLoader(
    dataset=train_dataset, batch_size=batch_size, num_workers=num_threads
)
valid_loader = torch.utils.data.DataLoader(
    dataset=valid_dataset, batch_size=batch_size, shuffle=False, num_workers=num_threads
)
test_loader = torch.utils.data.DataLoader(
    dataset=test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_threads
)

5.3%

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../data\MNIST\raw\train-images-idx3-ubyte.gz


100.0%


Extracting ../data\MNIST\raw\train-images-idx3-ubyte.gz to ../data\MNIST\raw


100.0%


Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ../data\MNIST\raw\train-labels-idx1-ubyte.gz
Extracting ../data\MNIST\raw\train-labels-idx1-ubyte.gz to ../data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz



75.5%

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ../data\MNIST\raw\t10k-images-idx3-ubyte.gz


100.0%
100.0%

Extracting ../data\MNIST\raw\t10k-images-idx3-ubyte.gz to ../data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../data\MNIST\raw\t10k-labels-idx1-ubyte.gz
Extracting ../data\MNIST\raw\t10k-labels-idx1-ubyte.gz to ../data\MNIST\raw






In [None]:
encoder = nn.Sequential()
model = VAE()