In [None]:
import torch
from diffusers import StableDiffusionImg2ImgPipeline
from PIL import Image
import numpy as np
from skimage.metrics import structural_similarity as compare_ssim
from skimage.metrics import peak_signal_noise_ratio as compare_psnr
from transformers import CLIPProcessor, CLIPModel
from lpips import LPIPS  # Requires installation of lpips
from IPython.display import display
import matplotlib.pyplot as plt

### Load Models

In [None]:
# Load CLIP model for similarity scoring
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to("cuda")
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

# Load LPIPS model for perceptual similarity
lpips_model = LPIPS(net="alex").to("cuda")

### Functions for Pipeline

In [None]:
def truncate_prompt(prompt, max_length=77):
    """
    Truncate the prompt to fit within the token limit for the CLIP model, considering tokenized length.
    """
    from transformers import CLIPTokenizer
    tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")

    # Tokenize the prompt
    tokens = tokenizer(prompt, truncation=True, max_length=max_length, return_tensors="pt")

    # Decode back the truncated prompt
    truncated_prompt = tokenizer.decode(tokens["input_ids"][0], skip_special_tokens=True)
    print(f"Prompt truncated to: {truncated_prompt}")
    return truncated_prompt

# Function to calculate SSIM
def calculate_ssim(image1, image2):
    image1 = np.array(image1.convert("L"))
    image2 = np.array(image2.convert("L"))
    ssim, _ = compare_ssim(image1, image2, full=True)
    return ssim

# Function to calculate PSNR
def calculate_psnr(image1, image2):
    image1 = np.array(image1)
    image2 = np.array(image2)
    return compare_psnr(image1, image2, data_range=255)

# Function to calculate CLIP similarity
def calculate_clip_score(image, text):
    inputs = clip_processor(text=[text], images=image, return_tensors="pt", padding=True).to("cuda")
    outputs = clip_model(**inputs)
    return outputs.logits_per_image.item()

def preprocess_latents(pipeline, image, generator):
    """
    Preprocess input image into latents for use in Img2Img pipeline.
    """
    device = pipeline.device
    image = image.to(device).to(dtype=torch.float16)  # Ensure the input image matches model precision
    latent_image = pipeline.vae.encode(image.to(device)).latent_dist.sample(generator)
    latent_image = latent_image * 0.18215  # Scale factor used in Stable Diffusion pipelines
    return latent_image

def setup_img2img_pipeline():
    model_id = "runwayml/stable-diffusion-v1-5"
    pipe = StableDiffusionImg2ImgPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
    pipe = pipe.to("cuda")
    return pipe

def truncate_prompt_with_addition(base_prompt, addition, max_length=77):
    """
    Truncate the base prompt to fit within the token limit while appending the addition.
    Ensures the resulting prompt is exactly max_length tokens.
    """
    from transformers import CLIPTokenizer
    tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")

    # Tokenize base prompt and addition
    base_tokens = tokenizer(base_prompt, truncation=False, return_tensors="pt")
    addition_tokens = tokenizer(addition, truncation=False, return_tensors="pt")

    # Calculate available space for base prompt after accounting for the addition
    total_tokens = len(base_tokens["input_ids"][0]) + len(addition_tokens["input_ids"][0]) - 1
    if total_tokens > max_length:
        max_base_length = max_length - (len(addition_tokens["input_ids"][0]) - 1)
        base_tokens["input_ids"] = base_tokens["input_ids"][:, :max_base_length]

    # Combine truncated base prompt and addition tokens
    combined_tokens = torch.cat((base_tokens["input_ids"][0], addition_tokens["input_ids"][0][1:]))
    combined_tokens = combined_tokens[:max_length]  # Ensure final truncation to max_length

    # Decode back to prompt text
    truncated_prompt = tokenizer.decode(combined_tokens, skip_special_tokens=True)
    print(f"Updated Prompt: {truncated_prompt}")
    return truncated_prompt

def process_enhanced_prompt(input_image_path, input_prompt, num_iterations=5, strength=0.3, guidance_scale=7.5):
    pipeline = setup_img2img_pipeline()

    # Define the initial enhanced prompt
    prompt = (
        "highly detailed police sketch, black and white pencil drawing, "
        "professional forensic artist style " + input_prompt
    )

    # Load the input image
    input_image = Image.open(input_image_path).convert("RGB")
    input_image_tensor = torch.tensor(np.array(input_image)).permute(2, 0, 1).unsqueeze(0).float() / 255.0
    input_image_tensor = input_image_tensor.to(pipeline.device).to(dtype=torch.float16)

    # Preprocess input image into latents
    generator = torch.Generator(device="cuda").manual_seed(42)
    latents = preprocess_latents(pipeline, input_image_tensor, generator)

    # Lists to store metrics
    ssim_scores, psnr_scores, lpips_scores, clip_scores, generated_images = [], [], [], [], []

    generated_image = None

    for i in range(num_iterations):


        print(f"Iteration {i + 1}: Generating image with prompt: {prompt}")

        if i == 0:
            addition = "with a thinner nose"
        elif i == 1:
            addition = "with thinner eyes"
        elif i == 2:
            addition = "with a more square jawline"
        elif i == 3:
            addition = "with a wider mouth and thinner lips"
        else:
            addition = "with darker eyebrows"

        if i != 0:
          # Update latents after each iteration
          latents = preprocess_latents(pipeline, torch.tensor(np.array(generated_image)).permute(2, 0, 1).unsqueeze(0).float() / 255.0, generator)

        # Update the prompt with the addition, ensuring it remains 77 tokens
        prompt = truncate_prompt_with_addition(prompt, addition, max_length=77)

        # Generate an image using Img2Img with latents
        output = pipeline(prompt=prompt, negative_prompt="no changes to other features",image=input_image, latents=latents, strength=strength, guidance_scale=guidance_scale)

        generated_image = output.images[0]

        # Resize the input image to match the generated image dimensions
        resized_input_image = input_image.resize(generated_image.size)

        generated_images.append(generated_image)

        # Display the generated image
        display(generated_image)

        # Calculate metrics
        ssim = calculate_ssim(resized_input_image, generated_image)
        ssim_scores.append(ssim)
        print(f"SSIM with input image at iteration {i + 1}: {ssim}")

        psnr = calculate_psnr(resized_input_image, generated_image)
        psnr_scores.append(psnr)
        print(f"PSNR with input image at iteration {i + 1}: {psnr}")

        lpips_value = lpips_model(
            torch.tensor(np.array(resized_input_image)).permute(2, 0, 1).unsqueeze(0).float().to("cuda"),
            torch.tensor(np.array(generated_image)).permute(2, 0, 1).unsqueeze(0).float().to("cuda")
        ).item()
        lpips_scores.append(lpips_value)
        print(f"LPIPS with input image at iteration {i + 1}: {lpips_value}")

        clip_score = calculate_clip_score(generated_image, prompt)
        clip_scores.append(clip_score)
        print(f"CLIP score for iteration {i + 1}: {clip_score}")

    return ssim_scores, psnr_scores, lpips_scores, clip_scores, generated_images

### Generate Images

In [None]:

input_image_path = "../input_images/00002.jpg"
prompt = "a male around 30-35 years old The hairstyle is medium-length, straight hair parted slightly off-center and almond-shaped eyes with a neutral expression mouth"

# Process the enhanced prompt and calculate metrics
ssim_scores, psnr_scores, lpips_scores, clip_scores, generated_images = process_enhanced_prompt(input_image_path, prompt)

# Print results
print("\nFinal Metrics Across Iterations:")
print(f"SSIM Scores: {ssim_scores}")
print(f"PSNR Scores: {psnr_scores}")
print(f"LPIPS Scores: {lpips_scores}")
print(f"CLIP Scores: {clip_scores}")

### View Images

In [None]:
plt.figure(figsize=(20, 5))
for i, img in enumerate(generated_images):
    plt.subplot(1, len(generated_images), i + 1)
    plt.imshow(img)
    plt.axis("off")
    plt.title(f"Iteration {i + 1}")
plt.tight_layout()
plt.show()

### Plot Metrics Over Iterations

In [None]:
iterations = list(range(1, len(ssim_scores) + 1))
plt.figure(figsize=(15, 5))

# SSIM Plot
plt.subplot(1, 4, 1)
plt.plot(iterations, ssim_scores, label="SSIM", marker='o', color='blue')
plt.xlabel("Iteration")
plt.ylabel("SSIM")
plt.title("SSIM Over Iterations")
plt.grid(True)

# PSNR Plot
plt.subplot(1, 4, 2)
plt.plot(iterations, psnr_scores, label="PSNR", marker='o', color='green')
plt.xlabel("Iteration")
plt.ylabel("PSNR")
plt.title("PSNR Over Iterations")
plt.grid(True)

# CLIP Score Plot
plt.subplot(1, 4, 3)
plt.plot(iterations, clip_scores, label="CLIP Score", marker='o', color='purple')
plt.xlabel("Iteration")
plt.ylabel("CLIP Score")
plt.title("CLIP Score Over Iterations")
plt.grid(True)

# LPIPS Plot
plt.subplot(1, 4, 4)
plt.plot(iterations, lpips_scores, label="LPIPS", marker='o', color='orange')
plt.xlabel("Iteration")
plt.ylabel("LPIPS")
plt.title("LPIPS Over Iterations")
plt.grid(True)

plt.tight_layout()
plt.show()