In [None]:
#pip install torch torchvision

In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

In [5]:
class VAE(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(VAE, self).__init__()

        # Encoder layers
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc21 = nn.Linear(hidden_dim, latent_dim)
        self.fc22 = nn.Linear(hidden_dim, latent_dim)

        # Decoder layers
        self.fc3 = nn.Linear(latent_dim, hidden_dim)
        self.fc4 = nn.Linear(hidden_dim, input_dim)

    def encode(self, x):
        h = F.relu(self.fc1(x))
        mu = self.fc21(h)
        log_var = self.fc22(h)
        return mu, log_var

    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + eps * std

    # def decode(self, z):
    #     h = F.relu(self.fc3(z))
    #     return torch.sigmoid(self.fc4(h))

    def decode(self, z):
        h = F.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h))  # Ensure output is in [0, 1]

    def forward(self, x):
        mu, log_var = self.encode(x)
        z = self.reparameterize(mu, log_var)
        return self.decode(z), mu, log_var

In [6]:
def vae_loss(recon_x, x, mu, log_var):
    recon_loss = F.binary_cross_entropy(recon_x, x, reduction='sum')
    kl_divergence = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
    return recon_loss + kl_divergence

In [7]:
# Example usage
input_dim = 784  # Assuming MNIST-like dataset
hidden_dim = 256
latent_dim = 64

In [8]:
# Initialize the VAE
vae = VAE(input_dim, hidden_dim, latent_dim)

In [9]:
# Define optimizer
optimizer = optim.Adam(vae.parameters(), lr=0.001)

In [10]:
# Training loop
num_epochs = 50
batch_size = 128

In [11]:
# Create DataLoader for MNIST
batch_size = 128
transform = transforms.Compose([
    transforms.ToTensor(),                       # Converts [0, 255] to [0.0, 1.0]
    #transforms.Normalize((0.5,), (0.5,))
])
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)


In [12]:
for epoch in range(num_epochs):
    for batch_idx, (data, _) in enumerate(dataloader):
        # Flatten the input data
        x = data.view(-1, input_dim)

        # Forward pass
        recon_x, mu, log_var = vae(x)

        # Calculate loss
        loss = vae_loss(recon_x, x, mu, log_var)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Print progress
        if batch_idx % 100 == 0:
            print(f"Epoch [{epoch+1}/{num_epochs}], Step [{batch_idx}/{len(dataloader)}], Loss: {loss.item()/batch_size:.4f}")


Epoch [1/50], Step [0/469], Loss: 545.6916
Epoch [1/50], Step [100/469], Loss: 194.9076
Epoch [1/50], Step [200/469], Loss: 164.1754
Epoch [1/50], Step [300/469], Loss: 161.9988
Epoch [1/50], Step [400/469], Loss: 145.6302
Epoch [2/50], Step [0/469], Loss: 138.6867
Epoch [2/50], Step [100/469], Loss: 134.8753
Epoch [2/50], Step [200/469], Loss: 136.1489
Epoch [2/50], Step [300/469], Loss: 132.1799
Epoch [2/50], Step [400/469], Loss: 123.6593
Epoch [3/50], Step [0/469], Loss: 129.8332
Epoch [3/50], Step [100/469], Loss: 118.4102
Epoch [3/50], Step [200/469], Loss: 118.5678
Epoch [3/50], Step [300/469], Loss: 121.3571
Epoch [3/50], Step [400/469], Loss: 122.5141
Epoch [4/50], Step [0/469], Loss: 124.2179
Epoch [4/50], Step [100/469], Loss: 120.5249
Epoch [4/50], Step [200/469], Loss: 115.2439
Epoch [4/50], Step [300/469], Loss: 114.7247
Epoch [4/50], Step [400/469], Loss: 114.8214
Epoch [5/50], Step [0/469], Loss: 114.6236
Epoch [5/50], Step [100/469], Loss: 113.7574
Epoch [5/50], Step [

In [13]:
import matplotlib.pyplot as plt

In [23]:
def plot_latent_space(vae, dataloader):
    latents = []
    labels = []
    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(dataloader):
            data = data.view(-1, input_dim)
            _, mu, _ = vae(data)
            latents.append(mu)
            labels.append(target)
        latents = torch.cat(latents, dim=0)
        labels  = torch.cat(labels, dim=0)

    plt.figure(figsize=(10, 8))
    plt.scatter(latents[:, 0], latents[:, 1], c=labels, cmap='rainbow')
    plt.colorbar()
    plt.xlabel('Latent Dimension 1')
    plt.ylabel('Latent Dimension 2')
    plt.title('Latent Space Visualization')
    plt.show()

In [14]:
def generate_samples(vae, num_samples):
    with torch.no_grad():
        z = torch.randn(num_samples, latent_dim)
        samples = vae.decode(z)

    samples = samples.view(-1, 1, 28, 28)  # Reshape samples if needed
    samples = samples * 0.5 + 0.5  # Denormalize if needed

    grid = torchvision.utils.make_grid(samples, nrow=int(num_samples ** 0.5), padding=2)
    plt.figure(figsize=(8, 8))
    plt.imshow(grid.permute(1, 2, 0))
    plt.axis('off')
    plt.title('Randomly Generated Samples')
    plt.show()

In [15]:
# Plot latent space
plot_latent_space(vae, dataloader)

NameError: name 'plot_latent_space' is not defined

In [3]:
# Generate samples
num_samples = 25
generate_samples(vae, num_samples)

NameError: name 'vae' is not defined