# Face Generation Inference

This notebook demonstrates how to use the trained face generation model for inference.

In [None]:
import sys
import os
sys.path.append('..')

import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import glob
from tqdm import tqdm

from src.encoder import get_encoder
from src.generator import get_generator
from src.utils import show_tensor_images, load_image, interpolate_embeddings

## 1. Load Models

In [None]:
# Configuration
EMBEDDING_SIZE = 256
Z_DIM = 512
IMG_SIZE = 128
CHECKPOINT_PATH = '../output/checkpoints/checkpoint_epoch_100.pth'  # Update with your checkpoint path
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
TEST_DIR = '../data/test'  # Path to test images

# Load models
encoder = get_encoder(embedding_size=EMBEDDING_SIZE, device=DEVICE)
generator = get_generator(
    z_dim=Z_DIM, 
    embedding_size=EMBEDDING_SIZE, 
    img_size=IMG_SIZE, 
    device=DEVICE
)

# Load checkpoint if exists
if os.path.exists(CHECKPOINT_PATH):
    checkpoint = torch.load(CHECKPOINT_PATH, map_location=DEVICE)
    generator.load_state_dict(checkpoint['generator'])
    print(f"Loaded checkpoint from epoch {checkpoint['epoch']+1}")
else:
    print(f"No checkpoint found at {CHECKPOINT_PATH}")
    
# Set models to evaluation mode
encoder.eval()
generator.eval()

## 2. Generate Faces from Test Images

In [None]:
# Get some test images
test_image_paths = glob.glob(os.path.join(TEST_DIR, '*'))[:16]  # Adjust number as needed

# Load images
test_images = []
for path in test_image_paths:
    img = load_image(path, img_size=IMG_SIZE).to(DEVICE)
    test_images.append(img)
    
test_images = torch.cat(test_images, dim=0)

# Display original test images
print("Original Test Images:")
show_tensor_images(test_images, num_images=16, figsize=(12, 8))

In [None]:
# Generate reconstructions
with torch.no_grad():
    # Extract embeddings
    embeddings = encoder(test_images)
    
    # Sample noise vectors
    z = torch.randn(embeddings.shape[0], Z_DIM).to(DEVICE)
    
    # Generate faces from embeddings
    generated_images = generator(z, embeddings)
    
# Display reconstructions
print("Generated Reconstructions:")
show_tensor_images(generated_images, num_images=16, figsize=(12, 8))

## 3. Style Mixing - Different Noise Vectors with Same Embedding

In [None]:
# Select a single test image
test_image = test_images[0].unsqueeze(0)

# Generate multiple variations with different noise vectors
num_variations = 8

with torch.no_grad():
    # Extract embedding
    embedding = encoder(test_image)
    
    # Repeat embedding
    embedding = embedding.repeat(num_variations, 1)
    
    # Generate different noise vectors
    z_vectors = torch.randn(num_variations, Z_DIM).to(DEVICE)
    
    # Generate faces
    style_mixed_images = generator(z_vectors, embedding)

# Show original image
print("Original Image:")
show_tensor_images(test_image, num_images=1)

# Show generated variations
print("Same Identity with Different Styles:")
show_tensor_images(style_mixed_images, num_images=num_variations, figsize=(12, 4))

## 4. Interpolate Between Face Embeddings

In [None]:
# Select two test images for interpolation
image1 = test_images[0].unsqueeze(0)
image2 = test_images[1].unsqueeze(0)

# Display the selected images
print("Image 1 and Image 2 for Interpolation:")
show_tensor_images(torch.cat([image1, image2], dim=0), num_images=2)

# Extract embeddings
with torch.no_grad():
    embedding1 = encoder(image1)
    embedding2 = encoder(image2)
    
    # Create interpolated embeddings
    interpolated_embeddings = interpolate_embeddings(embedding1, embedding2, steps=8)
    
    # Use same noise vector for all interpolated images
    z = torch.randn(1, Z_DIM).to(DEVICE)
    z = z.repeat(interpolated_embeddings.shape[0], 1)
    
    # Generate images from interpolated embeddings
    interpolated_images = generator(z, interpolated_embeddings)

# Display interpolation results
print("Interpolation Between Two Faces:")
show_tensor_images(interpolated_images, num_images=8, figsize=(15, 4))

## 5. Generate Faces from Random Embeddings

In [None]:
# Generate random embeddings
num_samples = 16
random_embeddings = torch.randn(num_samples, EMBEDDING_SIZE).to(DEVICE)
z = torch.randn(num_samples, Z_DIM).to(DEVICE)

# Generate faces from random embeddings
with torch.no_grad():
    random_faces = generator(z, random_embeddings)
    
# Display random generated faces
print("Faces Generated from Random Embeddings:")
show_tensor_images(random_faces, num_images=num_samples, figsize=(12, 8))

## 6. Face Manipulation - Component Analysis

In [None]:
# Generate several faces with the same noise but different embeddings
num_faces = 100
z = torch.randn(1, Z_DIM).to(DEVICE)
z = z.repeat(num_faces, 1)
random_embeddings = torch.randn(num_faces, EMBEDDING_SIZE).to(DEVICE)

# Generate faces
with torch.no_grad():
    faces = generator(z, random_embeddings)
    
    # Compute mean embedding
    mean_embedding = random_embeddings.mean(dim=0, keepdim=True)
    
    # Generate face with mean embedding
    mean_face = generator(z[:1], mean_embedding)
    
# Display mean face
print("Mean Face:")
show_tensor_images(mean_face, num_images=1)

In [None]:
# Get some sample faces
sample_faces = faces[:8]

# Display sample faces
print("Sample Generated Faces:")
show_tensor_images(sample_faces, num_images=8, figsize=(15, 4))

In [None]:
# Component manipulation (like PCA but simpler)
# Pick a random direction in embedding space
direction = torch.randn(EMBEDDING_SIZE).to(DEVICE)
direction = direction / direction.norm()  # Normalize

# Create variations by moving in that direction
variations = []
strengths = [-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0]

with torch.no_grad():
    for strength in strengths:
        # Modify the mean embedding
        modified_embedding = mean_embedding + direction * strength
        
        # Generate face
        modified_face = generator(z[:1], modified_embedding)
        variations.append(modified_face)
        
    variations = torch.cat(variations, dim=0)

# Display variations
print("Manipulating Face Attributes:")
show_tensor_images(variations, num_images=len(strengths), figsize=(15, 4))

## 7. Create Interpolation Video

In [None]:
# Install imageio if not already installed
# !pip install imageio

In [None]:
import imageio
from torchvision.utils import make_grid

def tensor_to_numpy(img_tensor):
    """Convert a tensor image to numpy image."""
    # Convert to CPU, detach from computation graph, and convert to numpy
    img = img_tensor.cpu().detach().numpy()
    # Change from CxHxW to HxWxC
    img = np.transpose(img, (1, 2, 0))
    # Convert from [-1, 1] to [0, 255]
    img = (img + 1) / 2
    img = (img * 255).astype(np.uint8)
    return img

def create_interpolation_video(filename, image1, image2, steps=60, fps=30):
    """Create a smooth interpolation video between two images."""
    # Extract embeddings
    with torch.no_grad():
        embedding1 = encoder(image1)
        embedding2 = encoder(image2)
        
        # Create interpolated embeddings
        interpolated_embeddings = interpolate_embeddings(embedding1, embedding2, steps=steps)
        
        # Use same noise vector for all interpolated images
        z = torch.randn(1, Z_DIM).to(DEVICE)
        z = z.repeat(interpolated_embeddings.shape[0], 1)
        
        # Generate images from interpolated embeddings
        frames = []
        for i in tqdm(range(interpolated_embeddings.shape[0])):
            emb = interpolated_embeddings[i:i+1]
            z_i = z[i:i+1]
            gen_img = generator(z_i, emb)
            
            # Convert to numpy
            frame = tensor_to_numpy(gen_img[0])
            frames.append(frame)
        
        # Create forward and backward loop
        frames_loop = frames + frames[::-1]
        
        # Save video
        imageio.mimsave(filename, frames_loop, fps=fps)
        
        print(f"Video saved to {filename}")

In [None]:
# Select two images for interpolation
image1 = test_images[0].unsqueeze(0)
image2 = test_images[3].unsqueeze(0)

# Display the selected images
print("Image 1 and Image 2 for Video Interpolation:")
show_tensor_images(torch.cat([image1, image2], dim=0), num_images=2)

# Create interpolation video
output_video = "../output/face_interpolation.mp4"
os.makedirs(os.path.dirname(output_video), exist_ok=True)
create_interpolation_video(output_video, image1, image2, steps=30, fps=15)

## 8. Generate Multiple Styles for a Single Identity

In [None]:
def generate_style_grid(reference_image, num_styles=16, figsize=(12, 8)):
    """Generate multiple style variations of a single face."""
    with torch.no_grad():
        # Extract embedding
        embedding = encoder(reference_image)
        
        # Repeat embedding
        embedding = embedding.repeat(num_styles, 1)
        
        # Generate different noise vectors
        z_vectors = torch.randn(num_styles, Z_DIM).to(DEVICE)
        
        # Generate faces
        style_variations = generator(z_vectors, embedding)
    
    # Display reference image
    plt.figure(figsize=(3, 3))
    plt.title("Reference Image")
    show_tensor_images(reference_image, num_images=1)
    
    # Display style variations
    plt.figure(figsize=figsize)
    plt.title("Style Variations")
    show_tensor_images(style_variations, num_images=num_styles)
    
    return style_variations

In [None]:
# Select a reference image
reference_image = test_images[5].unsqueeze(0)

# Generate style variations
variations = generate_style_grid(reference_image, num_styles=16)