Write a PyTorch implementation of a Variational Autoencoder (VAE) for the MNIST dataset with a 2D latent space. Please ensure that the following requirements are satisfied:

- Scale and normalize the dataset
- Division of the training dataset into training and validation sets.
- A convolutional neural network (CNN) for both the encoder and decoder.
- A Gaussian distribution for q(z|x) in the encoder
- A Categorical distribution for p(x|z) in the decoder.
- The loss function is defined as a separate method, adjusted for the distribution used for the decoder.
- Use of GPU or MPS if available.


In [None]:
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, random_split
import torchvision
import torchvision.transforms as transforms
import numpy as np

# Check for GPU or MPS availability
device = torch.device('cuda' if torch.cuda.is_available()
                      else 'mps' if torch.backends.mps.is_available()
                      else 'cpu')
print(f"Using device: {device}")

# Define the number of categories for the Categorical distribution (pixel intensity levels)
num_classes = 256  # Pixel values range from 0 to 255

# Custom transform to convert images to LongTensor with integer pixel values
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: (x * 255).long())  # Convert pixel values to integers [0,255]
])

# Load and preprocess the MNIST dataset
from torchvision.datasets import MNIST

train_dataset_full = MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = MNIST(root='./data', train=False, transform=transform, download=True)

# Split the training dataset into training and validation sets
train_size = int(0.9 * len(train_dataset_full))
val_size = len(train_dataset_full) - train_size
train_dataset, val_dataset = random_split(train_dataset_full, [train_size, val_size])

# Create DataLoaders
batch_size = 128

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Define the Encoder network using CNN
class Encoder(nn.Module):
    def __init__(self, latent_dim=2):
        super(Encoder, self).__init__()
        self.latent_dim = latent_dim
        # Convolutional layers
        self.conv1 = nn.Conv2d(1, 32, kernel_size=4, stride=2, padding=1)  # Output: 32 x 14 x 14
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1)  # Output: 64 x 7 x 7
        self.conv3 = nn.Conv2d(64, 128, kernel_size=7)  # Output: 128 x 1 x 1
        # Fully connected layers for mu and logvar
        self.fc_mu = nn.Linear(128 * 1 * 1, latent_dim)
        self.fc_logvar = nn.Linear(128 * 1 * 1, latent_dim)
        
    def forward(self, x):
        # Convolutional layers with activations
        x = F.relu(self.conv1(x))      # x: batch_size x 32 x 14 x14
        x = F.relu(self.conv2(x))      # x: batch_size x 64 x 7 x 7
        x = F.relu(self.conv3(x))      # x: batch_size x 128 x 1 x1
        x = x.view(-1, 128 * 1 * 1)    # Flatten
        # Output mu and logvar
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)
        return mu, logvar

# Define the Decoder network using CNN
class Decoder(nn.Module):
    def __init__(self, latent_dim=2, num_classes=256):
        super(Decoder, self).__init__()
        self.latent_dim = latent_dim
        self.num_classes = num_classes
        # Fully connected layer to expand z
        self.fc = nn.Linear(latent_dim, 128 * 7 * 7)
        # ConvTranspose layers to reconstruct image
        self.deconv1 = nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1)  # Output: 64 x 14 x 14
        self.deconv2 = nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1)   # Output: 32 x 28 x 28
        # The final layer outputs logits for num_classes categories per pixel
        self.conv_out = nn.Conv2d(32, num_classes, kernel_size=1)  # Output: num_classes x 28 x 28

    def forward(self, z):
        x = self.fc(z)                   # x: batch_size x (128*7*7)
        x = F.relu(x)
        x = x.view(-1, 128, 7, 7)        # x: batch_size x 128 x 7 x 7
        x = F.relu(self.deconv1(x))      # x: batch_size x 64 x 14 x 14
        x = F.relu(self.deconv2(x))      # x: batch_size x 32 x 28 x 28
        x = self.conv_out(x)             # x: batch_size x num_classes x 28 x 28
        return x  # Logits over num_classes for each pixel

# Define the VAE model combining the Encoder and Decoder
class VAE(nn.Module):
    def __init__(self, latent_dim=2, num_classes=256):
        super(VAE, self).__init__()
        self.latent_dim = latent_dim
        self.num_classes = num_classes
        self.encoder = Encoder(latent_dim)
        self.decoder = Decoder(latent_dim, num_classes)
        
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)    # Standard deviation
        eps = torch.randn_like(std)      # Sample epsilon
        return mu + eps * std

    def forward(self, x):
        mu, logvar = self.encoder(x)
        z = self.reparameterize(mu, logvar)
        x_recon_logits = self.decoder(z)
        return x_recon_logits, mu, logvar

    # Define the loss function adjusted for the Categorical distribution
    def loss_function(self, x_recon_logits, x, mu, logvar):
        batch_size = x.size(0)
        # Flatten x_recon_logits and x for loss computation
        x = x.view(batch_size, -1)  # Targets: batch_size x (H*W)
        x_recon_logits = x_recon_logits.view(batch_size, self.num_classes, -1)  # Outputs: batch_size x num_classes x (H*W)
        # Reconstruction loss using CrossEntropyLoss (suitable for Categorical distribution)
        reconstruction_loss = F.cross_entropy(x_recon_logits, x, reduction='sum')
        # KL divergence between q(z|x) and p(z)
        KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        return reconstruction_loss + KLD

# Instantiate the VAE model and optimizer
latent_dim = 2
vae = VAE(latent_dim=latent_dim, num_classes=num_classes).to(device)
optimizer = optim.Adam(vae.parameters(), lr=1e-3)

# Training loop
num_epochs = 10

for epoch in range(num_epochs):
    vae.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.to(device)  # data shape: batch_size x 1 x 28 x 28
        optimizer.zero_grad()
        x_recon_logits, mu, logvar = vae(data)
        loss = vae.loss_function(x_recon_logits, data, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
    avg_train_loss = train_loss / len(train_loader.dataset)
    print(f'Epoch {epoch+1}, Average Training Loss: {avg_train_loss:.4f}')
    
    # Validation phase
    vae.eval()
    val_loss = 0
    with torch.no_grad():
        for data, _ in val_loader:
            data = data.to(device)
            x_recon_logits, mu, logvar = vae(data)
            loss = vae.loss_function(x_recon_logits, data, mu, logvar)
            val_loss += loss.item()
    avg_val_loss = val_loss / len(val_loader.dataset)
    print(f'Epoch {epoch+1}, Average Validation Loss: {avg_val_loss:.4f}')
```