In [3]:
import os
import torch
import numpy as np
from PIL import Image
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import requests
from io import BytesIO
import shutil

from diffusers import (
    StableDiffusionPipeline, 
    DPMSolverMultistepScheduler,
    StableDiffusionImg2ImgPipeline
)

# Define constants
DEFAULT_BASE_MODEL = "runwayml/stable-diffusion-v1-5"
DEFAULT_OUTPUT_DIR = "./face_aging_model"
DEFAULT_DATASET_DIR = "./datasets"

In [4]:
def download_image(url, output_path=None):
    """Download an image from a URL"""
    try:
        response = requests.get(url, timeout=10)
        if response.status_code == 200:
            if output_path:
                with open(output_path, 'wb') as f:
                    f.write(response.content)
                return output_path
            else:
                return Image.open(BytesIO(response.content))
        else:
            print(f"Failed to download image: {response.status_code}")
            return None
    except Exception as e:
        print(f"Error downloading image: {e}")
        return None

def download_sample_images(output_dir=DEFAULT_DATASET_DIR):
    """Download sample images for different age groups"""
    print("Downloading sample images...")
    os.makedirs(output_dir, exist_ok=True)
    
    # Define age groups and sample image URLs (using FFHQ dataset samples)
    age_groups = {
        "child": [
            "https://raw.githubusercontent.com/NVlabs/ffhq-dataset/master/thumbnails128x128/00000/00000.png",
            "https://raw.githubusercontent.com/NVlabs/ffhq-dataset/master/thumbnails128x128/00001/00001.png",
        ],
        "young_adult": [
            "https://raw.githubusercontent.com/NVlabs/ffhq-dataset/master/thumbnails128x128/00002/00002.png",
            "https://raw.githubusercontent.com/NVlabs/ffhq-dataset/master/thumbnails128x128/00003/00003.png",
        ],
        "middle_aged": [
            "https://raw.githubusercontent.com/NVlabs/ffhq-dataset/master/thumbnails128x128/00004/00004.png",
            "https://raw.githubusercontent.com/NVlabs/ffhq-dataset/master/thumbnails128x128/00005/00005.png",
        ],
        "elderly": [
            "https://raw.githubusercontent.com/NVlabs/ffhq-dataset/master/thumbnails128x128/00006/00006.png",
            "https://raw.githubusercontent.com/NVlabs/ffhq-dataset/master/thumbnails128x128/00007/00007.png",
        ]
    }
    
    # Download images for each age group
    for age_group, urls in age_groups.items():
        group_dir = os.path.join(output_dir, age_group)
        os.makedirs(group_dir, exist_ok=True)
        
        for i, url in enumerate(urls):
            try:
                response = requests.get(url, timeout=10)
                if response.status_code == 200:
                    img = Image.open(BytesIO(response.content))
                    img_path = os.path.join(group_dir, f"{i}.png")
                    img.save(img_path)
                    
                    # Create a caption file for the image
                    caption_path = os.path.join(group_dir, f"{i}.txt")
                    with open(caption_path, 'w') as f:
                        age = 5 if age_group == "child" else \
                              25 if age_group == "young_adult" else \
                              45 if age_group == "middle_aged" else 75
                        f.write(f"a detailed photograph of a {age_group} person, {age} years old")
            except Exception as e:
                print(f"Error downloading {url}: {e}")
    
    print(f"Sample images created at {output_dir}")
    return output_dir


In [5]:
def age_face(
    image_path,
    base_model_path=DEFAULT_BASE_MODEL,
    target_age="elderly",
    target_age_value=75,
    strength=0.75,
    num_inference_steps=30,
    guidance_scale=7.5,
    output_path=None
):
    """Apply age transformation to a face image using text-to-image prompting"""
    # Load models
    print(f"Loading model {base_model_path}...")
    
    # Choose pipeline based on whether we want to use an existing image
    if os.path.exists(image_path) or image_path.startswith("http"):
        pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
            base_model_path,
            torch_dtype=torch.float16,
            safety_checker=None
        ).to("cuda" if torch.cuda.is_available() else "cpu")
        img2img = True
    else:
        pipe = StableDiffusionPipeline.from_pretrained(
            base_model_path,
            torch_dtype=torch.float16,
            safety_checker=None
        ).to("cuda" if torch.cuda.is_available() else "cpu")
        img2img = False
    
    pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
    
    # Load input image if we're doing img2img
    if img2img:
        if isinstance(image_path, str):
            if image_path.startswith("http"):
                init_image = download_image(image_path)
            else:
                init_image = Image.open(image_path).convert("RGB")
        else:
            # Assume it's already a PIL image
            init_image = image_path
        
        # Resize image
        width, height = init_image.size
        if width > 768 or height > 768:
            # Maintain aspect ratio
            if width > height:
                new_width = 768
                new_height = int(height * (768 / width))
            else:
                new_height = 768
                new_width = int(width * (768 / height))
            init_image = init_image.resize((new_width, new_height), Image.LANCZOS)
    
    # Create prompt for different ages
    age_descriptions = {
        "child": "a young child, about 5-8 years old",
        "teen": "a teenager, about 15-18 years old",
        "young_adult": "a young adult, about 25-30 years old",
        "middle_aged": "a middle-aged person, about 45-50 years old",
        "elderly": "an elderly person, about 70-80 years old",
    }
    
    # Get description or use target_age as is if not found
    age_desc = age_descriptions.get(target_age, f"a {target_age} person, {target_age_value} years old")
    
    # Craft prompt
    prompt = f"a highly detailed realistic photograph of {age_desc}, same person, same identity, detailed face, high quality, detailed skin"
    
    # Add age-specific details
    if "elderly" in target_age or target_age_value >= 65:
        prompt += ", wrinkles, gray hair, aged skin"
    elif "middle" in target_age or 40 <= target_age_value < 65:
        prompt += ", slight wrinkles, mature face"
    elif "young" in target_age or 20 <= target_age_value < 40:
        prompt += ", youthful appearance"
    elif "teen" in target_age or 13 <= target_age_value < 20:
        prompt += ", teenage appearance, young face"
    elif "child" in target_age or target_age_value < 13:
        prompt += ", childlike features, young face, smooth skin"
    
    # Negative prompt to avoid distortion
    negative_prompt = "deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, blurry"
    
    print(f"Generating {target_age} version with prompt: {prompt}")
    
    # Generate image
    if img2img:
        result = pipe(
            prompt=prompt,
            negative_prompt=negative_prompt,
            image=init_image,
            strength=strength,
            num_inference_steps=num_inference_steps,
            guidance_scale=guidance_scale
        )
    else:
        result = pipe(
            prompt=prompt,
            negative_prompt=negative_prompt,
            num_inference_steps=num_inference_steps,
            guidance_scale=guidance_scale
        )
    
    aged_image = result.images[0]
    
    # Save output if specified
    if output_path:
        os.makedirs(os.path.dirname(os.path.abspath(output_path)), exist_ok=True)
        aged_image.save(output_path)
        print(f"Aged image saved to {output_path}")
    
    # Display the image in the notebook
    plt.figure(figsize=(10, 10))
    plt.imshow(np.array(aged_image))
    plt.axis('off')
    plt.title(f"{target_age} ({target_age_value} years)")
    plt.show()
    
    return aged_image

In [6]:
def generate_age_progression(
    image_path,
    base_model_path=DEFAULT_BASE_MODEL,
    output_dir="./age_progressions",
    age_steps=[("child", 7), ("young_adult", 25), ("middle_aged", 45), ("elderly", 75)],
    strength=0.75,
    guidance_scale=7.5
):
    """Generate a series of age progressions for a face"""
    os.makedirs(output_dir, exist_ok=True)
    
    # Create a subplot for all ages
    fig, axes = plt.subplots(1, len(age_steps) + 1, figsize=(4 * (len(age_steps) + 1), 4))
    
    # Load and display original image
    if isinstance(image_path, str):
        if image_path.startswith("http"):
            original_image = download_image(image_path)
        else:
            original_image = Image.open(image_path).convert("RGB")
    else:
        original_image = image_path
    
    # Display original image
    axes[0].imshow(np.array(original_image))
    axes[0].set_title("Original")
    axes[0].axis("off")
    
    # Save original image
    original_path = os.path.join(output_dir, "original.png")
    original_image.save(original_path)
    
    # Generate each age progression
    results = []
    for i, (age_label, age_value) in enumerate(age_steps):
        output_path = os.path.join(output_dir, f"{age_label}_{age_value}.png")
        print(f"Generating {age_label} ({age_value} years) version...")
        
        # Age the face
        aged_image = age_face(
            image_path=image_path,
            base_model_path=base_model_path,
            target_age=age_label,
            target_age_value=age_value,
            strength=strength,
            guidance_scale=guidance_scale,
            output_path=output_path
        )
        
        # Display in the subplot
        axes[i+1].imshow(np.array(aged_image))
        axes[i+1].set_title(f"{age_label} ({age_value} years)")
        axes[i+1].axis("off")
        
        results.append((age_label, age_value, output_path))
    
    # Save and display the composite image
    plt.tight_layout()
    composite_path = os.path.join(output_dir, "age_progression.png")
    plt.savefig(composite_path)
    plt.show()
    
    # Print summary
    print("\nAge progression complete! Results saved to:")
    print(f"Original: {original_path}")
    for age_label, age_value, path in results:
        print(f"{age_label} ({age_value} years): {path}")
    print(f"Composite: {composite_path}")
    
    return results

In [7]:
url = "https://raw.githubusercontent.com/NVlabs/ffhq-dataset/master/thumbnails128x128/00000/00000.png"
aged_face = age_face(
    image_path=url,
    target_age="elderly",
    target_age_value=75,
    output_path="aged_face.png"
)

Loading model runwayml/stable-diffusion-v1-5...


model_index.json:   0%|          | 0.00/541 [00:00<?, ?B/s]

Fetching 13 files:   0%|          | 0/13 [00:00<?, ?it/s]

preprocessor_config.json:   0%|          | 0.00/342 [00:00<?, ?B/s]

Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]

You have disabled the safety checker for <class 'diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline'> by passing `safety_checker=None`. Ensure that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered results in services or applications open to the public. Both the diffusers team and Hugging Face strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling it only for use-cases that involve analyzing network behavior or auditing its results. For more information, please have a look at https://github.com/huggingface/diffusers/pull/254 .


Failed to download image: 404


AttributeError: 'NoneType' object has no attribute 'size'