In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Set device (use GPU if available)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Define the VAE architecture
class VAE(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(VAE, self).__init__()

        # Encoder
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc21 = nn.Linear(hidden_dim, latent_dim)  # Mean
        self.fc22 = nn.Linear(hidden_dim, latent_dim)  # Log variance

        # Decoder
        self.fc3 = nn.Linear(latent_dim, hidden_dim)
        self.fc4 = nn.Linear(hidden_dim, input_dim)

    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        return self.fc21(h1), self.fc22(h1)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h3))

    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, 784))
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

# Reconstruction + KL divergence losses summed over all elements and batch
def loss_function(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

# Hyperparameters
input_dim = 784  # MNIST images are 28x28
hidden_dim = 400
latent_dim = 20
epochs = 10
batch_size = 128
learning_rate = 1e-3

# Load MNIST dataset
train_dataset = datasets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# Initialize VAE and optimizer
vae = VAE(input_dim, hidden_dim, latent_dim)
optimizer = optim.Adam(vae.parameters(), lr=learning_rate)

# Training loop
for epoch in range(epochs):
    for i, (data, _) in enumerate(train_loader):
        # Forward pass
        recon_batch, mu, logvar = vae(data)
        loss = loss_function(recon_batch, data, mu, logvar)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (i+1) % 100 == 0:
            print(f'Epoch [{epoch+1}/{epochs}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item():.4f}')

# Now you can use the trained VAE to estimate the probability density of a given input:


Epoch [1/10], Step [100/469], Loss: 23232.7246
Epoch [1/10], Step [200/469], Loss: 19584.7051
Epoch [1/10], Step [300/469], Loss: 18241.0879
Epoch [1/10], Step [400/469], Loss: 16724.7910
Epoch [2/10], Step [100/469], Loss: 15556.2041
Epoch [2/10], Step [200/469], Loss: 15780.4580
Epoch [2/10], Step [300/469], Loss: 15856.6367
Epoch [2/10], Step [400/469], Loss: 15351.9600
Epoch [3/10], Step [100/469], Loss: 14989.8105
Epoch [3/10], Step [200/469], Loss: 14813.7441
Epoch [3/10], Step [300/469], Loss: 14900.6631
Epoch [3/10], Step [400/469], Loss: 14345.0283
Epoch [4/10], Step [100/469], Loss: 14324.7754
Epoch [4/10], Step [200/469], Loss: 14322.2012
Epoch [4/10], Step [300/469], Loss: 13736.7090
Epoch [4/10], Step [400/469], Loss: 14153.5312
Epoch [5/10], Step [100/469], Loss: 14251.2109
Epoch [5/10], Step [200/469], Loss: 14218.9512
Epoch [5/10], Step [300/469], Loss: 13760.1836
Epoch [5/10], Step [400/469], Loss: 14373.0508
Epoch [6/10], Step [100/469], Loss: 13889.6406
Epoch [6/10],

In [5]:
def estimate_pdf(x, vae):
    with torch.no_grad():
        mu, logvar = vae.encode(x.view(-1, 784))
        # Calculate the probability density in the latent space
        std = torch.exp(0.5 * logvar)
        p = torch.exp(-0.5 * ((mu - mu) / std).pow(2)) / (std * torch.sqrt(torch.tensor(2 * torch.pi)))  # Convert to tensor
        return p

# Example usage:
test_image = train_dataset[0][0]  # Get a test image from the dataset
pdf_estimate = estimate_pdf(test_image, vae)
print(pdf_estimate)

tensor([[1.0158, 2.9095, 1.2391, 2.6533, 1.6319, 0.9642, 2.1973, 2.2548, 1.5594,
         0.6841, 1.2004, 1.6546, 0.8141, 1.5383, 1.3646, 1.8661, 0.9563, 1.2787,
         1.6255, 1.1552]])
