Write a PyTorch implementation of a Variational Autoencoder (VAE) for the MNIST dataset with a 2D latent space. Please ensure that the following requirements are satisfied:

- Scale and normalize the dataset
- Division of the training dataset into training and validation sets.
- A convolutional neural network (CNN) for both the encoder and decoder.
- A Gaussian distribution for q(z|x) in the encoder
- A Continuous Bernoulli distributed for p(x|z) in the decoder
- The loss function is defined as a separate method, adjusted for the distribution used for the decoder.
- Use of GPU or MPS if available.


In [None]:
Here is a PyTorch implementation of a Variational Autoencoder (VAE) for the MNIST dataset with a 2D latent space, satisfying all the specified requirements:

```python
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
import numpy as np

# Check for GPU or MPS (Mac) availability
device = torch.device("cuda" if torch.cuda.is_available() else
                      "mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")

# =====================
# Dataset Preparation
# =====================
# Transformations: Convert images to tensors and normalize to [0,1]
transform = transforms.Compose([
    transforms.ToTensor(),
])

# Load MNIST dataset
dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

# Split training dataset into training and validation sets (90% train, 10% validation)
train_size = int(0.9 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# Data loaders
batch_size = 128
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)


# =====================
# Model Definition
# =====================
# Encoder using CNN
class Encoder(nn.Module):
    def __init__(self, latent_dim):
        super(Encoder, self).__init__()

        self.conv1 = nn.Conv2d(1, 32, 4, stride=2, padding=1)   # Output: [batch, 32, 14, 14]
        self.conv2 = nn.Conv2d(32, 64, 4, stride=2, padding=1)  # Output: [batch, 64, 7, 7]
        self.conv3 = nn.Conv2d(64, 128, 7)                      # Output: [batch, 128, 1, 1]

        # Fully connected layers for mean and log variance
        self.fc_mu = nn.Linear(128, latent_dim)
        self.fc_logvar = nn.Linear(128, latent_dim)

    def forward(self, x):
        # Convolutional layers with ReLU activation
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))  # Output: [batch, 128, 1, 1]
        x = x.view(x.size(0), -1)  # Flatten to [batch, 128]
        mu = self.fc_mu(x)
        log_var = self.fc_logvar(x)
        return mu, log_var


# Decoder using CNN
class Decoder(nn.Module):
    def __init__(self, latent_dim):
        super(Decoder, self).__init__()

        # Fully connected layer to map latent space to features
        self.fc = nn.Linear(latent_dim, 128)

        # Transposed convolutional layers
        self.deconv1 = nn.ConvTranspose2d(128, 64, kernel_size=7)              # Output: [batch, 64, 7, 7]
        self.deconv2 = nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1)  # Output: [batch, 32, 14, 14]
        self.deconv3 = nn.ConvTranspose2d(32, 1, kernel_size=4, stride=2, padding=1)   # Output: [batch, 1, 28, 28]

    def forward(self, z):
        x = self.fc(z)                     # [batch, 128]
        x = x.view(-1, 128, 1, 1)          # Reshape to [batch, 128, 1, 1]
        x = F.relu(self.deconv1(x))        # [batch, 64, 7, 7]
        x = F.relu(self.deconv2(x))        # [batch, 32, 14, 14]
        x = self.deconv3(x)                # [batch, 1, 28, 28]
        x = torch.sigmoid(x)               # Outputs in (0,1) for Continuous Bernoulli
        return x


# Variational Autoencoder combining Encoder and Decoder
class VAE(nn.Module):
    def __init__(self, latent_dim):
        super(VAE, self).__init__()

        self.encoder = Encoder(latent_dim)
        self.decoder = Decoder(latent_dim)

    def reparameterize(self, mu, log_var):
        # Reparameterization trick to sample z
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        z = mu + eps * std
        return z

    def forward(self, x):
        # Encode input to get mean and log variance
        mu, log_var = self.encoder(x)
        # Sample from latent space
        z = self.reparameterize(mu, log_var)
        # Decode to reconstruct input
        x_recon = self.decoder(z)
        return x_recon, mu, log_var


# =====================
# Loss Function
# =====================
def continuous_bernoulli_log_likelihood(x, lambda_param, eps=1e-6):
    """
    Computes the negative log-likelihood for the Continuous Bernoulli distribution.

    Args:
        x: Target images, values in [0,1]
        lambda_param: Decoder outputs, parameter λ ∈ (0,1)

    Returns:
        Negative log-likelihood per batch
    """
    # Clamp values to avoid numerical issues
    lambda_param = torch.clamp(lambda_param, eps, 1 - eps)
    x = torch.clamp(x, eps, 1 - eps)

    # x * log(λ) + (1 - x) * log(1 - λ)
    x_term = x * torch.log(lambda_param) + (1 - x) * torch.log(1 - lambda_param)

    # Compute log-normalization constant log C(λ)
    lambda_diff = 1 - 2 * lambda_param
    # Handle λ close to 0.5 separately
    close_to_half = torch.abs(lambda_diff) < 1e-2
    # For λ close to 0.5, log C(λ) ≈ log(2)
    log_C_lambda = torch.where(
        close_to_half,
        torch.log(torch.tensor(2.0, device=lambda_param.device)),
        torch.log((2 * torch.atanh(lambda_diff)) / lambda_diff)
    )

    # Negative log-likelihood
    nll = - (log_C_lambda + x_term)
    # Sum over all dimensions except batch
    nll = nll.view(nll.size(0), -1).sum(dim=1)
    return nll


def loss_function(x_recon, x, mu, log_var):
    # Reconstruction loss using Continuous Bernoulli negative log-likelihood
    recon_loss = continuous_bernoulli_log_likelihood(x, x_recon)
    # KL divergence between q(z|x) and p(z)
    kl_div = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp(), dim=1)
    # Total loss per sample
    total_loss = recon_loss + kl_div
    # Return average over batch
    return total_loss.mean(), recon_loss.mean(), kl_div.mean()


# =====================
# Training Loop
# =====================
# Initialize model, optimizer, and move model to device
latent_dim = 2
model = VAE(latent_dim).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# Training parameters
epochs = 20

for epoch in range(1, epochs + 1):
    model.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        # Forward pass
        x_recon, mu, log_var = model(data)
        # Compute loss
        loss, recon_loss, kl_div = loss_function(x_recon, data, mu, log_var)
        # Backpropagation
        loss.backward()
        train_loss += loss.item()
        optimizer.step()

        if batch_idx % 100 == 0:
            print(f'Epoch [{epoch}/{epochs}] Batch [{batch_idx}/{len(train_loader)}] '
                  f'Loss: {loss.item():.4f} Recon Loss: {recon_loss.item():.4f} KL Div: {kl_div.item():.4f}')

    # Validation loop
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for data, _ in val_loader:
            data = data.to(device)
            x_recon, mu, log_var = model(data)
            loss, _, _ = loss_function(x_recon, data, mu, log_var)
            val_loss += loss.item()
    val_loss /= len(val_loader)
    print(f'====> Epoch: {epoch} Average validation loss: {val_loss:.4f}')

print('Training complete.')
```

**Explanation:**

- **Device Selection:** The code checks for availability of CUDA (GPU) or MPS (Apple Silicon GPU) and sets the computation device accordingly.

- **Data Loading and Preprocessing:**
  - Uses `transforms.ToTensor()` to scale the MNIST images to the [0,1] range.
  - Splits the training dataset into training and validation sets (90% training, 10% validation).

- **Model Architecture:**
  - **Encoder:** Consists of convolutional layers that downsample the input image and fully connected layers that output the mean (`mu`) and log variance (`log_var`) of the latent Gaussian distribution.
  - **Decoder:** Uses transposed convolutional layers to reconstruct the image from the latent representation. The output layer uses a sigmoid activation function to ensure the outputs are in the (0,1) range, matching the Continuous Bernoulli distribution requirements.
  - **VAE Class:** Combines the encoder and decoder and includes the reparameterization trick necessary for backpropagation through stochastic nodes.

- **Loss Function:**
  - The `loss_function` computes the total loss as the sum of the reconstruction loss and the KL divergence.
  - The reconstruction loss is calculated using the negative log-likelihood of the Continuous Bernoulli distribution.
  - The KL divergence measures how much the learned latent distribution deviates from the prior (standard normal distribution).

- **Training Loop:**
  - For each epoch, the model is trained using batches from the training loader.
  - After each epoch, the model is evaluated on the validation set to monitor overfitting.
  - The losses are printed periodically for monitoring purposes.

- **GPU/MPS Usage:**
  - The model and data are moved to the selected device (`cuda`, `mps`, or `cpu`) to leverage hardware acceleration if available.

**Notes:**

- This implementation uses the Continuous Bernoulli distribution for modeling the pixel intensities in the decoder, which is more appropriate for continuous data in the [0,1] range compared to the standard Bernoulli distribution.
- The code includes handling for numerical stability, especially in computing the normalization constant of the Continuous Bernoulli distribution.
- The latent dimension is set to 2 to allow for visualization of the latent space if desired.

**Dependencies:**

- PyTorch
- torchvision
- Numpy

Make sure to install all necessary packages and have access to the MNIST dataset (the code will download it if not present).