# VAE MNIST Generation

This notebook demonstrates how to use the trained VAE model to generate new MNIST-like digits.

In [None]:
import sys
sys.path.append('../../')

import torch
import matplotlib.pyplot as plt
from src.models import BasicVAE
from src.generate_images import generate_images, plot_generated_images

In [None]:
# Set up device and load the trained model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = BasicVAE(input_dim=784, hidden_dim=400, latent_dim=20).to(device)
model.load_state_dict(torch.load('../../vae_model.pth'))

In [None]:
# Generate new images
num_images = 16
generated_images = generate_images(model, num_images, device)

# Plot the generated images
plot_generated_images(generated_images, 4, 4)

## Latent Space Visualization

Let's visualize how the latent space affects the generated images.

In [None]:
def interpolate_latent_space(model, start, end, steps=10):
    interpolation = torch.zeros((steps, model.fc_mu.out_features))
    for i, alpha in enumerate(torch.linspace(0, 1, steps)):
        interpolation[i] = start * (1 - alpha) + end * alpha
    
    with torch.no_grad():
        return model.decode(interpolation.to(device)).cpu()

# Generate two random points in latent space
z1 = torch.randn(1, model.fc_mu.out_features)
z2 = torch.randn(1, model.fc_mu.out_features)

# Interpolate between these two points
interpolated_images = interpolate_latent_space(model, z1, z2)

# Plot the interpolation
fig, axes = plt.subplots(1, 10, figsize=(20, 2))
for i, ax in enumerate(axes):
    ax.imshow(interpolated_images[i].view(28, 28), cmap='gray')
    ax.axis('off')
plt.show()