In [None]:
import torchvision.transforms as transforms
import torchvision

transform = transforms.Compose([transforms.ToTensor()])
training_images = torchvision.datasets.MNIST(
    root="./data", train=True, transform=transform, download=True
)

In [None]:
import torch
from models import Encoder, Decoder, CategoricalVAE

batch_size = 1
train_dataset = torch.utils.data.DataLoader(
    dataset=training_images, batch_size=batch_size, shuffle=True
)

image_shape = next(iter(train_dataset))[0][0].shape  # [1, 28, 28]
K = 26  # number of classes
N = 3  # number of categorical distributions

encoder = Encoder(N, K, image_shape)
decoder = Decoder(N, K, image_shape)
model = CategoricalVAE(encoder, decoder)

state_dict = torch.load("outputs/default/save_49999.pt", weights_only=True)
model.load_state_dict(state_dict)

In [None]:
batch = next(iter(train_dataset))
x, labels = batch
print(x.shape)

In [None]:
import matplotlib.pyplot as plt

plt.imshow(x.squeeze(), cmap="gray")
plt.axis("off")
plt.show()

In [None]:
import matplotlib.pyplot as plt

with torch.no_grad():
    phi, x_hat = model(x, temperature=1.0)
plt.imshow(x_hat.squeeze(), cmap="gray")
plt.axis("off")
plt.show()

In [None]:
plt.imshow(phi.squeeze(), cmap="gray")
plt.axis("off")
plt.show()

In [None]:
from models import gumbel_softmax

z_given_x = gumbel_softmax(phi, temperature=1.0, hard=True, batch=True)
plt.imshow(z_given_x.squeeze(), cmap="gray")
plt.axis("off")
plt.show()

In [None]:
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid
import numpy as np

fig = plt.figure(figsize=(8.0, 8.0))
grid = ImageGrid(
    fig,
    111,  # similar to subplot(111)
    nrows_ncols=(4, 4),  # creates 2x2 grid of axes
    axes_pad=0.15,  # pad between axes in inch.
)

for ax in grid:
    # Iterating over the grid returns the Axes.
    z_given_x = gumbel_softmax(phi, temperature=5.0, hard=True, batch=True)
    with torch.no_grad():
        x_hat = model.decoder(z_given_x)
    ax.imshow(x_hat.squeeze(), cmap="gray")
    ax.axis("off")

plt.show()