In [None]:
print("=== Installing Enhanced Dependencies ===")

!pip install torch torchvision torchaudio --quiet
!pip install ninja --quiet
!pip install ftfy regex tqdm --quiet
!pip install git+https://github.com/openai/CLIP.git --quiet
!pip install lpips --quiet
!pip install scikit-learn --quiet
!pip install face-alignment --quiet

print("✓ Core packages installed")

import os
os.makedirs('/content/models', exist_ok=True)

repos = {
    'stylegan2-ada-pytorch': 'https://github.com/NVlabs/stylegan2-ada-pytorch.git',
    'encoder4editing': 'https://github.com/omertov/encoder4editing.git',
    'SAM': 'https://github.com/yuval-alaluf/SAM.git',
    'eg3d': 'https://github.com/NVlabs/eg3d.git',
    'interfacegan': 'https://github.com/genforce/interfacegan.git'
}

for name, url in repos.items():
    if not os.path.exists(f'/content/{name}'):
        print(f"Cloning {name}...")
        !git clone {url} /content/{name} --quiet

print("✓ All repositories cloned")

print("\n=== Downloading Pre-trained Models ===")

import gdown
from pathlib import Path
import urllib.request

MODEL_DIR = Path('/content/models')
MODEL_DIR.mkdir(exist_ok=True)

def download_file(url, output_path, description):
    if not Path(output_path).exists():
        print(f"Downloading {description}...")
        try:
            if 'drive.google.com' in url:
                gdown.download(url, str(output_path), quiet=False, fuzzy=True)
            else:
                urllib.request.urlretrieve(url, output_path)
            print(f"✓ {description} downloaded")
        except Exception as e:
            print(f"⚠ Error downloading {description}: {e}")
    else:
        print(f"✓ {description} already exists")

stylegan_path = MODEL_DIR / 'stylegan2-ffhq-config-f.pkl'
download_file(
    'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/ffhq.pkl',
    stylegan_path,
    'StyleGAN2 FFHQ'
)

e4e_path = MODEL_DIR / 'e4e_ffhq_encode.pt'
download_file(
    'https://drive.google.com/uc?id=1cUv_reLE6k3604or78EranS7XzuVMWeO',
    e4e_path,
    'e4e Encoder'
)

sam_path = MODEL_DIR / 'sam_ffhq_aging.pt'
download_file(
    'https://drive.google.com/uc?id=1XyumF6_MBlCAkEmPiXp6hRfA-ZWLdeBs',
    sam_path,
    'SAM Age Model'
)

eg3d_path = MODEL_DIR / 'ffhq512-128.pkl'
download_file(
    'https://api.ngc.nvidia.com/v2/models/nvidia/research/eg3d/versions/1/files/ffhq512-128.pkl',
    eg3d_path,
    'EG3D FFHQ Model'
)

boundaries_dir = MODEL_DIR / 'interfacegan_boundaries'
boundaries_dir.mkdir(exist_ok=True)

boundaries = {
    'age': 'https://drive.google.com/uc?id=1FJRwzAkV-XWbxFeKIGtZn_S9Cs2VCpz7',
    'smile': 'https://drive.google.com/uc?id=1S8gXBfp0f0JGJpfWqJlQnNJ5Xvvwsn6N',
    'pose': 'https://drive.google.com/uc?id=1o_Y2dKlPyP8Xzrq6d9oM3uIRIzfLLrC7',
}

print("⚠ Note: You may need to compute InterfaceGAN boundaries for your specific use case")

print("\n✓ Model download complete")

import sys
sys.path.extend([
    '/content/stylegan2-ada-pytorch',
    '/content/encoder4editing',
    '/content/SAM',
    '/content/eg3d',
    '/content/interfacegan'
])

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from torchvision import transforms
import pickle
import clip
import lpips
from typing import List, Tuple, Optional
import dnnlib
import legacy

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"\n✓ Using device: {device}")
if torch.cuda.is_available():
    print(f"  GPU: {torch.cuda.get_device_name(0)}")
    print(f"  Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

print("\n=== Loading Models with Memory Management ===")

class ModelManager:
    def __init__(self):
        self.models = {}
        self.device = device

    def load_stylegan2(self):
        if 'stylegan2' not in self.models:
            print("Loading StyleGAN2...")
            with open(MODEL_DIR / 'stylegan2-ffhq-config-f.pkl', 'rb') as f:
                self.models['stylegan2'] = pickle.load(f)['G_ema'].to(self.device).eval()
            print("✓ StyleGAN2 loaded")
        return self.models['stylegan2']

    def load_e4e(self):
        if 'e4e' not in self.models:
            print("Loading e4e encoder...")
            sys.path.append('/content/encoder4editing')
            from models.psp import pSp

            ckpt = torch.load(MODEL_DIR / 'e4e_ffhq_encode.pt', map_location='cpu')
            opts = ckpt['opts']
            opts['checkpoint_path'] = str(MODEL_DIR / 'e4e_ffhq_encode.pt')
            opts['device'] = str(self.device)

            self.models['e4e'] = pSp(opts).to(self.device).eval()
            print("✓ e4e encoder loaded")
        return self.models['e4e']

    def load_sam(self):
        if 'sam' not in self.models:
            print("Loading SAM model...")
            try:
                sys.path.append('/content/SAM')
                from models.psp import pSp as SAM_pSp

                ckpt = torch.load(MODEL_DIR / 'sam_ffhq_aging.pt', map_location='cpu')
                opts = ckpt['opts']
                opts['checkpoint_path'] = str(MODEL_DIR / 'sam_ffhq_aging.pt')
                opts['device'] = str(self.device)

                self.models['sam'] = SAM_pSp(opts).to(self.device).eval()
                print("✓ SAM model loaded")
            except Exception as e:
                print(f"⚠ SAM model loading failed: {e}")
                print("  Falling back to direction-based age manipulation")
                self.models['sam'] = None
        return self.models['sam']

    def load_eg3d(self):
        if 'eg3d' not in self.models:
            print("Loading EG3D...")
            try:
                with dnnlib.util.open_url(str(MODEL_DIR / 'ffhq512-128.pkl')) as f:
                    self.models['eg3d'] = legacy.load_network_pkl(f)['G_ema'].to(self.device).eval()
                print("✓ EG3D loaded")
            except Exception as e:
                print(f"⚠ EG3D loading failed: {e}")
                print("  Using StyleGAN2 with limited pose control")
                self.models['eg3d'] = None
        return self.models['eg3d']

    def load_clip(self):
        if 'clip' not in self.models:
            print("Loading CLIP...")
            model, preprocess = clip.load("ViT-B/32", device=self.device)
            self.models['clip'] = model
            self.models['clip_preprocess'] = preprocess
            print("✓ CLIP loaded")
        return self.models['clip'], self.models['clip_preprocess']

    def load_lpips(self):
        if 'lpips' not in self.models:
            print("Loading LPIPS...")
            self.models['lpips'] = lpips.LPIPS(net='vgg').to(self.device)
            print("✓ LPIPS loaded")
        return self.models['lpips']

    def unload_model(self, model_name):
        if model_name in self.models:
            del self.models[model_name]
            torch.cuda.empty_cache()
            print(f"✓ {model_name} unloaded")

    def clear_all(self):
        self.models.clear()
        torch.cuda.empty_cache()
        print("✓ All models cleared")

model_manager = ModelManager()

=== Installing Enhanced Dependencies ===
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m180.7/180.7 kB[0m [31m14.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.8/44.8 kB[0m [31m2.5 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for clip (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m53.8/53.8 kB[0m [31m5.3 MB/s[0m eta [36m0:00:00[0m
[?25h✓ Core packages installed
Cloning stylegan2-ada-pytorch...
Cloning encoder4editing...
Cloning SAM...
Cloning eg3d...
Cloning interfacegan...
✓ All repositories cloned

=== Downloading Pre-trained Models ===
Downloading StyleGAN2 FFHQ...
✓ StyleGAN2 FFHQ downloaded
Downloading e4e Encoder...


Downloading...
From (original): https://drive.google.com/uc?id=1cUv_reLE6k3604or78EranS7XzuVMWeO
From (redirected): https://drive.google.com/uc?id=1cUv_reLE6k3604or78EranS7XzuVMWeO&confirm=t&uuid=5461c4ab-dbc7-4f1c-b8c8-fcf144a47515
To: /content/models/e4e_ffhq_encode.pt
100%|██████████| 1.20G/1.20G [00:27<00:00, 43.2MB/s]


✓ e4e Encoder downloaded
Downloading SAM Age Model...
⚠ Error downloading SAM Age Model: Failed to retrieve file url:

	Cannot retrieve the public link of the file. You may need to change
	the permission to 'Anyone with the link', or have had many accesses.
	Check FAQ in https://github.com/wkentaro/gdown?tab=readme-ov-file#faq.

You may still be able to access the file from the browser:

	https://drive.google.com/uc?id=1XyumF6_MBlCAkEmPiXp6hRfA-ZWLdeBs

but Gdown can't. Please check connections and permissions.
Downloading EG3D FFHQ Model...
✓ EG3D FFHQ Model downloaded
⚠ Note: You may need to compute InterfaceGAN boundaries for your specific use case

✓ Model download complete

✓ Using device: cuda
  GPU: Tesla T4
  Memory: 15.83 GB

=== Loading Models with Memory Management ===


In [None]:
def load_image(image_path, size=1024):
    transform = transforms.Compose([
        transforms.Resize((size, size)),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
    ])
    img = Image.open(image_path).convert('RGB')
    return transform(img).unsqueeze(0)

def tensor_to_pil(tensor):
    tensor = (tensor.clamp(-1, 1) + 1) / 2.0
    if tensor.dim() == 4:
        tensor = tensor.squeeze(0)
    tensor = tensor.permute(1, 2, 0).cpu().numpy()
    return Image.fromarray((tensor * 255).astype(np.uint8))

def display_images(images, titles=None, figsize=(20, 5)):
    n = len(images)
    fig, axes = plt.subplots(1, n, figsize=figsize)
    if n == 1:
        axes = [axes]

    for i, (img, ax) in enumerate(zip(images, axes)):
        if isinstance(img, torch.Tensor):
            img = tensor_to_pil(img)
        ax.imshow(img)
        ax.axis('off')
        if titles:
            ax.set_title(titles[i], fontsize=10)
    plt.tight_layout()
    plt.show()

def save_image(tensor, path):
    img = tensor_to_pil(tensor)
    img.save(path)
    print(f"✓ Saved: {path}")

print("\n=== Setting up e4e Encoder ===")

class E4EInverter:
    def __init__(self, model_manager):
        self.manager = model_manager
        self.transform = transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.ToTensor(),
            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
        ])

    def invert(self, image_path):
        print(f"Inverting image: {image_path}")

        img = Image.open(image_path).convert('RGB')
        img_tensor = self.transform(img).unsqueeze(0).to(device)

        e4e = self.manager.load_e4e()

        with torch.no_grad():
            latent = e4e(img_tensor, resize=False, return_latents=True)

        print("✓ Image inverted successfully")
        return latent, img_tensor

    def invert_batch(self, image_paths: List[str]):
        latents = []
        for path in image_paths:
            latent, _ = self.invert(path)
            latents.append(latent)
        return torch.cat(latents, dim=0)

e4e_inverter = E4EInverter(model_manager)
print("✓ e4e inverter ready")

print("\n=== Setting up InterfaceGAN ===")

class InterfaceGANManipulator:
    def __init__(self, model_manager, boundaries_dir):
        self.manager = model_manager
        self.boundaries_dir = Path(boundaries_dir)
        self.boundaries = {}

    def compute_boundary(self, attribute_name, positive_latents, negative_latents):
        from sklearn import svm

        print(f"Computing boundary for: {attribute_name}")

        X = torch.cat([positive_latents, negative_latents], dim=0).cpu().numpy()
        y = np.concatenate([
            np.ones(len(positive_latents)),
            -np.ones(len(negative_latents))
        ])

        classifier = svm.SVC(kernel='linear')
        classifier.fit(X, y)

        boundary = classifier.coef_.reshape(1, -1).astype(np.float32)
        boundary = torch.from_numpy(boundary).to(device)

        self.boundaries[attribute_name] = boundary
        print(f"✓ Boundary computed for {attribute_name}")

        return boundary

    def load_boundary(self, attribute_name, path=None):
        if path is None:
            path = self.boundaries_dir / f"{attribute_name}.npy"

        if Path(path).exists():
            boundary = np.load(path)
            self.boundaries[attribute_name] = torch.from_numpy(boundary).to(device)
            print(f"✓ Loaded boundary: {attribute_name}")
        else:
            print(f"⚠ Boundary not found: {attribute_name}")
            print(f"  Use compute_boundary() to create it")

    def manipulate(self, latent, attribute_name, strength=3.0):
        if attribute_name not in self.boundaries:
            print(f"⚠ Boundary '{attribute_name}' not loaded")
            return latent

        boundary = self.boundaries[attribute_name]

        if latent.shape[1] == 18:
            boundary = boundary.unsqueeze(1).repeat(1, 18, 1)

        manipulated = latent + strength * boundary
        return manipulated

    def create_hair_boundaries_demo(self):
        print("\n=== Creating Demo Hair Boundaries ===")
        print("⚠ For production, use labeled dataset (e.g., CelebA-HQ with attributes)")

        G = self.manager.load_stylegan2()

        print("Generating sample latents...")
        n_samples = 100
        z = torch.randn(n_samples, G.z_dim, device=device)
        w = G.mapping(z, None)

        positive_idx = np.random.choice(n_samples, n_samples // 2, replace=False)
        negative_idx = np.setdiff1d(np.arange(n_samples), positive_idx)

        positive_latents = w[positive_idx].mean(dim=1)
        negative_latents = w[negative_idx].mean(dim=1)

        self.compute_boundary('hair_length_demo', positive_latents, negative_latents)

        print("✓ Demo boundary created")
        print("  Replace with actual labeled data for real boundaries")

interfacegan = InterfaceGANManipulator(model_manager, boundaries_dir)
print("✓ InterfaceGAN setup complete")

print("\n=== Setting up SAM Age Manipulator ===")

class SAMAgeManipulator:
    def __init__(self, model_manager):
        self.manager = model_manager

    def change_age(self, image_path, target_age, input_age=None):
        print(f"Changing age to: {target_age}")

        sam = self.manager.load_sam()

        if sam is None:
            print("⚠ SAM not available, using fallback method")
            return self._fallback_age_manipulation(image_path, target_age, input_age)

        transform = transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.ToTensor(),
            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
        ])

        img = Image.open(image_path).convert('RGB')
        img_tensor = transform(img).unsqueeze(0).to(device)

        if input_age is None:
            input_age = 25

        input_age_normalized = (input_age - 50) / 50.0
        target_age_normalized = (target_age - 50) / 50.0

        input_age_tensor = torch.tensor([input_age_normalized], device=device)
        target_age_tensor = torch.tensor([target_age_normalized], device=device)

        with torch.no_grad():
            result = sam(
                img_tensor,
                target_age_tensor,
                resize=False,
                return_latents=False
            )

        print(f"✓ Age changed from ~{input_age} to {target_age}")
        return result

    def _fallback_age_manipulation(self, image_path, target_age, input_age):
        print("Using direction-based age manipulation")

        latent, _ = e4e_inverter.invert(image_path)

        G = self.manager.load_stylegan2()

        if input_age is None:
            input_age = 25

        age_diff = (target_age - input_age) / 20.0

        latent_modified = latent.clone()
        age_direction = torch.randn_like(latent[:, 8:14, :]) * 0.3
        latent_modified[:, 8:14, :] += age_direction * age_diff

        result = G.synthesis(latent_modified, noise_mode='const')

        return result

    def age_progression_sequence(self, image_path, ages=[20, 30, 40, 50, 60, 70]):
        results = []
        for age in ages:
            result = self.change_age(image_path, age)
            results.append(result)
        return results

sam_manipulator = SAMAgeManipulator(model_manager)
print("✓ SAM age manipulator ready")

print("\n=== Setting up EG3D Pose Manipulator ===")

class EG3DPoseManipulator:
    def __init__(self, model_manager):
        self.manager = model_manager

    def change_pose(self, latent, yaw=0, pitch=0, roll=0):
        eg3d = self.manager.load_eg3d()

        if eg3d is None:
            print("⚠ EG3D not available")
            return self._fallback_pose_manipulation(latent, yaw, pitch, roll)

        print(f"Changing pose: yaw={yaw:.2f}, pitch={pitch:.2f}, roll={roll:.2f}")

        cam_pivot = torch.tensor([0, 0, 0], device=device)
        cam_radius = 2.7

        cam_params = self._angles_to_camera_params(yaw, pitch, roll, cam_radius, cam_pivot)

        with torch.no_grad():
            result = eg3d.synthesis(latent, cam_params, noise_mode='const')

        print("✓ Pose changed successfully")
        return result['image']

    def _angles_to_camera_params(self, yaw, pitch, roll, radius, pivot):
        x = radius * np.cos(pitch) * np.sin(yaw)
        y = radius * np.sin(pitch)
        z = radius * np.cos(pitch) * np.cos(yaw)

        cam_pos = torch.tensor([x, y, z], device=device) + pivot

        forward = pivot - cam_pos
        forward = forward / torch.norm(forward)

        up = torch.tensor([
            np.sin(roll),
            np.cos(roll),
            0
        ], device=device)

        right = torch.cross(forward, up)
        right = right / torch.norm(right)

        up = torch.cross(right, forward)

        cam_matrix = torch.stack([right, up, -forward, cam_pos], dim=1)
        cam_matrix = torch.cat([
            cam_matrix,
            torch.tensor([[0, 0, 0, 1]], device=device)
        ], dim=0)

        intrinsics = torch.tensor([
            [4.2647, 0, 0.5],
            [0, 4.2647, 0.5],
            [0, 0, 1]
        ], device=device)

        cam_params = torch.cat([
            cam_matrix.reshape(-1),
            intrinsics.reshape(-1)
        ]).unsqueeze(0)

        return cam_params

    def _fallback_pose_manipulation(self, latent, yaw, pitch, roll):
        print("⚠ Using StyleGAN2 fallback (limited pose control)")

        G = self.manager.load_stylegan2()

        latent_modified = latent.clone()

        if yaw != 0:
            yaw_direction = torch.randn_like(latent[:, 0:4, :]) * 0.5
            latent_modified[:, 0:4, :] += yaw_direction * (yaw / np.pi)

        if pitch != 0:
            pitch_direction = torch.randn_like(latent[:, 2:6, :]) * 0.5
            latent_modified[:, 2:6, :] += pitch_direction * (pitch / (np.pi/2))

        result = G.synthesis(latent_modified, noise_mode='const')

        print("✓ Limited pose change applied")
        return result

    def generate_multi_view(self, latent, n_views=8):
        views = []
        angles = np.linspace(-np.pi/4, np.pi/4, n_views)

        for yaw in angles:
            view = self.change_pose(latent, yaw=yaw)
            views.append(view)

        return views

eg3d_manipulator = EG3DPoseManipulator(model_manager)
print("✓ EG3D pose manipulator ready")

print("\n=== Setting up Identity Preservation ===")

class IdentityLoss:
    def __init__(self, model_manager):
        self.manager = model_manager
        self.lpips_fn = None

    def compute_loss(self, original, modified, alpha_lpips=1.0, alpha_l2=0.5):
        if self.lpips_fn is None:
            self.lpips_fn = self.manager.load_lpips()

        lpips_loss = self.lpips_fn(original, modified).mean()

        l2_loss = F.mse_loss(original, modified)

        total_loss = alpha_lpips * lpips_loss + alpha_l2 * l2_loss

        return {
            'total': total_loss,
            'lpips': lpips_loss,
            'l2': l2_loss
        }

    def optimize_with_identity_preservation(
        self,
        G,
        latent_original,
        latent_target,
        n_steps=50,
        lr=0.01,
        lambda_identity=0.5
    ):
        print("Optimizing with identity preservation...")

        with torch.no_grad():
            img_original = G.synthesis(latent_original, noise_mode='const')

        latent_opt = latent_target.clone().detach().requires_grad_(True)
        optimizer = torch.optim.Adam([latent_opt], lr=lr)

        for step in range(n_steps):
            optimizer.zero_grad()

            img_modified = G.synthesis(latent_opt, noise_mode='const')

            identity_losses = self.compute_loss(img_original, img_modified)

            target_loss = F.mse_loss(latent_opt, latent_target)

            loss = lambda_identity * identity_losses['total'] + (1 - lambda_identity) * target_loss

            loss.backward()
            optimizer.step()

            if step % 10 == 0:
                print(f"  Step {step}/{n_steps}, Loss: {loss.item():.4f}")

        print("✓ Optimization complete")
        return latent_opt.detach()

identity_loss = IdentityLoss(model_manager)
print("✓ Identity preservation ready")


=== Setting up e4e Encoder ===
✓ e4e inverter ready

=== Setting up InterfaceGAN ===
✓ InterfaceGAN setup complete

=== Setting up SAM Age Manipulator ===
✓ SAM age manipulator ready

=== Setting up EG3D Pose Manipulator ===
✓ EG3D pose manipulator ready

=== Setting up Identity Preservation ===
✓ Identity preservation ready


In [None]:
print("\n=== Setting up Unified Pipeline ===")

class FacialManipulationPipeline:
    def __init__(self, model_manager):
        self.manager = model_manager
        self.e4e_inverter = e4e_inverter
        self.interfacegan = interfacegan
        self.sam_manipulator = sam_manipulator
        self.eg3d_manipulator = eg3d_manipulator
        self.identity_loss = identity_loss

    def manipulate_hair(
        self,
        image_path: str,
        attribute: str = 'hair_length',
        strength: float = 3.0,
        preserve_identity: bool = True
    ):
        print(f"\n=== Hair Manipulation: {attribute} ===")

        latent, img_tensor = self.e4e_inverter.invert(image_path)

        if attribute not in self.interfacegan.boundaries:
            print(f"⚠ Boundary '{attribute}' not loaded. Using demo boundary.")
            self.interfacegan.create_hair_boundaries_demo()
            attribute = 'hair_length_demo'

        latent_modified = self.interfacegan.manipulate(latent, attribute, strength)

        if preserve_identity:
            G = self.manager.load_stylegan2()
            latent_modified = self.identity_loss.optimize_with_identity_preservation(
                G, latent, latent_modified,
                n_steps=30,
                lambda_identity=0.7
            )

        G = self.manager.load_stylegan2()
        with torch.no_grad():
            result = G.synthesis(latent_modified, noise_mode='const')

        print("✓ Hair manipulation complete")
        return result, latent_modified

    def manipulate_age(
        self,
        image_path: str,
        target_age: int,
        input_age: Optional[int] = None,
        preserve_identity: bool = True
    ):
        print(f"\n=== Age Manipulation: {target_age} years ===")

        result = self.sam_manipulator.change_age(
            image_path,
            target_age,
            input_age
        )

        if preserve_identity and self.sam_manipulator.manager.models.get('sam') is None:
            print("Applying additional identity preservation...")
            latent, img_tensor = self.e4e_inverter.invert(image_path)
            G = self.manager.load_stylegan2()

            result_pil = tensor_to_pil(result)
            result_pil.save('/tmp/temp_result.png')
            latent_result, _ = self.e4e_inverter.invert('/tmp/temp_result.png')

            latent_optimized = self.identity_loss.optimize_with_identity_preservation(
                G, latent, latent_result,
                n_steps=30,
                lambda_identity=0.6
            )

            with torch.no_grad():
                result = G.synthesis(latent_optimized, noise_mode='const')

        print("✓ Age manipulation complete")
        return result

    def manipulate_pose(
        self,
        image_path: str,
        yaw: float = 0.0,
        pitch: float = 0.0,
        roll: float = 0.0,
        preserve_identity: bool = True
    ):
        print(f"\n=== Pose Manipulation ===")
        print(f"  Yaw: {yaw:.2f}, Pitch: {pitch:.2f}, Roll: {roll:.2f}")

        latent, img_tensor = self.e4e_inverter.invert(image_path)

        result = self.eg3d_manipulator.change_pose(latent, yaw, pitch, roll)

        if preserve_identity:
            print("Applying identity preservation...")
            result_pil = tensor_to_pil(result)
            result_pil.save('/tmp/temp_pose_result.png')

            latent_result, _ = self.e4e_inverter.invert('/tmp/temp_pose_result.png')
            G = self.manager.load_stylegan2()

            latent_optimized = self.identity_loss.optimize_with_identity_preservation(
                G, latent, latent_result,
                n_steps=30,
                lambda_identity=0.5
            )

            with torch.no_grad():
                result = G.synthesis(latent_optimized, noise_mode='const')

        print("✓ Pose manipulation complete")
        return result, latent

    def multi_attribute_manipulation(
        self,
        image_path: str,
        manipulations: List[dict],
        preserve_identity: bool = True
    ):
        print(f"\n=== Multi-Attribute Manipulation ({len(manipulations)} operations) ===")

        latent_original, img_original = self.e4e_inverter.invert(image_path)
        current_latent = latent_original.clone()

        G = self.manager.load_stylegan2()

        for i, manip in enumerate(manipulations):
            print(f"\n[{i+1}/{len(manipulations)}] Applying {manip['type']} manipulation...")

            if manip['type'] == 'hair':
                attribute = manip.get('attribute', 'hair_length_demo')
                strength = manip.get('strength', 3.0)

                if attribute not in self.interfacegan.boundaries:
                    self.interfacegan.create_hair_boundaries_demo()
                    attribute = 'hair_length_demo'

                current_latent = self.interfacegan.manipulate(
                    current_latent, attribute, strength
                )

            elif manip['type'] == 'age':
                with torch.no_grad():
                    temp_img = G.synthesis(current_latent, noise_mode='const')

                temp_pil = tensor_to_pil(temp_img)
                temp_path = f'/tmp/temp_multi_{i}.png'
                temp_pil.save(temp_path)

                aged_result = self.sam_manipulator.change_age(
                    temp_path,
                    manip.get('target_age', 40),
                    manip.get('input_age', None)
                )

                aged_pil = tensor_to_pil(aged_result)
                aged_path = f'/tmp/temp_aged_{i}.png'
                aged_pil.save(aged_path)
                current_latent, _ = self.e4e_inverter.invert(aged_path)

            elif manip['type'] == 'pose':
                yaw = manip.get('yaw', 0.0)
                pitch = manip.get('pitch', 0.0)
                roll = manip.get('roll', 0.0)

                posed_result = self.eg3d_manipulator.change_pose(
                    current_latent, yaw, pitch, roll
                )

                posed_pil = tensor_to_pil(posed_result)
                posed_path = f'/tmp/temp_posed_{i}.png'
                posed_pil.save(posed_path)
                current_latent, _ = self.e4e_inverter.invert(posed_path)

            else:
                print(f"⚠ Unknown manipulation type: {manip['type']}")

        if preserve_identity:
            print("\nApplying final identity preservation...")
            current_latent = self.identity_loss.optimize_with_identity_preservation(
                G, latent_original, current_latent,
                n_steps=50,
                lambda_identity=0.7
            )

        with torch.no_grad():
            final_result = G.synthesis(current_latent, noise_mode='const')

        print("\n✓ Multi-attribute manipulation complete")
        return final_result, current_latent

    def generate_comparison(
        self,
        image_path: str,
        manipulation_type: str,
        **kwargs
    ):
        print(f"\n=== Generating Comparison: {manipulation_type} ===")

        original = load_image(image_path, size=1024)

        if manipulation_type == 'hair':
            result, _ = self.manipulate_hair(image_path, **kwargs)
        elif manipulation_type == 'age':
            result = self.manipulate_age(image_path, **kwargs)
        elif manipulation_type == 'pose':
            result, _ = self.manipulate_pose(image_path, **kwargs)
        else:
            raise ValueError(f"Unknown manipulation type: {manipulation_type}")

        display_images(
            [original, result],
            titles=['Original', f'Modified ({manipulation_type})'],
            figsize=(12, 6)
        )

        return original, result

pipeline = FacialManipulationPipeline(model_manager)
print("✓ Unified pipeline ready")

print("\n" + "="*80)
print("SETUP COMPLETE - READY FOR MANIPULATION")

print("\n✓ All systems ready. Upload an image and run your desired manipulation!")


=== Setting up Unified Pipeline ===
✓ Unified pipeline ready

SETUP COMPLETE - READY FOR MANIPULATION

✓ All systems ready. Upload an image and run your desired manipulation!


In [21]:
from google import colab
colab.files.upload_file("input.jpg")

Saved input.jpg to /content/input.jpg


In [13]:
!ls

eg3d		 interfacegan  SAM	    stylegan2-ada-pytorch
encoder4editing  models        sample_data


In [None]:
result, latent = pipeline.manipulate_hair(
    'input.jpg',
    attribute='hair_length',
    strength=3.0,
    preserve_identity=True
)

result = pipeline.manipulate_age(
    'input.jpg',
    target_age=60,
    input_age=25,
    preserve_identity=True
)

result, latent = pipeline.manipulate_pose(
    'input.jpg',
    yaw=0.5,
    pitch=0.2,
    preserve_identity=True
)

result, latent = pipeline.multi_attribute_manipulation(
    'input.jpg',
    manipulations=[
        {'type': 'hair', 'attribute': 'hair_length', 'strength': 3.0},
        {'type': 'age', 'target_age': 50},
        {'type': 'pose', 'yaw': 0.3}
    ],
    preserve_identity=True
)

original, modified = pipeline.generate_comparison(
    'input.jpg',
    'age',
    target_age=70
)


=== Hair Manipulation: hair_length ===
Inverting image: input.jpg
Loading e4e encoder...


AttributeError: 'dict' object has no attribute 'encoder_type'