In [None]:
import torch
from PIL import Image
import numpy as np
from skimage.metrics import structural_similarity as compare_ssim
from skimage.metrics import peak_signal_noise_ratio as compare_psnr
import torch.nn as nn
from lpips import LPIPS
import matplotlib.pyplot as plt
from diffusers import StableDiffusionImg2ImgPipeline
from transformers import CLIPProcessor, CLIPModel

### Functions for Pipeline

In [None]:
device = "cuda"
projection_layer = nn.Linear(512, 768).to("cuda")

class LoRALinear(nn.Module):
    def __init__(self, original_linear, r, lora_alpha, lora_dropout):
        super().__init__()
        self.original_linear = original_linear
        self.r = r
        self.lora_alpha = lora_alpha
        self.lora_dropout = nn.Dropout(lora_dropout)
        self.lora_A = nn.Parameter(torch.zeros(original_linear.in_features, r))
        self.lora_B = nn.Parameter(torch.zeros(r, original_linear.out_features))
        self.scaling = lora_alpha / r

        # Initialize parameters
        nn.init.normal_(self.lora_A, std=0.02)
        nn.init.normal_(self.lora_B, std=0.02)

    def forward(self, x):
        result = self.original_linear(x)
        if self.r > 0:
            lora_result = (self.lora_dropout(x) @ self.lora_A @ self.lora_B) * self.scaling
            return result + lora_result
        return result


def apply_lora_to_model(model, r=16, lora_alpha=32, lora_dropout=0.1):
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear) and any(substr in name for substr in ['q_proj', 'v_proj', 'k_proj']):
            parent = model
            for name_part in name.split('.')[:-1]:
                parent = getattr(parent, name_part)
            layer_name = name.split('.')[-1]
            setattr(parent, layer_name, LoRALinear(module, r, lora_alpha, lora_dropout))

def setup_clip_model(checkpoint_path, device):
    model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
    processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
    apply_lora_to_model(model)  # Apply LoRA modifications
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'], strict=True)
    model = model.to(device).eval()
    return model, processor


def generate_image_with_sketch_and_embeddings(
    input_image_path, prompt, clip_model, clip_processor, stable_diffusion, strength=0.3, guidance_scale=7.5
):
    input_image = Image.open(input_image_path).convert("RGB")

    # Preprocess the image and prompt for CLIP
    inputs = clip_processor(text=[prompt], images=input_image, return_tensors="pt", padding=True).to(device)

    # Generate CLIP embeddings
    with torch.no_grad():
        image_features = clip_model.get_image_features(pixel_values=inputs["pixel_values"])
        text_features = clip_model.get_text_features(input_ids=inputs["input_ids"])

    # Normalize and combine embeddings
    image_features = image_features / image_features.norm(dim=-1, keepdim=True)
    text_features = text_features / text_features.norm(dim=-1, keepdim=True)
    combined_embeddings = (image_features + text_features) / 2.0
    combined_embeddings = combined_embeddings / combined_embeddings.norm(dim=-1, keepdim=True)

    # Project embeddings to match Stable Diffusion's expected dimension
    projected_embeddings = projection_layer(combined_embeddings)

    # Reshape embeddings to match Stable Diffusion's expected shape
    projected_embeddings = projected_embeddings.unsqueeze(1)  # Add seq_len dimension (batch_size, seq_len, embed_dim)

    # Create negative_prompt_embeds (typically all zeros or empty embeddings)
    negative_prompt_embeds = torch.zeros_like(projected_embeddings)


    # Generate the image using the Stable Diffusion pipeline
    generated_images = stable_diffusion(
          prompt_embeds=projected_embeddings,
          negative_prompt_embeds=negative_prompt_embeds,  # Provide negative embeddings
          image=input_image,  # Ensure this is a PIL.Image.Image object
          strength=strength,  # Control the level of deviation from the input image
          guidance_scale=guidance_scale,  # Control adherence to the prompt
      )["images"]

    return generated_images[0]

def calculate_clip_score(clip_model, clip_processor, image, text):
    # Ensure the CLIP model is on the correct device
    device = next(clip_model.parameters()).device

    # Preprocess the image and text
    inputs = clip_processor(text=[text], images=image, return_tensors="pt", padding=True).to(device)

    # Ensure tensors match the model's data type
    inputs = {
        key: value.to(dtype=torch.long if key == "input_ids" else clip_model.dtype)
        for key, value in inputs.items()
    }

    # Run the model and calculate the score
    with torch.no_grad():
        outputs = clip_model(**inputs)

    return outputs.logits_per_image.item()


def truncate_prompt_with_addition(base_prompt, addition, max_length=77):
    """
    Truncate the base prompt to fit within the token limit while appending the addition.
    Ensures the resulting prompt is exactly max_length tokens.
    """
    from transformers import CLIPTokenizer
    tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")

    # Tokenize base prompt and addition
    base_tokens = tokenizer(base_prompt, truncation=False, return_tensors="pt")
    addition_tokens = tokenizer(addition, truncation=False, return_tensors="pt")

    # Calculate available space for base prompt after accounting for the addition
    total_tokens = len(base_tokens["input_ids"][0]) + len(addition_tokens["input_ids"][0]) - 1
    if total_tokens > max_length:
        max_base_length = max_length - (len(addition_tokens["input_ids"][0]) - 1)
        base_tokens["input_ids"] = base_tokens["input_ids"][:, :max_base_length]

    # Combine truncated base prompt and addition tokens
    combined_tokens = torch.cat((base_tokens["input_ids"][0], addition_tokens["input_ids"][0][1:]))
    combined_tokens = combined_tokens[:max_length]  # Ensure final truncation to max_length

    # Decode back to prompt text
    truncated_prompt = tokenizer.decode(combined_tokens, skip_special_tokens=True)
    print(f"Updated Prompt: {truncated_prompt}")
    return truncated_prompt

def calculate_ssim(image1, image2):
    image1 = np.array(image1.convert("L"))
    image2 = np.array(image2.convert("L"))
    ssim, _ = compare_ssim(image1, image2, full=True)
    return ssim

def calculate_psnr(image1, image2):
    image1 = np.array(image1)
    image2 = np.array(image2)
    return compare_psnr(image1, image2, data_range=255)

def get_clip_embedding(clip_model, clip_processor, text_prompt, image_path=None, device="cuda"):
    # Compute CLIP embeddings for text and optional image
    if image_path:
        image = Image.open(image_path).convert("RGB")
        inputs = clip_processor(text=[text_prompt], images=image, return_tensors="pt", padding=True).to(device)
    else:
        inputs = clip_processor(text=[text_prompt], return_tensors="pt", padding=True).to(device)

    with torch.no_grad():
        outputs = clip_model(**inputs)

    return outputs.text_embeds  # Use outputs.image_embeds for image embeddings if needed


def project_clip_embedding(clip_embedding, target_dim=768, device="cuda"):
    # Ensure the embedding is compatible with Stable Diffusion
    projection = torch.nn.Linear(clip_embedding.shape[-1], target_dim).to(device)
    return projection(clip_embedding)


def process_with_clip_conditioning(initial_prompt, image_path, clip_model, clip_processor, stable_diffusion,  device="cuda", num_iterations=5
):
    enhanced_prompt = (
        "highly detailed police sketch, black and white pencil drawing, "
        "professional forensic artist style, detailed shading, "
        f"portrait of {initial_prompt}"
    )
    lpips_model = LPIPS(net="alex").to(device)
    input_image = Image.open(image_path).convert("RGB")
    ssim_scores, psnr_scores, lpips_scores, clip_scores, generated_images = [], [], [], [], []
    latents = None  # Initialize latents

    for i in range(num_iterations):


        # Update the prompt for the next iteration
        if i == 0:
            addition = ", make the hair curlier"
        elif i == 1:
            addition = ", make eyes larger"
        elif i == 2:
            addition = ", make the eyebrows thinner"
        else:
            addition = ", make the jawline more square"

        enhanced_prompt = truncate_prompt_with_addition(enhanced_prompt, addition, max_length=77)


        if i == 0:
          generated_image = generate_image_with_sketch_and_embeddings(
            image_path, enhanced_prompt, clip_model, clip_processor, stable_diffusion, strength=0.4, guidance_scale=7.5
          )
          print(f"\nIteration {i + 1}: Generating image with updated prompt...{enhanced_prompt}")
        else:
          generated_image = generate_image_with_sketch_and_embeddings(
            image_path, addition, clip_model, clip_processor, stable_diffusion, strength=0.3, guidance_scale=7.5
          )

          print(f"\nIteration {i + 1}: Generating image with updated prompt...{addition}")

        generated_images.append(generated_image)


        display(generated_image)
        generated_image.save(f"output_image_iteration_{i}.png")
        image_path = f"output_image_iteration_{i}.png"

        resized_input_image = input_image.resize(generated_image.size)

        ssim = calculate_ssim(resized_input_image, generated_image)
        psnr = calculate_psnr(resized_input_image, generated_image)
        lpips_value = lpips_model(
            torch.tensor(np.array(resized_input_image)).permute(2, 0, 1).unsqueeze(0).to(device, dtype=torch.float32),
            torch.tensor(np.array(generated_image)).permute(2, 0, 1).unsqueeze(0).to(device, dtype=torch.float32),
        ).item()
        clip_score = calculate_clip_score(clip_model, clip_processor, generated_image, enhanced_prompt)


        clip_scores.append(clip_score)
        ssim_scores.append(ssim)
        psnr_scores.append(psnr)
        lpips_scores.append(lpips_value)

        print(f"SSIM: {ssim}, PSNR: {psnr}, LPIPS: {lpips_value}, CLIP: {clip_score}")

    return ssim_scores, psnr_scores, lpips_scores, clip_scores, generated_images

### Generate Images

In [None]:
text_prompt = "The person is described as male around 30-35 years old with an oval face, defined cheekbones, medium-length straight hair, and light skin tone."
image_path = "/content/00002.jpg"

checkpoint_path = "/content/drive/My Drive/path_to_your_file/clip_checkpoint_epoch_20.pt"
clip_model, clip_processor = setup_clip_model(checkpoint_path, device)
stable_diffusion = StableDiffusionImg2ImgPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16
).to("cuda")

# Generate images with CLIP embedding conditioning
ssim_scores, psnr_scores, lpips_scores, clip_scores, generated_images = process_with_clip_conditioning(
    text_prompt, image_path, clip_model, clip_processor, stable_diffusion
)

# Print metrics
print("\nFinal Metrics Across Iterations:")
print(f"SSIM Scores: {ssim_scores}")
print(f"PSNR Scores: {psnr_scores}")
print(f"LPIPS Scores: {lpips_scores}")

### View Images

In [None]:
plt.figure(figsize=(20, 5))
for i, img in enumerate(generated_images):
    plt.subplot(1, len(generated_images), i + 1)
    plt.imshow(img)
    plt.axis("off")
    plt.title(f"Iteration {i + 1}")
plt.tight_layout()
plt.show()

### Plot Metrics Over Iterations

In [None]:
iterations = list(range(1, len(ssim_scores) + 1))
plt.figure(figsize=(15, 5))

# SSIM Plot
plt.subplot(1, 4, 1)
plt.plot(iterations, ssim_scores, label="SSIM", marker="o", color="blue")
plt.xlabel("Iteration")
plt.ylabel("SSIM")
plt.title("SSIM Over Iterations")
plt.grid(True)

# PSNR Plot
plt.subplot(1, 4, 2)
plt.plot(iterations, psnr_scores, label="PSNR", marker="o", color="green")
plt.xlabel("Iteration")
plt.ylabel("PSNR")
plt.title("PSNR Over Iterations")
plt.grid(True)

plt.subplot(1, 4, 3)
plt.plot(iterations, clip_scores, label="CLIP Score", marker='o', color='purple')
plt.xlabel("Iteration")
plt.ylabel("CLIP Score")
plt.title("CLIP Score Over Iterations")
plt.grid(True)


# LPIPS Plot
plt.subplot(1, 4, 4)
plt.plot(iterations, lpips_scores, label="LPIPS", marker="o", color="orange")
plt.xlabel("Iteration")
plt.ylabel("LPIPS")
plt.title("LPIPS Over Iterations")
plt.grid(True)

plt.tight_layout()
plt.show()
