In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import torch.nn.functional as F

In [2]:
# Step 1: Define the VAE architecture

class VAE(nn.Module):
    def __init__(self, latent_dim):
        super(VAE, self).__init__()
        
        # Encoder
        self.encoder_conv1 = nn.Conv2d(3, 32, kernel_size=4, stride=2, padding=1)
        self.encoder_conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1)
        self.encoder_conv3 = nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1)
        self.fc_mu = nn.Linear(128 * 4 * 4, latent_dim)
        self.fc_logvar = nn.Linear(128 * 4 * 4, latent_dim)
        
        # Decoder
        self.decoder_fc = nn.Linear(latent_dim, 128 * 4 * 4)
        self.decoder_conv1 = nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1)
        self.decoder_conv2 = nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1)
        self.decoder_conv3 = nn.ConvTranspose2d(32, 3, kernel_size=4, stride=2, padding=1)

    def encode(self, x):
        x = F.relu(self.encoder_conv1(x))
        x = F.relu(self.encoder_conv2(x))
        x = F.relu(self.encoder_conv3(x))
        x = x.view(x.size(0), -1)
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        epsilon = torch.randn_like(std)
        return mu + epsilon * std

    def decode(self, z):
        z = self.decoder_fc(z)
        z = z.view(z.size(0), 128, 4, 4)
        z = F.relu(self.decoder_conv1(z))
        z = F.relu(self.decoder_conv2(z))
        x_reconstructed = torch.sigmoid(self.decoder_conv3(z))
        return x_reconstructed

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        x_reconstructed = self.decode(z)
        return x_reconstructed, mu, logvar


In [3]:
# Step 2: Load and preprocess the CelebA dataset

from torchvision.datasets import ImageFolder

transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
])


# Define the path to the directory containing the CelebA images
celeba_data_path = 'data/celeba/'  # Replace with the actual path

# Create a custom dataset
train_dataset = ImageFolder(
    root=celeba_data_path,
    transform=transform  # Use the same transform as before
)

batch_size = 128
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

In [4]:
# Step 3: Define loss and optimizer

latent_dim = 128
vae = VAE(latent_dim)
optimizer = optim.Adam(vae.parameters(), lr=0.001)

In [5]:
# Step 4: Training loop

def loss_function(reconstructed_x, x, mu, logvar):
    # Define your VAE loss function here
    # For example, you can use the reconstruction loss and the KL divergence term
    pass

num_epochs = 10
for epoch in range(num_epochs):
    vae.train()
    total_loss = 0
    for batch_idx, data in enumerate(train_loader):
        optimizer.zero_grad()
        x = data[0]
        x_reconstructed, mu, logvar = vae(x)
        loss = loss_function(x_reconstructed, x, mu, logvar)
        loss.backward()
        total_loss += loss.item()
        optimizer.step()
    print(f'Epoch {epoch+1}/{num_epochs}, Loss: {total_loss / (batch_idx+1)}')

RuntimeError: mat1 and mat2 shapes cannot be multiplied (128x8192 and 2048x128)

In [None]:
# Step 5: Image interpolation

def interpolate_images(image1, image2, num_steps):
    # Encode both images
    mu1, logvar1 = vae.encode(image1)
    mu2, logvar2 = vae.encode(image2)

    # Interpolate in latent space
    interpolated_images = []
    for i in range(num_steps):
        t = i / (num_steps - 1)
        latent_interpolated = (1 - t) * mu1 + t * mu2
        reconstructed_interpolated = vae.decode(latent_interpolated)
        interpolated_images.append(reconstructed_interpolated)

    return interpolated_images

In [None]:
# Example of how to use interpolate_images
image1 = train_dataset[0][0].unsqueeze(0)  # Choose two images from the dataset
image2 = train_dataset[1][0].unsqueeze(0)
interpolated_images = interpolate_images(image1, image2, num_steps=10)

# Display the interpolated images
plt.figure(figsize=(12, 5))
for i, image in enumerate(interpolated_images):
    plt.subplot(2, 5, i+1)
    plt.imshow(image.squeeze().permute(1, 2, 0).detach().numpy())
    plt.axis('off')
plt.show()