# SmileGAN-PTI: Counterfactual Image Generation for Visual Attribute Editing

## Author: Mohammad Mosaffa (mm3322@cornell.edu)

In [None]:
#@title Install Dependencies
!pip install ninja imageio-ffmpeg ipywidgets matplotlib torchvision gdown
!pip install face-alignment
!pip install scipy scikit-image
!pip install lpips
!pip install ftfy regex tqdm
!pip install git+https://github.com/openai/CLIP.git
!pip install clip

## Setup: Clone e4e Repository & Download Pre‑trained Models

This section prepares the environment by cloning the **`encoder4editing` (e4e)** repository and downloading four pre‑trained assets that work together during inversion, fine‑tuning, and expression editing.

### Core Models

1. **`e4e_ffhq_encode.pt` (e4e Encoder)**
   * **Type:** Encoder  
   * **Purpose:** Converts a real input face image into a StyleGAN‑friendly latent code.  
   * **Key Feature:** Latent codes are deliberately positioned to remain *highly editable* while still preserving identity.  
   * **Dataset:** Trained on FFHQ (high‑quality human faces).

2. **`stylegan2‑ffhq‑config‑f.pt` (StyleGAN2 Generator)**
   * **Type:** Generator (Image Synthesizer)  
   * **Purpose:** Takes a latent code—original or edited—and synthesises a 1024 × 1024 face.  
   * **Key Feature:** Industry‑standard visual fidelity and detail.  
   * **Dataset:** Trained on FFHQ, matching the encoder.

### Add‑on Components

3. **`smile.pt` (Expression Direction)**
   * **Type:** Latent‑space edit vector (W⁺)  
   * **Purpose:** When added to the latent code, it biases mid‑level layers toward a *smiling* expression.  
   * **Key Feature:** Learned with InterfaceGAN; neutral enough to leave pose, hair and lighting unaffected.  
   * **Usage:** Scaled by a factor `α` (e.g., 0.7) and injected only into expression‑controlling layers (8–11).

4. **`model_ir_se50.pth` (ArcFace IR‑SE‑50)**
   * **Type:** Face‑recognition network  
   * **Purpose:** Provides an *identity loss* during PTI so the generator fine‑tunes without drifting from the person’s real‑world likeness.  
   * **Key Feature:** Lightweight but strong at capturing identity embeddings; widely used in PTI/e4e pipelines.  
   * **Dataset:** Trained on MS‑1M and refined for FFHQ compatibility.

---

### How They Work Together

1. **Encode** the photograph with the *e4e Encoder* → latent code *W⁺*.  
2. **Fine‑tune** the *StyleGAN2 Generator* via PTI, using *ArcFace IR‑SE‑50* to keep identity intact.  
3. **Edit** the code by adding `α × smile.pt` in expression layers.  
4. **Synthesize** the final 1024‑px smiling portrait with the tuned generator.

---

## ⚠️ Model‑Download Issues?

If automatic downloads fail:

| Asset | Direct Link | Target Path |
|-------|-------------|-------------|
| e4e Encoder | `https://drive.google.com/uc?id=1h2uBzmFhbuox2g9-vKBC9iTcxS5IS1pk` | `e4e/pretrained_models/e4e_ffhq_encode.pt` |
| StyleGAN2 Generator | `https://drive.google.com/uc?id=1V4rGkG8H7QtO_uVfrvy3aZhHFi3fY4UZ` | `e4e/pretrained_models/stylegan2-ffhq-config-f.pt` |
| Smile Direction | `https://raw.githubusercontent.com/yuval-alaluf/latent-diffusion-directions/main/edit_directions/smile.pt` | `e4e/pretrained_models/smile.pt` |
| ArcFace IR‑SE‑50 | `https://github.com/TreB1eN/InsightFace_Pytorch/raw/master/model_ir_se50.pth` | `e4e/pretrained_models/model_ir_se50.pth` |

Make sure all four files reside in **`e4e/pretrained_models/`** before running the pipeline.


In [None]:
!git clone https://github.com/omertov/encoder4editing.git e4e
%cd e4e

import os
import gdown

os.makedirs("pretrained_models", exist_ok=True)

# Download e4e encoder checkpoint
e4e_url = "https://drive.google.com/uc?id=1h2uBzmFhbuox2g9-vKBC9iTcxS5IS1pk"
e4e_output = "pretrained_models/e4e_ffhq_encode.pt"
gdown.download(e4e_url, e4e_output, quiet=False)

# Download StyleGAN2 FFHQ generator
stylegan_url = "https://drive.google.com/uc?id=1V4rGkG8H7QtO_uVfrvy3aZhHFi3fY4UZ"
stylegan_output = "pretrained_models/stylegan2-ffhq-config-f.pt"
gdown.download(stylegan_url, stylegan_output, quiet=False)

# smile latent direction
!git clone https://github.com/yuval-alaluf/latent-diffusion-directions.git latent_dirs

shutil.copy(
    "latent_dirs/edit_directions/smile.pt",
    "pretrained_models/smile.pt",
)
print("smile.pt copied to pretrained_models/")

# ArcFace IR‑SE‑50 (ID loss net)
!git clone https://github.com/TreB1eN/InsightFace_Pytorch.git arcface_tmp
shutil.copy(
    "arcface_tmp/model_ir_se50.pth",
    "pretrained_models/model_ir_se50.pth",
)
print("model_ir_se50.pth copied to pretrained_models/")


%cd /content/e4e

##  Block 1 — Environment Setup & Imports

This cell loads the core libraries (PyTorch, LPIPS, face_alignment, NumPy, PIL, etc.), and tries to spin up the face‑landmark detector—warning you if the package is missing.

It also wipes any remnants of previous runs and creates a clean folder structure: input_images/ for your source photos, output_smiles/ for the edited results, and optimized_data/ for latents and intermediate checkpoints. Absolute paths to all required weight files are defined here, along with quick knobs (SMILE_ALPHA, layer range) for controlling the smile edit later on.

In [None]:
import torch
import torch.nn.functional as F
import torch.optim as optim
from torchvision import transforms
from torchvision.utils import save_image
from PIL import Image
import os
import numpy as np
from skimage import transform as trans
import matplotlib.pyplot as plt
import lpips
import glob
import random
import sys
import copy
import shutil

try:
    import face_alignment
except ImportError:
    face_alignment = None

device = 'cuda' if torch.cuda.is_available() else 'cpu'

if face_alignment:
    try:
        fa = face_alignment.FaceAlignment(face_alignment.LandmarksType.TWO_D, flip_input=False, device=device)
    except Exception:
        fa = None
else:
    fa = None

BASE_PROJECT_PATH = "/content/"
E4E_BASE_PATH = os.path.join(BASE_PROJECT_PATH, "e4e")

INPUT_FOLDER_PATH = os.path.join(BASE_PROJECT_PATH, "input_images")
OUTPUT_FOLDER_PATH = os.path.join(BASE_PROJECT_PATH, "output_smiles")
OPTIMIZED_DATA_FOLDER = os.path.join(OUTPUT_FOLDER_PATH, "optimized_data")

print(f"Please ensure your input images are placed in: {INPUT_FOLDER_PATH}")

if os.path.exists(OUTPUT_FOLDER_PATH):
    try:
        shutil.rmtree(OUTPUT_FOLDER_PATH)
    except OSError:
        pass

os.makedirs(INPUT_FOLDER_PATH, exist_ok=True)
os.makedirs(OUTPUT_FOLDER_PATH, exist_ok=True)
os.makedirs(OPTIMIZED_DATA_FOLDER, exist_ok=True)

e4e_checkpoint_path = os.path.join(E4E_BASE_PATH, 'pretrained_models/e4e_ffhq_encode.pt')
smile_direction_path = os.path.join(E4E_BASE_PATH, 'pretrained_models/smile.pt')
id_loss_checkpoint_path_absolute = os.path.join(E4E_BASE_PATH, "pretrained_models/model_ir_se50.pth")

processed_image_pairs = []

Please ensure your input images are placed in: /content/input_images


## Block 2 — Face Alignment Function

Block 2 defines a function that detects facial landmarks in an input image, aligns the face to a standardized 256×256 format, and optionally saves a landmark verification plot for visual checking. It returns both the aligned face image and the landmark coordinates, preparing the data for downstream editing steps.

In [None]:
def align_face(image_path, fa_model, output_path_prefix=""):
    if fa_model is None:
        raise RuntimeError("Face alignment model (fa) is not initialized. Cannot proceed with alignment.")

    print(f"Aligning face for image: {image_path}")
    try:
        image = Image.open(image_path).convert('RGB')
    except FileNotFoundError:
        raise FileNotFoundError(f"Input image not found at {image_path}")

    img_array = np.array(image)
    landmarks_list = fa_model.get_landmarks(img_array)

    if landmarks_list is None or len(landmarks_list) == 0:
        raise ValueError(f"No face detected in the image: {image_path}")

    landmarks = landmarks_list[0]

    verification_image_path = f"{output_path_prefix}_landmarks_verification.png"
    if output_path_prefix:
        verification_dir = os.path.dirname(verification_image_path)
        if verification_dir:
            os.makedirs(verification_dir, exist_ok=True)

    plt.figure(figsize=(5,5))
    plt.imshow(img_array)
    plt.scatter(landmarks[:, 0], landmarks[:, 1], c='red', s=10, marker='.')
    plt.title(f"Landmarks: {os.path.basename(image_path)}")
    plt.axis('off')
    try:
        plt.savefig(verification_image_path)
        print(f"Landmark verification image saved to: {verification_image_path}")
    except Exception as e_save:
        print(f"Warning: Could not save landmark verification image to {verification_image_path}. Error: {e_save}")
    plt.close()

    left_eye_center = landmarks[36:42].mean(axis=0)
    right_eye_center = landmarks[42:48].mean(axis=0)
    nose_tip = landmarks[30]

    target_left_eye = (96, 112)
    target_right_eye = (160, 112)
    target_nose_tip = (128, 160)

    src_pts = np.array([left_eye_center, right_eye_center, nose_tip], dtype=np.float32)
    dst_pts = np.array([target_left_eye, target_right_eye, target_nose_tip], dtype=np.float32)

    tform = trans.SimilarityTransform()
    tform.estimate(src_pts, dst_pts)
    aligned_array = trans.warp(img_array, tform.inverse, output_shape=(256, 256), mode='reflect', preserve_range=True)

    if aligned_array.dtype in [np.float64, np.float32]:
        aligned_array = (np.clip(aligned_array, 0, 1) * 255).astype(np.uint8) if aligned_array.max() <= 1 else aligned_array.astype(np.uint8)
    else:
        aligned_array = aligned_array.astype(np.uint8)

    aligned_pil_image = Image.fromarray(aligned_array)
    return aligned_pil_image, landmarks


## Block 3 — Image Preprocessing and Mouth Mask Creation

Block 3 defines the preprocessing transformations used for different model inputs, such as the e4e inversion model, PTI targets, and identity loss networks. It also provides a function that creates a mouth mask by detecting the mouth region from facial landmarks on the aligned image. This mask can be used later to guide localized edits or loss calculations focused on the mouth area.

In [None]:
preprocess_transform_for_inversion = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    #transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

preprocess_transform_for_pti_targets = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor()
])

preprocess_transform_1024_for_pti_targets = transforms.Compose([
    transforms.Resize((1024, 1024)),
    transforms.ToTensor()
])

preprocess_transform_for_id_loss = transforms.Compose([
    transforms.Resize((112, 112)),
    transforms.ToTensor(),
    #transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

def create_mouth_mask_from_aligned_image(aligned_pil_image, fa_model, device_to_use):
    if fa_model is None:
        print("Face alignment model not available for mouth mask. Using a zero mask.")
        return torch.zeros((1, 1, 256, 256), device=device_to_use)

    aligned_img_np = np.array(aligned_pil_image)
    landmarks_on_aligned_list = fa_model.get_landmarks(aligned_img_np)

    if not landmarks_on_aligned_list:
        print("Could not detect landmarks on the aligned image for mouth mask. Using a zero mask.")
        return torch.zeros((1, 1, 256, 256), device=device_to_use)

    landmarks_on_aligned = landmarks_on_aligned_list[0]
    mouth_lm_indices = list(range(48, 68))
    mouth_landmarks_points = landmarks_on_aligned[mouth_lm_indices]

    min_x = int(np.min(mouth_landmarks_points[:, 0]))
    max_x = int(np.max(mouth_landmarks_points[:, 0]))
    min_y = int(np.min(mouth_landmarks_points[:, 1]))
    max_y = int(np.max(mouth_landmarks_points[:, 1]))

    current_mouth_mask = torch.zeros((1, 1, 256, 256), device=device_to_use)
    min_y_clamped, max_y_clamped = max(0, min_y), min(255, max_y)
    min_x_clamped, max_x_clamped = max(0, min_x), min(255, max_x)

    if min_y_clamped < max_y_clamped and min_x_clamped < max_x_clamped:
        current_mouth_mask[:, :, min_y_clamped:max_y_clamped, min_x_clamped:max_x_clamped] = 1.0
        print("Mouth mask created for current image.")
    else:
        print("Warning: Mouth landmark coordinates invalid for mask. Using zero mask.")

    return current_mouth_mask


## Block 4 — Global Model Loading (e4e, LPIPS, Smile Direction, ID Loss)

Block 4 handles the critical technical step of globally loading all machine learning models and components needed across the entire pipeline. It first sets up the e4e encoder by importing the pSp model class and safely loading the pretrained checkpoint (using both weights_only=True and fallback methods if needed). It also loads the smile direction vector, which is a latent-space edit that will later be applied to create smiling expressions.

Next, it prepares the perceptual loss function (LPIPS) using the VGG backbone, ensuring it is ready for evaluating visual similarity during fine-tuning. Finally, the block attempts to load and initialize the identity loss model (ArcFace IR-SE-50), which plays a crucial role during PTI optimization by ensuring that edits preserve the subject’s identity. Careful handling of the current working directory (CWD) is included to ensure the IDLoss class finds its associated weight files.

In [None]:
class Namespace:
    def __init__(self, dictionary):
        self.__dict__.update(dictionary)

try:
    from models.psp import pSp
    from criteria.id_loss import IDLoss
except ModuleNotFoundError as e:
    print(f"Import Error: {e}. Trying to add e4e path...")
    e4e_repo_path_check = os.path.join(BASE_PROJECT_PATH, "e4e")
    if os.path.isdir(e4e_repo_path_check) and e4e_repo_path_check not in sys.path:
        sys.path.append(e4e_repo_path_check)
        print(f"Added {e4e_repo_path_check} to sys.path for model imports.")
        from models.psp import pSp
        from criteria.id_loss import IDLoss
    else:
        print("CRITICAL: Could not find pSp or IDLoss model class. Ensure 'e4e' directory is available.")
        raise

if not os.path.exists(e4e_checkpoint_path):
    raise FileNotFoundError(f"CRITICAL: e4e checkpoint not found at {e4e_checkpoint_path}.")

try:
    base_ckpt = torch.load(e4e_checkpoint_path, map_location='cpu', weights_only=True)
    print("Base e4e checkpoint loaded (weights_only=True).")
except Exception as e_load_base:
    print(f"Warning: Could not load base e4e checkpoint with weights_only=True ({e_load_base}). Attempting without...")
    base_ckpt = torch.load(e4e_checkpoint_path, map_location='cpu')
    print("Base e4e checkpoint loaded (full pickle).")

base_e4e_opts = Namespace(base_ckpt['opts'])
base_e4e_opts.checkpoint_path = e4e_checkpoint_path
base_e4e_opts.device = device

if not os.path.exists(smile_direction_path):
    raise FileNotFoundError(f"CRITICAL: Smile direction file not found at {smile_direction_path}")
smile_direction_global = torch.load(smile_direction_path, map_location=device)
print(f"Global smile direction loaded with shape: {smile_direction_global.shape}")

perceptual_loss_fn_global = lpips.LPIPS(net='vgg').to(device).eval()
print("Global LPIPS model loaded.")

id_loss_fn_global = None
if not os.path.exists(id_loss_checkpoint_path_absolute):
    print(f"WARNING: Identity Loss checkpoint not found at {id_loss_checkpoint_path_absolute}. PTI will run without ID Loss.")
else:
    original_cwd = os.getcwd()
    if os.path.isdir(E4E_BASE_PATH):
        print(f"Temporarily changing CWD to {E4E_BASE_PATH} for IDLoss initialization.")
        os.chdir(E4E_BASE_PATH)
    else:
        print(f"Warning: E4E_BASE_PATH ({E4E_BASE_PATH}) not found. IDLoss might fail if it uses relative paths.")

    try:
        print("Attempting to instantiate IDLoss()...")
        id_loss_fn_global = IDLoss()
        id_loss_fn_global.to(device)
        id_loss_fn_global.eval()
        print(f"Global Identity Loss model instantiated and moved to device. CWD during init: {os.getcwd()}")
    except Exception as e_id_loss:
        print(f"Error initializing or loading Identity Loss model: {e_id_loss}")
        print("   Ensure 'model_ir_se50.pth' is in 'e4e/pretrained_models/'. PTI will run without ID Loss.")
        id_loss_fn_global = None
    finally:
        if os.getcwd() != original_cwd:
            os.chdir(original_cwd)
            print(f"Restored CWD to: {original_cwd}")

##5 — Alignment, Inversion, and PTI (Optimization Loop)

Block 5 is the technical heart of the pipeline, where the system processes each input image through three major stages: face alignment, latent inversion, and PTI (Pivotal Tuning Inversion) optimization.

First, it loops through all images in the input folder, detects and aligns each face, and saves the aligned result. Next, it performs inversion using the e4e encoder, which generates an initial latent code representing the face inside the StyleGAN latent space. This latent code is then passed into Stage 0: latent optimization, where the code itself is fine-tuned (without changing the generator) using a combination of pixel-level (L2) and perceptual (LPIPS) losses.

After that, Stage 1 begins: generator fine-tuning, where the generator’s weights are adjusted while keeping the optimized latent fixed. This phase uses a multi-scale loss design—combining L1 and LPIPS at 256×256 and 1024×1024 resolutions, optionally mouth-specific L1 losses (if a mask was generated), and identity preservation penalties (via ArcFace ID loss if available). Throughout, the block saves intermediate outputs (like optimized latents and tuned generator weights) and logs progress.

In [None]:
optimization_results = []

image_extensions = ('*.jpg', '*.jpeg', '*.png')
image_files = []
for ext in image_extensions:
    image_files.extend(glob.glob(os.path.join(INPUT_FOLDER_PATH, ext)))

if not image_files:
    print(f"No images found in {INPUT_FOLDER_PATH}. Please add images and try again.")
else:
    print(f"Found {len(image_files)} images for optimization.")

for current_input_image_path in image_files:
    current_image_filename = os.path.basename(current_input_image_path)
    current_filename_no_ext = os.path.splitext(current_image_filename)[0]
    print(f"\n--- Optimizing: {current_image_filename} ---")

    current_output_prefix_optimized = os.path.join(OPTIMIZED_DATA_FOLDER, current_filename_no_ext)
    optimized_latent_path = f"{current_output_prefix_optimized}_optimized_latent.pt"
    optimized_generator_path = f"{current_output_prefix_optimized}_optimized_generator.pt"
    aligned_image_for_opt_path = f"{current_output_prefix_optimized}_aligned_for_opt.jpg"

    try:
        align_verif_prefix = os.path.join(OUTPUT_FOLDER_PATH, current_filename_no_ext)
        aligned_image_pil, original_landmarks = align_face(current_input_image_path, fa, align_verif_prefix)
        aligned_image_pil.save(aligned_image_for_opt_path)
        print(f"Aligned image for optimization saved to: {aligned_image_for_opt_path}")

        input_tensor_for_inversion = preprocess_transform_for_inversion(aligned_image_pil).unsqueeze(0).to(device)
        mouth_mask_tensor = create_mouth_mask_from_aligned_image(aligned_image_pil, fa, device)

        inversion_e4e_model = pSp(base_e4e_opts)
        inversion_e4e_model.eval()
        inversion_e4e_model.to(device)

        with torch.no_grad():
            _, initial_latent_code = inversion_e4e_model(input_tensor_for_inversion, randomize_noise=False, return_latents=True)

        original_generator = inversion_e4e_model.decoder
        original_generator.eval()

        print("Starting PTI Stage 0: Latent Optimization...")
        optimized_latent_code = initial_latent_code.clone().detach().requires_grad_(True)
        optimizer_latent = optim.AdamW([optimized_latent_code], lr=0.01)
        num_steps_latent_opt = 10

        original_tensor_256_for_pti_0_1 = preprocess_transform_for_pti_targets(aligned_image_pil).unsqueeze(0).to(device)

        for step in range(num_steps_latent_opt):
            optimizer_latent.zero_grad()
            generated_output_latent_opt = original_generator([optimized_latent_code], input_is_latent=True, randomize_noise=False)
            generated_images_1024_raw_latent_opt = generated_output_latent_opt[0] if isinstance(generated_output_latent_opt, tuple) else generated_output_latent_opt
            generated_1024_latent_opt_0_1 = (generated_images_1024_raw_latent_opt.clamp(-1, 1) + 1) / 2.0
            generated_256_latent_opt_0_1 = F.interpolate(generated_1024_latent_opt_0_1, size=(256, 256), mode='bilinear', align_corners=False)

            l2_loss_latent_opt = F.mse_loss(generated_256_latent_opt_0_1, original_tensor_256_for_pti_0_1)
            lpips_input_gen_latent_opt = generated_256_latent_opt_0_1 * 2 - 1
            lpips_input_target_latent_opt = original_tensor_256_for_pti_0_1 * 2 - 1
            lpips_loss_latent_opt = perceptual_loss_fn_global(lpips_input_gen_latent_opt, lpips_input_target_latent_opt).mean()

            total_loss_latent_opt = l2_loss_latent_opt + lpips_loss_latent_opt
            total_loss_latent_opt.backward()
            optimizer_latent.step()

            if step % 100 == 0 or step == num_steps_latent_opt - 1:
                print(f"Latent Opt Step {step}/{num_steps_latent_opt} - Loss: {total_loss_latent_opt.item():.4f}")

        print("Latent Optimization completed.")
        current_latent_code = optimized_latent_code.detach()

        del original_generator, inversion_e4e_model
        torch.cuda.empty_cache() if device == 'cuda' else None

        print("Starting PTI Stage 1: Generator Tuning...")
        pti_model_instance = pSp(base_e4e_opts)
        pti_generator = pti_model_instance.decoder
        pti_generator.train()
        pti_generator.to(device)

        original_tensor_1024_for_pti_0_1 = preprocess_transform_1024_for_pti_targets(aligned_image_pil).unsqueeze(0).to(device)

        if id_loss_fn_global:
            target_image_for_id = preprocess_transform_for_id_loss(aligned_image_pil).unsqueeze(0).to(device)

        optimizer_pti = optim.AdamW(pti_generator.parameters(), lr=0.002, weight_decay=0.02)
        num_steps_pti_val = 1800

        prev_loss = float('inf')
        no_improvement = 0
        patience = 20
        min_delta = 1e-4

        print(f"Starting PTI Generator Tuning (targets in [0,1], lr={optimizer_pti.defaults['lr']}, steps={num_steps_pti_val})...")
        for step in range(num_steps_pti_val):
            optimizer_pti.zero_grad()
            generated_output_pti_loop = pti_generator([current_latent_code], input_is_latent=True, randomize_noise=False)
            generated_images_1024_raw_loop = generated_output_pti_loop[0] if isinstance(generated_output_pti_loop, tuple) else generated_output_pti_loop
            generated_1024_for_loss_0_1 = (generated_images_1024_raw_loop.clamp(-1, 1) + 1) / 2.0
            generated_256_for_loss_0_1 = F.interpolate(generated_1024_for_loss_0_1, size=(256, 256), mode='bilinear', align_corners=False)

            l1_256 = F.l1_loss(generated_256_for_loss_0_1, original_tensor_256_for_pti_0_1)
            l1_1024 = F.l1_loss(generated_1024_for_loss_0_1, original_tensor_1024_for_pti_0_1)

            lpips_input_gen_256 = generated_256_for_loss_0_1 * 2 - 1
            lpips_input_target_256 = original_tensor_256_for_pti_0_1 * 2 - 1
            lpips_input_gen_1024 = generated_1024_for_loss_0_1 * 2 - 1
            lpips_input_target_1024 = original_tensor_1024_for_pti_0_1 * 2 - 1
            lpips_256 = perceptual_loss_fn_global(lpips_input_gen_256, lpips_input_target_256).mean()
            lpips_1024 = perceptual_loss_fn_global(lpips_input_gen_1024, lpips_input_target_1024).mean()

            l1_mouth_val = torch.tensor(0.0).to(device)
            if mouth_mask_tensor.sum() > 0:
                l1_mouth_val = F.l1_loss(generated_256_for_loss_0_1 * mouth_mask_tensor, original_tensor_256_for_pti_0_1 * mouth_mask_tensor)

            id_loss_val = torch.tensor(0.0).to(device)
            if id_loss_fn_global:
                generated_for_id = F.interpolate(generated_images_1024_raw_loop, size=(112, 112), mode='bilinear', align_corners=False)
                id_loss_output = id_loss_fn_global(generated_for_id, target_image_for_id, target_image_for_id)
                id_loss_val = id_loss_output[0].mean() if isinstance(id_loss_output, tuple) else id_loss_output.mean()

            w_l1 = 1.0
            w_lpips = 1.3
            w_l1_mouth = 2.0
            id_loss_weight = 0.1

            total_loss_pti_loop = (w_l1 * l1_256 + w_lpips * lpips_256) + \
                                  0.5 * (w_l1 * l1_1024 + w_lpips * lpips_1024) + \
                                  w_l1_mouth * l1_mouth_val + \
                                  id_loss_weight * id_loss_val

            total_loss_pti_loop.backward()
            optimizer_pti.step()

            if step % 100 == 0 or step == num_steps_pti_val - 1:
                id_loss_print = f" | ID: {id_loss_val.item():.4f}" if id_loss_fn_global and isinstance(id_loss_val, torch.Tensor) and id_loss_val.numel() > 0 else ""
                print(f"PTI Gen Step {step}/{num_steps_pti_val} - Loss: {total_loss_pti_loop.item():.4f}{id_loss_print}")

            if total_loss_pti_loop.item() >= prev_loss - min_delta:
                no_improvement += 1
                if no_improvement >= patience:
                    print(f"Early stopping at step {step} due to no improvement.")
                    break
            else:
                no_improvement = 0
            prev_loss = total_loss_pti_loop.item()

        print("PTI Generator Tuning completed.")

        torch.save(current_latent_code, optimized_latent_path)
        torch.save(pti_generator.state_dict(), optimized_generator_path)

        optimization_results.append({
            "original_image_path": current_input_image_path,
            "original_landmarks": original_landmarks,
            "optimized_latent_path": optimized_latent_path,
            "optimized_generator_path": optimized_generator_path,
            "output_prefix_final": os.path.join(OUTPUT_FOLDER_PATH, current_filename_no_ext)
        })
        print(f"Optimization outputs saved for {current_image_filename} to {OPTIMIZED_DATA_FOLDER}")

        del pti_model_instance, pti_generator
        torch.cuda.empty_cache() if device == 'cuda' else None

    except Exception as e_img_opt:
        print(f"Error optimizing {current_image_filename}: {e_img_opt}")
        import traceback
        traceback.print_exc()

print("\n--- Optimization stage (Alignment, Inversion, PTI) completed. ---")

## Block 6 — Smile Application, Blending, and Saving

Block 6 is the final execution stage where the system takes each optimized image (with its tuned generator and latent code) and applies the smile edit. It systematically loads the optimized latent and generator, injects the smile direction vector at the specified layers (typically mid-level layers controlling expression), and generates the smiling face image.

Next, it warps this generated smiling face back into the coordinate space of the original high-resolution photo using a geometric transform based on the detected landmarks. A blending mask is constructed, and the smiling face is smoothly merged into the original image, preserving background details and ensuring natural integration. The block carefully manages previous outputs by deleting old files, saves both the direct smiling crop and the final blended image, and logs progress for each processed file.

In [None]:
SMILE_ALPHA = 1.2
SMILE_START_LAYER = 4
SMILE_END_LAYER = 6

if not optimization_results:
    print("No images were successfully optimized. Skipping smile application.")
else:
    print(f"\n--- Starting Smile Application and Blending for {len(optimization_results)} images ---")

    if 'base_e4e_opts' not in globals():
        raise RuntimeError("CRITICAL: base_e4e_opts not defined. Cannot reconstruct generator for smile application.")

    for image_data in optimization_results:
        current_input_image_path = image_data["original_image_path"]
        original_landmarks = image_data["original_landmarks"]
        optimized_latent_path = image_data["optimized_latent_path"]
        optimized_generator_path = image_data["optimized_generator_path"]
        current_output_prefix_final = image_data["output_prefix_final"]

        current_image_filename = os.path.basename(current_input_image_path)
        print(f"\n--- Applying smile to: {current_image_filename} (Alpha: {SMILE_ALPHA}) ---")

        final_smiling_crop_path = f"{current_output_prefix_final}_smiling_crop_alpha{SMILE_ALPHA}.jpg"
        final_blended_smile_path = f"{current_output_prefix_final}_smiling_final_alpha{SMILE_ALPHA}.jpg"

        for file_to_remove in [final_smiling_crop_path, final_blended_smile_path]:
            if os.path.exists(file_to_remove):
                try:
                    os.remove(file_to_remove)
                    print(f"Removed previous version: {file_to_remove}")
                except OSError as e_remove:
                    print(f"Error removing {file_to_remove}: {e_remove}")

        try:
            optimized_latent = torch.load(optimized_latent_path, map_location=device)

            smile_model_instance = pSp(base_e4e_opts)
            smile_generator = smile_model_instance.decoder
            smile_generator.load_state_dict(torch.load(optimized_generator_path, map_location=device))
            smile_generator.to(device)
            smile_generator.eval()

            edited_latent = optimized_latent.clone().detach()

            layer_to_edit = SMILE_START_LAYER
            if layer_to_edit < edited_latent.shape[1]:
                if smile_direction_global.ndim == 1 and smile_direction_global.shape[0] == edited_latent.shape[2]:
                    edited_latent[:, layer_to_edit, :] += SMILE_ALPHA * smile_direction_global
                elif smile_direction_global.ndim == 2 and smile_direction_global.shape[0] == edited_latent.shape[1] and smile_direction_global.shape[1] == edited_latent.shape[2]:
                    edited_latent[:, layer_to_edit, :] += SMILE_ALPHA * smile_direction_global[layer_to_edit, :]
                elif smile_direction_global.ndim == 2 and smile_direction_global.shape[0] == 1 and smile_direction_global.shape[1] == edited_latent.shape[2]:
                    edited_latent[:, layer_to_edit, :] += SMILE_ALPHA * smile_direction_global.squeeze(0)
                else:
                    print(f"Warning: Smile direction shape {smile_direction_global.shape} not directly applicable to layer {layer_to_edit}.")
            else:
                print(f"Warning: SMILE_START_LAYER {layer_to_edit} is out of bounds for latent code.")

            with torch.no_grad():
                edited_result_gen_output = smile_generator([edited_latent], input_is_latent=True, randomize_noise=False)
                edited_image_tensor_raw = edited_result_gen_output[0] if isinstance(edited_result_gen_output, tuple) else edited_result_gen_output

            final_image_norm_0_1 = (edited_image_tensor_raw.clamp(-1, 1) + 1) / 2.0
            save_image(final_image_norm_0_1, final_smiling_crop_path)
            print(f"Smiling crop (alpha={SMILE_ALPHA}) saved to: {final_smiling_crop_path}")

            from scipy.ndimage import gaussian_filter

            left_eye_orig_b = original_landmarks[36:42].mean(axis=0)
            right_eye_orig_b = original_landmarks[42:48].mean(axis=0)
            nose_orig_b = original_landmarks[30]
            src_pts_orig_b = np.array([left_eye_orig_b, right_eye_orig_b, nose_orig_b], dtype=np.float32)

            scale_1024_from_256_b = 1024 / 256
            dst_pts_1024_b = np.array([
                (96 * scale_1024_from_256_b, 112 * scale_1024_from_256_b),
                (160 * scale_1024_from_256_b, 112 * scale_1024_from_256_b),
                (128 * scale_1024_from_256_b, 160 * scale_1024_from_256_b)
            ], dtype=np.float32)

            tform_original_to_crop_b = trans.SimilarityTransform()
            tform_original_to_crop_b.estimate(src_pts_orig_b, dst_pts_1024_b)

            original_full_pil_b = Image.open(current_input_image_path).convert('RGB')
            original_full_arr_0_1_b = np.array(original_full_pil_b) / 255.0

            smiling_crop_np_0_1_b = final_image_norm_0_1.squeeze(0).permute(1, 2, 0).cpu().numpy()

            warped_smiling_face_b = trans.warp(
                smiling_crop_np_0_1_b,
                tform_original_to_crop_b,
                output_shape=original_full_arr_0_1_b.shape[:2],
                mode='reflect',
                preserve_range=True
            )

            mask_crop_space_b = np.ones((1024, 1024), dtype=np.float32)
            warped_mask_b = trans.warp(
                mask_crop_space_b,
                tform_original_to_crop_b,
                output_shape=original_full_arr_0_1_b.shape[:2],
                preserve_range=True
            )
            warped_mask_b = np.clip(warped_mask_b, 0, 1)

            sigma_blur_b = 0
            blurred_mask_b = gaussian_filter(warped_mask_b, sigma=sigma_blur_b)
            blurred_mask_b = np.clip(blurred_mask_b, 0, 1)[..., None]

            blended_arr_0_1_b_final = original_full_arr_0_1_b * (1 - blurred_mask_b) + \
                                      warped_smiling_face_b * blurred_mask_b
            blended_arr_0_1_b_final = np.clip(blended_arr_0_1_b_final, 0, 1)

            blended_img_uint8_b_final = (blended_arr_0_1_b_final * 255).astype(np.uint8)
            final_blended_pil_to_save = Image.fromarray(blended_img_uint8_b_final)
            final_blended_pil_to_save.save(final_blended_smile_path)
            print(f"Final blended image (alpha={SMILE_ALPHA}) saved to: {final_blended_smile_path}")

            processed_image_pairs.append((current_input_image_path, final_blended_smile_path))

            del smile_model_instance, smile_generator
            torch.cuda.empty_cache() if device == 'cuda' else None

        except Exception as e_smile_apply:
            print(f"Error applying smile to {current_image_filename}: {e_smile_apply}")
            import traceback
            traceback.print_exc()

print("\n--- Smile application and blending stage completed. ---")


# The End!