https://www.tensorflow.org/tutorials/generative/cvae

In [None]:
from typing import Optional

import auto_compyute as ac
from auto_compyute import Array  
import auto_compyute.nn.functional as F
from auto_compyute import nn

ac.backends.set_random_seed(0)
device = "cuda" if ac.backends.gpu_available() else "cpu"

In [None]:
import pandas as pd

# download the datasets
# train_url = "https://pjreddie.com/media/files/mnist_train.csv"
train_images = pd.read_csv("../data/mnist_train.csv", header=None)
train_images = ac.array(train_images.to_numpy())[:, 1:]

# test_url = "https://pjreddie.com/media/files/mnist_test.csv"
test_images = pd.read_csv("../data/mnist_test.csv", header=None)
test_images = ac.array(test_images.to_numpy())[:, 1:]

In [None]:
def preprocess_images(images: Array):
    images = images.view(images.shape[0], 1, 28, 28) / 255.0
    return ac.where(images > .5, 1.0, 0.0).float()

train_images = preprocess_images(train_images)
test_images = preprocess_images(test_images)

In [None]:
batch_size = 256
train_dl = nn.Dataloader((train_images,), batch_size=batch_size, device=device)
test_dl = nn.Dataloader((test_images,), batch_size=batch_size, device=device)

In [None]:
class CVAE(nn.Module):
    """Convolutional variational autoencoder."""

    def __init__(self, latent_dim) -> None:
        super().__init__()
        self.latent_dim = latent_dim

        self.encoder = nn.Sequential(
            nn.Conv2D(1, 32, kernel_size=3, stride=2), nn.ReLU(),
            nn.Conv2D(32, 64, kernel_size=3, stride=2), nn.ReLU(),
            nn.Flatten(),
            nn.Linear(2304, latent_dim + latent_dim)
        )

        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 32*7*7), nn.ReLU(),
            nn.Reshape((32, 7, 7)),
            nn.ConvTranspose2D(32, 64, kernel_size=3, stride=2, padding=1, output_padding=1), nn.ReLU(),
            nn.ConvTranspose2D(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1), nn.ReLU(),
            nn.ConvTranspose2D(32, 1, kernel_size=3, stride=1, padding=1)
        )

    def sample(self, eps: Optional[Array] = None) -> Array:
        if eps is None:
            eps = ac.randn(100, self.latent_dim).to(self.device)
        return self.decode(eps, apply_sigmoid=True)

    def encode(self, x: Array) -> tuple[Array, Array]:
        mean, logvar = self.encoder(x).split(2, dim=1)
        return mean, logvar

    def reparameterize(self, mean: Array, logvar: Array) -> Array:
        eps = ac.randn_like(mean)
        return eps * (logvar * 0.5).exp() + mean

    def decode(self, z: Array, apply_sigmoid: bool = False) -> Array:
        logits = self.decoder(z)
        if apply_sigmoid:
            probs = F.sigmoid(logits)
            return probs
        return logits
    
    def forward(self, x: Array) -> Array:
        pass

https://www.tensorflow.org/tutorials/generative/cvae

In [None]:
epochs = 10
latent_dim = 2
num_examples_to_generate = 16

random_vector_for_generation = ac.randn(num_examples_to_generate, latent_dim).to(device)
model = CVAE(latent_dim).to(device)

In [None]:
import math

optimizer = nn.optimizers.Adam(model.parameters(), learning_rate=1e-4)


def log_normal_pdf(sample: Array, mean: Array | float, logvar: Array | float, dim=1):
    log2pi = math.log(2.0 * math.pi)
    return (-0.5 * ((sample - mean) ** 2.0 * ac.array(-logvar).exp() + logvar + log2pi)).sum(dim)


def compute_loss(model, x):
    mean, logvar = model.encode(x)
    z = model.reparameterize(mean, logvar)
    x_logit = model.decode(z)

    logpx_z = -F.bce_loss(x_logit, x, reduction="sum")
    logpz = log_normal_pdf(z, 0., 0.)
    logqz_x = log_normal_pdf(z, mean, logvar)
    return -(logpx_z + logpz - logqz_x).mean()


def train_step(model: nn.Module, x: Array, optimizer: nn.optimizers.Optimizer):
    model.train()
    loss = compute_loss(model, x)
    loss.backward()
    optimizer.update_params()
    optimizer.reset_param_grads()

In [None]:
import matplotlib.pyplot as plt

def generate_and_save_images(model, epoch, test_sample):
    mean, logvar = model.encode(test_sample.to(device))
    z = model.reparameterize(mean, logvar)
    predictions = model.sample(z)
    fig = plt.figure(figsize=(4, 4))

    print(predictions.shape)

    for i in range(predictions.shape[0]):
        plt.subplot(4, 4, i + 1)
        plt.imshow(predictions[i, 0, :, :].cpu().data, cmap='gray')
        plt.axis('off')
        
    plt.show()

In [None]:
# Pick a sample of the test set for generating output images
assert batch_size >= num_examples_to_generate
test_sample = test_images[0:num_examples_to_generate, :, :, :]

generate_and_save_images(model, 0, test_sample)

In [None]:
import time
from IPython import display

for epoch in range(1, epochs + 1):
    start_time = time.time()
    for train_x, in train_dl():
        train_step(model, train_x, optimizer)
    end_time = time.time()

    loss = 0.0
    model.eval()
    with ac.no_autograd_tracing(): 
        for test_x, in test_dl():
            loss += compute_loss(model, test_x).item()
        elbo = -(loss / len(test_dl))
    display.clear_output(wait=False)
    print(f"Epoch: {epoch}, Test set ELBO: {elbo:.2f}, time elapse for current epoch: {end_time - start_time:.2f} s")
    generate_and_save_images(model, epoch, test_sample)