In [1]:
import time

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.distributions as tdist
import torch.utils.data
import torchvision
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets
from torchvision.transforms import v2
from torchvision.utils import save_image
from tqdm import tqdm

In [31]:
class VAE(nn.Module):
    def __init__(self, latent_dim=2, distribution="bernoulli"):
        super(VAE, self).__init__()
        self.latent_dim = latent_dim
        self.distribution = distribution

        # Encoder network
        # Convolutional layers for Encoder
        self.conv1 = nn.Conv2d(
            1, 32, kernel_size=4, stride=2, padding=1
        )  # Input: (batch, 1, 28, 28) -> Output: (batch, 32, 14, 14)
        self.conv2 = nn.Conv2d(
            32, 64, kernel_size=4, stride=2, padding=1
        )  # Input: (batch, 32, 14, 14) -> Output: (batch, 64, 7, 7)
        # Fully connected layers to produce mu and logvar
        self.fc_mu = nn.Linear(64 * 7 * 7, latent_dim)
        self.fc_logvar = nn.Linear(64 * 7 * 7, latent_dim)

        # Decoder network
        # Fully connected layer
        self.fc = nn.Linear(latent_dim, 64 * 7 * 7)
        # Transposed convolutional layers
        self.deconv1 = nn.ConvTranspose2d(
            64, 32, kernel_size=4, stride=2, padding=1
        )  # Input: (batch, 64, 7, 7) -> Output: (batch, 32, 14, 14)
        self.deconv2 = nn.ConvTranspose2d(
            32, 1, kernel_size=4, stride=2, padding=1
        )  # Input: (batch, 32, 14, 14) -> Output: (batch, 1, 28, 28)

        # Decoder network for gaussian with learned variance
        self.deconv3 = nn.ConvTranspose2d(
            32, 2, kernel_size=4, stride=2, padding=1
        )  # Input: (batch, 32, 14, 14) -> Output: (batch, 2, 28, 28)

    def encode(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = x.view(x.size(0), -1)  # Flatten
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)  # Standard deviation
        eps = torch.randn_like(std)  # Random tensor with same shape as std
        return mu + eps * std  # Reparameterization trick

    def decode(self, z):
        x = F.relu(self.fc(z))
        x = x.view(-1, 64, 7, 7)  # Reshape
        x = F.relu(self.deconv1(x))

        if self.distribution != "gaussian_with_learned_variance":
            x = self.deconv2(x)
            return torch.sigmoid(x)
        else:
            output = self.deconv3(x)
            # Split the output into mean and log-variance
            mu_x = output[:, 0, :, :]  # Mean of reconstructed image
            logvar_x = output[:, 1, :, :]  # Log-variance of reconstructed image
            return (torch.sigmoid(mu_x), F.softplus(logvar_x))

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        recon = self.decode(z)
        return recon, mu, logvar

In [3]:
# https://github.com/Robert-Aduviri/Continuous-Bernoulli-VAE
def sumlogC(x, eps=1e-5):
    """
    Numerically stable implementation of
    sum of logarithm of Continous Bernoulli
    constant C, using Taylor 2nd degree approximation

    Parameter
    ----------
    x : Tensor of dimensions (batch_size, dim)
        x takes values in (0,1)
    """
    x = torch.clamp(x, eps, 1.0 - eps)
    mask = torch.abs(x - 0.5).ge(eps)
    far = torch.masked_select(x, mask)
    close = torch.masked_select(x, ~mask)
    far_values = torch.log((torch.log(1.0 - far) - torch.log(far)).div(1.0 - 2.0 * far))
    close_values = torch.log(torch.tensor((2.0))) + torch.log(
        1.0 + torch.pow(1.0 - 2.0 * close, 2) / 3.0
    )
    return far_values.sum() + close_values.sum()


In [4]:
def sumlogC_optimized(x, eps=1e-5):
    """
    Optimized numerically stable implementation of
    sum of logarithm of Continuous Bernoulli constant C,
    using Taylor 2nd degree approximation.

    Parameters
    ----------
    x : Tensor of dimensions (batch_size, dim)
        x takes values in (0,1)
    eps : float, optional
        Small value to prevent numerical instability near 0 and 1
    """
    # Clamp x to avoid issues with log(0)
    x = torch.clamp(x, eps, 1.0 - eps)

    # Compute mask for elements far from 0.5
    mask = torch.abs(x - 0.5) >= eps

    # Precompute constants
    log2 = torch.log(torch.tensor(2.0))
    one_minus_2x = 1.0 - 2.0 * x

    # Compute 'far_values' for elements where |x - 0.5| >= eps
    numerator = torch.log1p(-x) - torch.log(x)  # log(1 - x) - log(x)
    denominator = one_minus_2x
    far_values = torch.log(numerator / denominator)

    # Compute 'close_values' using Taylor approximation for elements where |x - 0.5| < eps
    close_values = log2 + torch.log(1.0 + (one_minus_2x**2) / 3.0)

    # Use torch.where to select appropriate values based on the mask
    values = torch.where(mask, far_values, close_values)

    # Return the sum of values
    return values.sum()


In [5]:
def vae_loss(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x, x, reduction="sum")
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD


In [7]:
def KL_divergence(mu, logvar):
    q_z = tdist.Normal(loc=mu, scale=(0.5 * logvar).exp())
    p_z = tdist.Normal(loc=torch.zeros_like(mu), scale=torch.ones_like(logvar))
    KLD = tdist.kl_divergence(q_z, p_z).sum()
    return KLD


In [8]:
def cb_lambda_loss(recon_x, x, mu, logvar):
    tmp = tdist.ContinuousBernoulli(probs=recon_x)
    recon_x = tmp.mean
    BCE = F.binary_cross_entropy(recon_x, x, reduction="sum")
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    logC = sumlogC(recon_x)
    return BCE + KLD + logC


In [9]:
def gaussian_loss(recon_x, x, mu, logvar):
    batch_size = x.size(0)
    # Flatten recon_x and x to [batch_size, 784]
    recon_x = recon_x.view(batch_size, -1)
    x = x.view(batch_size, -1)

    # Reconstruction loss (assuming a fixed variance, can use MSE)
    MSE = F.mse_loss(recon_x, x, reduction="sum")
    # MSE /= batch_size  # Normalize by batch size

    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return MSE + KLD

In [29]:
def reconstruction_loss(x_reconstructed_mu, x_reconstructed_logvar, x_true):
    # x_true: [batch_size, height, width]
    # x_reconstructed_mu: [batch_size, height, width]
    # x_reconstructed_logvar: [batch_size, height, width]

    # Flatten the images
    x_true = x_true.view(x_true.size(0), -1)
    x_reconstructed_mu = x_reconstructed_mu.view(x_reconstructed_mu.size(0), -1)
    x_reconstructed_logvar = x_reconstructed_logvar.view(
        x_reconstructed_logvar.size(0), -1
    )

    # Compute the negative log-likelihood
    recon_loss = 0.5 * torch.sum(
        x_reconstructed_logvar
        + ((x_true - x_reconstructed_mu) ** 2) / torch.exp(x_reconstructed_logvar)
        + torch.log(torch.tensor(2) * torch.pi),  # Sum over pixels
    )
    return recon_loss  # Mean over the batch


def kl_divergence(mu_z, logvar_z):
    # mu_z and logvar_z are of shape [batch_size, latent_dim]
    kl_div = -0.5 * torch.sum(1 + logvar_z - mu_z.pow(2) - logvar_z.exp())
    return kl_div  # Mean over the batch


def loss_cont_gaussian(recon_x, x_true, mu_z, logvar_z):
    x_reconstructed_mu, x_reconstructed_logvar = recon_x
    recon_loss = reconstruction_loss(x_reconstructed_mu, x_reconstructed_logvar, x_true)
    kl_loss = kl_divergence(mu_z, logvar_z)
    return recon_loss + kl_loss

In [12]:
def beta_loss(alphas, betas, x, mu, logvar, beta_reg):
    x = x.view(-1, 784)
    recon_dist = tdist.Beta(alphas, betas)
    recon_x = recon_dist.mean
    recon_x = recon_x.view(-1, 784)
    BCE = F.binary_cross_entropy(recon_x, x, reduction="sum")
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD


In [13]:
def loss_fct(recon_x, x, mu, logvar, distribution):
    if distribution == "bernoulli":
        return vae_loss(recon_x, x, mu, logvar)
    elif distribution == "continuous_bernoulli":
        return cb_lambda_loss(recon_x, x, mu, logvar)
    elif distribution == "gaussian":
        return gaussian_loss(recon_x, x, mu, logvar)
    elif distribution == "gaussian_with_learned_variance":
        return loss_cont_gaussian(recon_x, x, mu, logvar)

In [14]:
EPOCHS = 100
BATCH_SIZE = 128

torch.manual_seed(1);


## Check for GPU or MPS availability. Use CPU if neither is available


In [None]:
device = torch.device(
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using device: {device}")


In [16]:
transform = v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)])

# DataLoader instances will load tensors directly into GPU memory if device is set to 'cuda'
kwargs = {"num_workers": 1, "pin_memory": True} if device == "cuda" else {}

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST("../data", train=True, download=True, transform=transform),
    batch_size=BATCH_SIZE,
    shuffle=True,
    **kwargs,
)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST("../data", train=False, transform=transform),
    batch_size=BATCH_SIZE,
    shuffle=False,
    **kwargs,
)

In [17]:
def train(model, optimizer, loss_fn, epoch, distribution):
    model.train()
    train_loss = 0
    for _, (data, _) in enumerate(
        tqdm(
            train_loader,
            desc=f"Distribution: {distribution} - Training Epoch {epoch}/{EPOCHS}",
        )
    ):
        data = data.to(device)
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        loss = loss_fn(recon_batch, data, mu, logvar, distribution)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()

    train_loss /= len(train_loader.dataset)
    print(f"Epoch: {epoch} Average loss: {train_loss:.4f}")


In [18]:
def test(model, loss_fn, epoch, distribution):
    model.eval()
    test_loss = 0

    with torch.no_grad():
        for i, (data, _) in enumerate(
            tqdm(
                test_loader,
                desc=f"Distribution: {distribution} - Test Epoch {epoch}/{EPOCHS}",
            )
        ):
            data = data.to(device)

            recon_batch, mu, logvar = model(data)
            loss = loss_fn(recon_batch, data, mu, logvar, distribution)
            test_loss += loss

            if i == 0:
                n = min(data.size(0), 8)
                if distribution == "gaussian_with_learned_variance":
                    recon_batch = recon_batch[0]
                recon_batch = recon_batch.view(BATCH_SIZE, 1, 28, 28)
                comparison = torch.cat([data[:n], recon_batch[:n]])

                save_image(
                    comparison.cpu(),
                    f"../images/{model.distribution}/reconstruction_"
                    + str(epoch)
                    + ".png",
                    nrow=n,
                )

    test_loss /= len(test_loader.dataset)
    print(f"Test set loss: {test_loss:.4f}")


In [19]:
latent_dim = 2


In [None]:
distributions = [
    "gaussian_with_learned_variance",
    # "gaussian",
    # "bernoulli",
    # "continuous_bernoulli",
]

for distribution in distributions:
    model = VAE(latent_dim=latent_dim, distribution=distribution).to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)

    for epoch in range(1, EPOCHS + 1):
        train(model, optimizer, loss_fct, epoch, model.distribution)
        test(model, loss_fct, epoch, model.distribution)
        if distribution != "gaussian_with_learned_variance":
            with torch.no_grad():
                sample = torch.randn(64, latent_dim).to(device)
                sample = model.decode(sample).cpu()
                save_image(
                    sample.view(64, 1, 28, 28),
                    f"../images/{model.distribution}/sample_" + str(epoch) + ".png",
                )

    timestr = time.strftime("%Y%m%d-%H%M%S")
    torch.save(model, f"../models/{model.distribution}_{timestr}.pt")


In [None]:
def imshow(img, title="MNIST Samples"):
    npimg = img.numpy()
    plt.figure(figsize=(7, 7))
    plt.title(title, fontsize=20)
    plt.axis("off")
    plt.imshow(np.transpose(npimg, (1, 2, 0)), interpolation="nearest")


# Visualize dataset
dataiter = iter(test_loader)
mnist_images, mnist_labels = next(dataiter)


model_vae = VAE(latent_dim=latent_dim, distribution="bernoulli").to(device)
model_cbvae = VAE(latent_dim=latent_dim, distribution="continuous_bernoulli").to(device)
model_gvae = VAE(latent_dim=latent_dim, distribution="gaussian").to(device)
model_gvvae = VAE(
    latent_dim=latent_dim, distribution="gaussian_with_learned_variance"
).to(device)


# Load models
model_vae = torch.load("../models/bernoulli_20250105-170534.pt")
model_cbvae = torch.load("../models/continuous_bernoulli_20250105-172440.pt")
model_gvae = torch.load("../models/gaussian_20250105-180253.pt")
model_gvvae = torch.load("../models/gaussian_with_learned_variance_20250105-185346.pt")

# Sample from models
model_vae.eval()
model_cbvae.eval()
model_gvae.eval()
model_gvvae.eval()

num_samples = 12

z = torch.randn(num_samples, 2).to(device)
sample_cbvae = model_cbvae.decode(z).cpu().view(num_samples, 1, 28, 28).detach()
sample_vae = model_vae.decode(z).cpu().view(num_samples, 1, 28, 28).detach()
sample_gvae = model_gvae.decode(z).cpu().view(num_samples, 1, 28, 28).detach()


mu_x, logvar_x = model_gvvae.decode(z)
# Step 3: Sample x ~ N(μ_x, diag(σ_x^2))
std_x = torch.exp(0.5 * logvar_x)
eps = torch.randn_like(logvar_x)
x_sampled = mu_x + eps * logvar_x
# x_sampled = torch.sigmoid(x_sampled)
sample_gvvae = x_sampled.cpu().view(num_samples, 1, 28, 28).detach()

# Plot

imshow(
    torchvision.utils.make_grid(mnist_images[:num_samples], num_samples),
    r"MNIST Data Samples",
)
imshow(
    torchvision.utils.make_grid(sample_cbvae[:num_samples], num_samples),
    r"Samples from $\mathcal{CB}$-VAE",
)
imshow(
    torchvision.utils.make_grid(sample_vae[:num_samples], num_samples),
    r"Samples from $\mathcal{B}$-VAE",
)
imshow(
    torchvision.utils.make_grid(sample_gvae[:num_samples], num_samples),
    r"Samples from $\mathcal{G}$-VAE",
)
imshow(
    torchvision.utils.make_grid(sample_gvvae[:num_samples], num_samples),
    r"Samples from $\mathcal{G}$-VAE learned variance",
)