In [1]:
import torch
torch.cuda.set_device('cuda:0')
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# print(device)

# pip install torch-summary

import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image

import torch

# import torchsummary
import torchvision as tv
from torchvision import transforms, datasets
from torchvision.transforms import v2

from torch.utils.data import DataLoader, Dataset
from torchvision.utils import make_grid

from torch.distributions.normal import Normal
import torch.nn.functional as F

import torch.nn as nn
from torch.nn import ReLU
from torch.optim.lr_scheduler import _LRScheduler

from time import time

In [2]:

class CelebAEncoder(nn.Module):
    def __init__(self, latent_dim=1024):
        super(CelebAEncoder, self).__init__()
        # Four 2D convolutional layers.
        self.conv1 = nn.Conv2d(3, 2048, kernel_size=3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(2048, 1024, kernel_size=3, stride=2, padding=1)
        self.conv3 = nn.Conv2d(1024, 512, kernel_size=5, stride=2, padding=2)
        self.conv4 = nn.Conv2d(512, 256, kernel_size=5, stride=2, padding=2)
        # For a 64×64 input, after 4 conv layers the feature map becomes 4×4.
        self.fc = nn.Linear(256 * 4 * 4, 2048)
    
    def forward(self, x):
        x = F.relu(self.conv1(x))   # [B, 2048, 32, 32]
        x = F.relu(self.conv2(x))   # [B, 1024, 16, 16]
        x = F.relu(self.conv3(x))   # [B, 512, 8, 8]
        x = F.relu(self.conv4(x))   # [B, 256, 4, 4]
        x = x.view(x.size(0), -1)     # Flatten
        x = self.fc(x)              # [B, 2048]
        # Split into two halves: one for μ and one for log-σ (pre-softplus)
        mu, log_sigma = torch.chunk(x, 2, dim=1)
        # Ensure σ > 0 using softplus.
        sigma = F.softplus(log_sigma)
        return mu, sigma

class CelebADecoder(nn.Module):
    def __init__(self, latent_dim=1024):
        super(CelebADecoder, self).__init__()
        # Transform the latent vector into a seed feature map.
        self.fc = nn.Linear(latent_dim, 2048 * 4 * 4)
        # Four transposed convolutional layers.
        self.deconv1 = nn.ConvTranspose2d(2048, 1024, kernel_size=5, stride=2,
                                          padding=2, output_padding=1)
        self.deconv2 = nn.ConvTranspose2d(1024, 512, kernel_size=5, stride=2,
                                          padding=2, output_padding=1)
        self.deconv3 = nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2,
                                          padding=1, output_padding=1)
        self.deconv4 = nn.ConvTranspose2d(256, 3, kernel_size=3, stride=2,
                                          padding=1, output_padding=1)
    
    def forward(self, z):
        x = self.fc(z)
        x = x.view(-1, 2048, 4, 4)
        x = F.relu(self.deconv1(x))
        x = F.relu(self.deconv2(x))
        x = F.relu(self.deconv3(x))
        # Use sigmoid to map the output to [0, 1].
        x = torch.sigmoid(self.deconv4(x))
        return x

class CelebAVAE(nn.Module):
    def __init__(self, latent_dim=1024):
        super(CelebAVAE, self).__init__()
        self.encoder = CelebAEncoder(latent_dim)
        self.decoder = CelebADecoder(latent_dim)
    
    def reparameterize(self, mu, sigma):
        eps = torch.randn_like(sigma)
        return mu + eps * sigma
    
    def forward(self, x):
        mu, sigma = self.encoder(x)
        z = self.reparameterize(mu, sigma)
        recon = self.decoder(z)
        return recon, mu, sigma

#########################################
# Loss Function for the VAE
#########################################

def loss_function(recon_x, x, mu, sigma):
    # Reconstruction loss (BCE) summed over all pixels.
    BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')
    # KL divergence between the approximate posterior and the unit Gaussian.
    KL = -0.5 * torch.sum(1 + 2 * torch.log(sigma) - mu.pow(2) - sigma.pow(2))
    return BCE + KL

In [3]:

class ActorNetwork(nn.Module):
    def __init__(self, z_dim, attr_dim=None):
        """
        If attr_dim is provided, the network is conditioned on attribute labels.
        """
        super(ActorNetwork, self).__init__()
        self.z_dim = z_dim
        self.use_attr = attr_dim is not None
        input_dim = z_dim
        if self.use_attr:
            # Map attribute labels to a 2048-dimensional embedding.
            self.attr_fc = nn.Linear(attr_dim, 2048)
            input_dim += 2048

        # Four fully connected layers with 2048 outputs each.
        self.fc1 = nn.Linear(input_dim, 2048)
        self.fc2 = nn.Linear(2048, 2048)
        self.fc3 = nn.Linear(2048, 2048)
        self.fc4 = nn.Linear(2048, 2048)
        # Final layer produces 2*z_dim outputs (to split into δz and gate logits)
        self.fc_out = nn.Linear(2048, 2 * z_dim)
    
    def forward(self, z, y=None):
        """
        Args:
            z: latent vector [batch, z_dim]
            y: attribute labels [batch, attr_dim] (optional)
        Returns:
            Transformed z'
        """
        # Keep a copy of the original z for the residual update.
        z_orig = z
        if self.use_attr and y is not None:
            y_emb = F.relu(self.attr_fc(y))
            x = torch.cat([z, y_emb], dim=1)
        else:
            x = z
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = F.relu(self.fc4(x))
        x = self.fc_out(x)
        delta_z, gate_logits = torch.chunk(x, 2, dim=1)
        gates = torch.sigmoid(gate_logits)
        # Compute the updated latent vector:
        z_transformed = (1 - gates) * z_orig + gates * delta_z
        return z_transformed


class CriticNetwork(nn.Module):
    def __init__(self, z_dim, attr_dim=None):
        """
        If attr_dim is provided, the critic is conditioned on attribute labels.
        """
        super(CriticNetwork, self).__init__()
        self.z_dim = z_dim
        self.use_attr = attr_dim is not None
        input_dim = z_dim
        if self.use_attr:
            self.attr_fc = nn.Linear(attr_dim, 2048)
            input_dim += 2048

        # Four fully connected layers with 2048 outputs each.
        self.fc1 = nn.Linear(input_dim, 2048)
        self.fc2 = nn.Linear(2048, 2048)
        self.fc3 = nn.Linear(2048, 2048)
        self.fc4 = nn.Linear(2048, 2048)
        # Final layer produces a single output.
        self.fc_out = nn.Linear(2048, 1)
    
    def forward(self, z, y=None):
        """
        Args:
            z: latent vector [batch, z_dim]
            y: attribute labels [batch, attr_dim] (optional)
        Returns:
            Critic score in [0, 1]
        """
        if self.use_attr and y is not None:
            y_emb = F.relu(self.attr_fc(y))
            x = torch.cat([z, y_emb], dim=1)
        else:
            x = z
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = F.relu(self.fc4(x))
        x = self.fc_out(x)
        output = torch.sigmoid(x)
        return output



In [4]:
import torch
import matplotlib.pyplot as plt

# Hyperparameters and device settings
latent_dim = 1024
attr_dim = 40      # e.g., if you have 40 attribute values per sample
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Instantiate the networks and move them to the device.
actor = ActorNetwork(z_dim=latent_dim, attr_dim=attr_dim).to(device)
critic = CriticNetwork(z_dim=latent_dim, attr_dim=attr_dim).to(device)
vae = CelebAVAE().to(device)  # This is assumed to convert latent codes to images

# Load the pre-saved weights.
actor.load_state_dict(torch.load('actor_model.pt', map_location=device))
critic.load_state_dict(torch.load('critic_model.pt', map_location=device))
vae.load_state_dict(torch.load('celeba_vae.pth', map_location=device))  # adjust the filename as needed

decoder = vae.decoder

# Set the models to evaluation mode.
actor.eval()
critic.eval()
decoder.eval()

# data = pd.read_csv("./data/list_attr_celeba.csv").iloc[0,1:]
# data = np.array(data.replace(-1,0).values)
# print(data)

# Create a sample latent vector and an attribute vector.
# z_latent is a tensor of shape [1, latent_dim], and attributes is [1, attr_dim].
z_latent = torch.randn(1, latent_dim).to(device)
attributes = torch.zeros(1, attr_dim).to(device)  
# attributes = torch.tensor(data).to(device)

attributes[0][14] = 1

# Use the actor to transform the latent vector.
with torch.no_grad():
    z_transformed = actor(z_latent, attributes)

# Feed the transformed latent vector to the CelebA decoder to generate an image.
with torch.no_grad():
    generated_img = decoder(z_transformed)  # Expected output shape: [1, 3, H, W]

# Process the generated image for plotting.
# Remove the batch dimension and move to CPU.
image = generated_img.squeeze(0).cpu()  # shape: [3, H, W]

# Plot the image using matplotlib.
plt.figure(figsize=(6, 6))
# Convert from [C, H, W] to [H, W, C] for imshow.
plt.imshow(image.permute(1, 2, 0).numpy())
plt.axis('off')
plt.title("Generated CelebA Image")
plt.show()


  actor.load_state_dict(torch.load('actor_model.pt', map_location=device))


RuntimeError: Error(s) in loading state_dict for ActorNetwork:
	size mismatch for attr_fc.weight: copying a param with shape torch.Size([2048, 10]) from checkpoint, the shape in current model is torch.Size([2048, 40]).