In [None]:
import torch
from diffusers import StableDiffusionPipeline
from PIL import Image
import numpy as np
from skimage.metrics import structural_similarity as compare_ssim
from skimage.metrics import peak_signal_noise_ratio as compare_psnr
from transformers import CLIPProcessor, CLIPModel
from lpips import LPIPS  # Requires installation of lpips
from IPython.display import display
import matplotlib.pyplot as plt

# Load CLIP model for similarity scoring
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to("cuda")
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

# Load LPIPS model for perceptual similarity
lpips_model = LPIPS(net="alex").to("cuda")

def truncate_prompt(prompt, max_length=77):
    """
    Truncate the prompt to fit within the token limit for the CLIP model, considering tokenized length.
    """
    from transformers import CLIPTokenizer
    tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
    
    # Tokenize the prompt
    tokens = tokenizer(prompt, truncation=True, max_length=max_length, return_tensors="pt")
    
    # Decode back the truncated prompt
    truncated_prompt = tokenizer.decode(tokens["input_ids"][0], skip_special_tokens=True)
    print(f"Prompt truncated to: {truncated_prompt}")
    return truncated_prompt


# Function to calculate SSIM
def calculate_ssim(image1, image2):
    image1 = np.array(image1.convert("L"))
    image2 = np.array(image2.convert("L"))
    ssim, _ = compare_ssim(image1, image2, full=True)
    return ssim

# Function to calculate PSNR
def calculate_psnr(image1, image2):
    image1 = np.array(image1)
    image2 = np.array(image2)
    return compare_psnr(image1, image2, data_range=255)

# Function to calculate CLIP similarity
def calculate_clip_score(image, text):
    inputs = clip_processor(text=[text], images=image, return_tensors="pt", padding=True).to("cuda")
    outputs = clip_model(**inputs)
    return outputs.logits_per_image.item()

def setup_stable_diffusion(seed=42):
    model_id = "runwayml/stable-diffusion-v1-5"
    device = "cuda" if torch.cuda.is_available() else "cpu"
    pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
    pipe = pipe.to(device)
    pipe.safety_checker = None
    pipe.enable_attention_slicing()

    if seed is not None:
        generator = torch.Generator(device="cuda").manual_seed(seed)
    pipe.set_progress_bar_config(disable=True)

    return pipe


def process_enhanced_prompt(input_image_path, prompt, num_iterations=5):
    pipeline = setup_stable_diffusion()

    # Define the enhanced prompt
    prompt = (
        "highly detailed police sketch, black and white pencil drawing, "
        "professional forensic artist style, detailed shading, portrait of " + prompt
    )

    # Load the input image
    input_image = Image.open(input_image_path).convert("RGB")

    # Lists to store metrics
    ssim_scores, psnr_scores, lpips_scores, clip_scores = [], [], [], []

    # Use a consistent random seed for reproducibility
    generator = torch.Generator(device="cuda").manual_seed(42)

    # Latents initialization
    latents = None

    for i in range(num_iterations):
        print(f"Iteration {i + 1}: Generating image with prompt: {prompt}")

        prompt = truncate_prompt(prompt, max_length=77)

        # Generate latents for the first iteration if not already defined
        if latents is None:
            height, width = pipeline.unet.config.sample_size * 8, pipeline.unet.config.sample_size * 8
            latents = torch.randn(
                (1, pipeline.unet.in_channels, height // 8, width // 8),
                generator=generator,
                device=pipeline.device,
                dtype=torch.float16,
            )

        # Generate an image using the latents
        output = pipeline(prompt, latents=latents, return_dict=True, guidance_scale=7.5)
        generated_image = output.images[0]

        # Resize the input image to match the generated image dimensions
        resized_input_image = input_image.resize(generated_image.size)

        # Display the generated image
        display(generated_image)

        # Calculate metrics
        ssim = calculate_ssim(resized_input_image, generated_image)
        ssim_scores.append(ssim)
        print(f"SSIM with input image at iteration {i + 1}: {ssim}")

        psnr = calculate_psnr(resized_input_image, generated_image)
        psnr_scores.append(psnr)
        print(f"PSNR with input image at iteration {i + 1}: {psnr}")

        lpips_value = lpips_model(
            torch.tensor(np.array(resized_input_image)).permute(2, 0, 1).unsqueeze(0).float().to("cuda"),
            torch.tensor(np.array(generated_image)).permute(2, 0, 1).unsqueeze(0).float().to("cuda")
        ).item()
        lpips_scores.append(lpips_value)
        print(f"LPIPS with input image at iteration {i + 1}: {lpips_value}")

        clip_score = calculate_clip_score(generated_image, prompt)
        clip_scores.append(clip_score)
        print(f"CLIP score for iteration {i + 1}: {clip_score}")

        # Update prompt if necessary
        if i == 0:
            prompt += ", add more detail to the hairstyle"
        elif i == 1:
            prompt += ", emphasize the eyes and eyebrows"
        elif i == 2:
            prompt += ", add sharper jawline definition"

    return ssim_scores, psnr_scores, lpips_scores, clip_scores



# Example usage
input_image_path = "/content/00002.jpg"  # Replace with the path to your input image
prompt = "The person is described as male around 30-35 years old with an oval face, defined cheekbones, medium-length straight hair, and light skin tone."

# Process the enhanced prompt and calculate metrics
ssim_scores, psnr_scores, lpips_scores, clip_scores = process_enhanced_prompt(input_image_path, prompt)

# Print results
print("\nFinal Metrics Across Iterations:")
print(f"SSIM Scores: {ssim_scores}")
print(f"PSNR Scores: {psnr_scores}")
print(f"LPIPS Scores: {lpips_scores}")
print(f"CLIP Scores: {clip_scores}")

In [None]:
# Metrics Explanation:
# 1. Structural Similarity Index (SSIM):
#    - Measures structural similarity (luminance, contrast, and structure) between two images.
#    - Range: 0 to 1, where 1 indicates identical structure.
#    - Higher SSIM indicates better structural alignment between iterations.

# 2. Peak Signal-to-Noise Ratio (PSNR):
#    - Measures pixel-level distortion or noise between two images.
#    - Higher PSNR (typically 20-50) indicates lower distortion and clearer images.

# 3. Cosine Similarity:
#    - Compares the semantic similarity of image embeddings (high-level features).
#    - Range: -1 to 1, where 1 means identical embeddings and -1 means complete dissimilarity.
#    - Higher values indicate better semantic consistency between iterations.

# 4. CLIP Score:
#    - Measures the similarity between a generated image and the input text prompt using the CLIP model.
#    - Higher scores indicate that the generated image aligns better with the input prompt.

# 5. Learned Perceptual Image Patch Similarity (LPIPS):
#    - Evaluates perceptual similarity between two images based on learned features.
#    - Range: 0 to 1, where lower scores indicate higher perceptual similarity.
#    - Useful for assessing how close two images are in terms of human perception.

# 6. Fréchet Inception Distance (FID):
#    - Measures the quality and diversity of generated images compared to reference images.
#    - Lower FID scores indicate that the generated images are closer to the reference distribution.
#    - Typically used to evaluate overall image quality in generative models.

In [None]:
# Adjust iterations to align with metrics (start from 2nd iteration)
iterations = list(range(1, len(ssim_scores) + 1))

# Plotting the metrics
plt.figure(figsize=(15, 5))

# SSIM Plot
plt.subplot(1, 4, 1)
plt.plot(iterations, ssim_scores, label="SSIM", marker='o', color='blue')
plt.xlabel("Iteration")
plt.ylabel("SSIM")
plt.title("SSIM Over Iterations")
plt.grid(True)

# PSNR Plot
plt.subplot(1, 4, 2)
plt.plot(iterations, psnr_scores, label="PSNR", marker='o', color='green')
plt.xlabel("Iteration")
plt.ylabel("PSNR")
plt.title("PSNR Over Iterations")
plt.grid(True)

# # Cosine Similarity Plot
# plt.subplot(1, 5, 3)
# plt.plot(iterations, cosine_similarities, label="Cosine Similarity", marker='o', color='red')
# plt.xlabel("Iteration")
# plt.ylabel("Cosine Similarity")
# plt.title("Cosine Similarity Over Iterations")
# plt.grid(True)

# CLIP Score Plot
plt.subplot(1, 4, 3)
plt.plot(iterations, clip_scores, label="CLIP Score", marker='o', color='purple')
plt.xlabel("Iteration")
plt.ylabel("CLIP Score")
plt.title("CLIP Score Over Iterations")
plt.grid(True)

# LPIPS Plot
plt.subplot(1, 4, 4)
plt.plot(iterations, lpips_scores, label="LPIPS", marker='o', color='orange')
plt.xlabel("Iteration")
plt.ylabel("LPIPS")
plt.title("LPIPS Over Iterations")
plt.grid(True)

plt.tight_layout()
plt.show()