Write a PyTorch implementation of a Variational Autoencoder (VAE) for the MNIST dataset with a 2D latent space. Please ensure that the following requirements are satisfied:

- Scale and normalize the dataset
- Division of the training dataset into training and validation sets.
- A convolutional neural network (CNN) for both the encoder and decoder.
- A Gaussian distribution for q(z|x) in the encoder
- A Categorical distribution for p(x|z) in the decoder and the 10 different digits in the dataset. use a softmax function
- The loss function is defined as a separate method, adjusted for the distribution used for the decoder.
- Use of GPU or MPS if available.

In [None]:
Here's a complete PyTorch implementation of a Variational Autoencoder (VAE) for the MNIST dataset with a 2D latent space, satisfying all the specified requirements:

```python
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision import datasets, transforms
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import os

# Check if GPU or MPS is available
if torch.cuda.is_available():
    device = torch.device('cuda')
    print("Using CUDA")
elif torch.backends.mps.is_available():
    device = torch.device('mps')
    print("Using MPS")
else:
    device = torch.device('cpu')
    print("Using CPU")

# 1. Load and preprocess the MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))  # Normalize the dataset
])

# Download MNIST dataset
mnist_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)

# Split training data into training and validation sets
train_size = int(0.8 * len(mnist_dataset))
val_size = len(mnist_dataset) - train_size
train_dataset, val_dataset = random_split(mnist_dataset, [train_size, val_size])

test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

batch_size = 128

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader   = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader  = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# 2. Define the VAE model
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()
        # Encoder CNN
        self.encoder = nn.Sequential(
            # Input channels = 1, output channels = 32
            nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1),  # -> [batch, 32, 14, 14]
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),  # -> [batch, 64, 7, 7]
            nn.ReLU(),
            nn.Flatten()
        )
        # Compute mean and log variance (for reparameterization trick)
        self.fc_mu = nn.Linear(64 * 7 * 7, 2)
        self.fc_logvar = nn.Linear(64 * 7 * 7, 2)

        # Decoder
        self.decoder_fc = nn.Linear(2, 64 * 7 * 7)
        self.decoder = nn.Sequential(
            nn.Unflatten(1, (64, 7, 7)),
            nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1),  # -> [batch, 32, 14, 14]
            nn.ReLU(),
            nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1),  # -> [batch, 16, 28, 28]
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((1, 1)),  # Global average pooling -> [batch, 16, 1, 1]
            nn.Flatten(),
            nn.Linear(16, 10),  # Output logits for 10 digits
            nn.Softmax(dim=1)   # Softmax to get probabilities
        )

    def encode(self, x):
        h = self.encoder(x)
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        # Reparameterization trick
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        h = self.decoder_fc(z)
        x_recon = self.decoder(h)
        return x_recon

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        x_recon = self.decode(z)
        return x_recon, mu, logvar

# 3. Define the loss function
def loss_function(recon_y, y, mu, logvar):
    # Reconstruction loss (cross-entropy)
    CE = F.cross_entropy(recon_y, y, reduction='sum')
    # KL divergence
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return CE + KLD

# Initialize the model and optimizer
model = VAE().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# 4. Training loop
num_epochs = 10

for epoch in range(1, num_epochs + 1):
    model.train()
    train_loss = 0

    for batch_idx, (data, target) in enumerate(train_loader):
        data = data.to(device)
        target = target.to(device)

        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        loss = loss_function(recon_batch, target, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()

    print(f'Epoch {epoch}, Training loss: {train_loss / len(train_loader.dataset):.4f}')

    # Validation loop
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for data, target in val_loader:
            data = data.to(device)
            target = target.to(device)
            recon_batch, mu, logvar = model(data)
            loss = loss_function(recon_batch, target, mu, logvar)
            val_loss += loss.item()
    val_loss /= len(val_loader.dataset)
    print(f'Validation loss: {val_loss:.4f}')

# 5. Testing and visualization (optional)
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
    for data, target in test_loader:
        data = data.to(device)
        target = target.to(device)
        recon_batch, mu, logvar = model(data)
        loss = loss_function(recon_batch, target, mu, logvar)
        test_loss += loss.item()
        pred = recon_batch.argmax(dim=1, keepdim=True)
        correct += pred.eq(target.view_as(pred)).sum().item()

test_loss /= len(test_loader.dataset)
accuracy = 100. * correct / len(test_loader.dataset)
print(f'Test loss: {test_loss:.4f}, Test accuracy: {accuracy:.2f}%')

# Visualize latent space
def plot_latent_space(model, device, num_batches=100):
    model.eval()
    zs = []
    ys = []
    with torch.no_grad():
        for i, (data, target) in enumerate(test_loader):
            data = data.to(device)
            target = target.to(device)
            mu, logvar = model.encode(data)
            z = model.reparameterize(mu, logvar)
            zs.append(z.cpu().numpy())
            ys.append(target.cpu().numpy())
            if i >= num_batches:
                break
    zs = np.concatenate(zs)
    ys = np.concatenate(ys)
    plt.figure(figsize=(8, 6))
    scatter = plt.scatter(zs[:, 0], zs[:, 1], c=ys, cmap='tab10', alpha=0.7)
    plt.colorbar(scatter)
    plt.title('Latent Space Visualization')
    plt.xlabel('z[0]')
    plt.ylabel('z[1]')
    plt.show()

plot_latent_space(model, device)
```

This code performs the following steps:

1. **Data Loading and Preprocessing**:

   - Downloads the MNIST dataset.
   - Scales and normalizes the dataset using `transforms.Normalize`.
   - Splits the training dataset into training and validation sets (80% training, 20% validation).

2. **Model Definition**:

   - Defines a `VAE` class with a convolutional encoder and decoder.
   - The encoder outputs the mean and log variance of the latent Gaussian distribution `q(z|x)`.
   - The decoder outputs a categorical distribution over the 10 digits using a softmax function.

3. **Loss Function**:

   - Defines `loss_function` that computes the cross-entropy loss between the predicted digit probabilities and the true labels.
   - Includes the KL divergence between the approximate posterior `q(z|x)` and the prior `p(z)`.

4. **Training Loop**:

   - Trains the VAE for 10 epochs.
   - Uses the optimizer `Adam` with a learning rate of `1e-3`.
   - Moves all tensors to the appropriate device (GPU/MPS/CPU).

5. **Testing and Visualization**:

   - Evaluates the model on the test set and computes the test loss and accuracy.
   - Visualizes the 2D latent space by plotting the latent variables colored by their true digit labels.

**Notes**:

- The model assumes that the decoder's goal is to predict the digit label from the latent space, effectively making the VAE perform classification.
- The latent space is two-dimensional, allowing for visualization.
- The code checks for GPU or MPS availability and uses it if possible.
- The training and validation losses are printed for each epoch.

**Dependencies**:

- `torch`: PyTorch library.
- `torchvision`: For datasets and transforms.
- `matplotlib`: For plotting the latent space.

**Instructions**:

- Make sure all the required libraries are installed.
- Run the code in an environment where you have access to a GPU or MPS for better performance.
- Adjust `num_epochs` and `batch_size` as needed.

This implementation satisfies all the specified requirements and provides a clear example of how to build and train a VAE with a categorical decoder on the MNIST dataset.