# VAE Example Notebook

This notebook demonstrates how to use the GalGenAI library to build and test a Variational Autoencoder.

In [None]:
import torch
import matplotlib.pyplot as plt
from torchvision import datasets, transforms

from galgenai import VAE, get_device, get_device_name

## 1. Device Setup

In [None]:
device = get_device()
print(f"Using device: {get_device_name()}")

## 2. Create VAE Model

In [None]:
# Initialize model
model = VAE(in_channels=1, latent_dim=16, input_size=32)
model = model.to(device)

# Count parameters
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Model parameters: {num_params:,}")

## 3. Test Forward Pass

In [None]:
# Create random input
x = torch.randn(4, 1, 32, 32).to(device)
print(f"Input shape: {x.shape}")

# Forward pass
reconstruction, mu, logvar = model(x)
print(f"Reconstruction shape: {reconstruction.shape}")
print(f"Latent mean shape: {mu.shape}")
print(f"Latent log variance shape: {logvar.shape}")

## 4. Load MNIST Dataset

In [None]:
# Transform: pad MNIST to 32x32
transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Pad(2),  # 28x28 -> 32x32
    ]
)

# Load dataset
test_dataset = datasets.MNIST(
    root="../data", train=False, download=True, transform=transform
)
print(f"Test dataset size: {len(test_dataset)}")

## 5. Visualize Some Images

In [None]:
# Get some images
fig, axes = plt.subplots(2, 5, figsize=(10, 4))
for i in range(10):
    img, label = test_dataset[i]
    ax = axes[i // 5, i % 5]
    ax.imshow(img[0], cmap="gray")
    ax.set_title(f"Label: {label}")
    ax.axis("off")
plt.tight_layout()
plt.show()

## 6. Test Reconstruction (Untrained Model)

In [None]:
model.eval()

# Get a batch of images
images = torch.stack([test_dataset[i][0] for i in range(8)]).to(device)

# Get reconstructions
with torch.no_grad():
    reconstructions, _, _ = model(images)

# Visualize
fig, axes = plt.subplots(2, 8, figsize=(16, 4))
for i in range(8):
    # Original
    axes[0, i].imshow(images[i, 0].cpu(), cmap="gray")
    axes[0, i].axis("off")
    if i == 0:
        axes[0, i].set_title("Original", fontsize=12)

    # Reconstruction
    axes[1, i].imshow(reconstructions[i, 0].cpu(), cmap="gray")
    axes[1, i].axis("off")
    if i == 0:
        axes[1, i].set_title("Reconstructed", fontsize=12)

plt.tight_layout()
plt.show()
print("Note: This is an untrained model, so reconstructions will be poor.")

## 7. Test Sampling from Prior

In [None]:
# Generate samples
num_samples = 16
with torch.no_grad():
    samples = model.generate(num_samples, device)

# Visualize
fig, axes = plt.subplots(4, 4, figsize=(8, 8))
for i in range(4):
    for j in range(4):
        idx = i * 4 + j
        axes[i, j].imshow(samples[idx, 0].cpu(), cmap="gray")
        axes[i, j].axis("off")

plt.suptitle("Samples from Prior (Untrained Model)", fontsize=14)
plt.tight_layout()
plt.show()

## 8. Next Steps

To train the model, run the training script:

```bash
python scripts/train_mnist.py
```

Or implement your own training loop using the `galgenai.training` module.