In [None]:
import torch.nn as nn
import torch
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
from diffusers import StableDiffusionImg2ImgPipeline

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
# Load CLIP model and processor
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

# Load Stable Diffusion img2img pipeline
stable_diffusion = StableDiffusionImg2ImgPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16
).to("cuda")

### Functions for Pipeline

In [None]:
projection_layer = nn.Linear(512, 768).to("cuda")

def generate_image_with_sketch_and_embeddings(
    input_image_path, prompt, strength=0.3, guidance_scale=7.5
):
    # Load and preprocess the input sketch
    input_image = Image.open(input_image_path).convert("RGB")
        
    # Preprocess the image and prompt for CLIP
    inputs = clip_processor(text=[prompt], images=input_image, return_tensors="pt", padding=True)

    # Generate CLIP embeddings
    with torch.no_grad():
        image_features = clip_model.get_image_features(pixel_values=inputs["pixel_values"]).to("cuda")
        text_features = clip_model.get_text_features(input_ids=inputs["input_ids"]).to("cuda")

    # Normalize and combine embeddings
    image_features = image_features / image_features.norm(dim=-1, keepdim=True)
    text_features = text_features / text_features.norm(dim=-1, keepdim=True)
    combined_embeddings = (image_features + text_features) / 2.0

    # Project embeddings to match Stable Diffusion's expected dimension
    projected_embeddings = projection_layer(combined_embeddings)

    # Reshape embeddings to match Stable Diffusion's expected shape
    projected_embeddings = projected_embeddings.unsqueeze(1)  # Add seq_len dimension (batch_size, seq_len, embed_dim)

    # Create negative_prompt_embeds (typically all zeros or empty embeddings)
    negative_prompt_embeds = torch.zeros_like(projected_embeddings)

    # Generate the image using the img2img pipeline
    generated_images = stable_diffusion(
        prompt_embeds=projected_embeddings,
        negative_prompt_embeds=negative_prompt_embeds,  # Provide negative embeddings
        image=input_image,  # Ensure this is a PIL.Image.Image object
        strength=strength,  # Control the level of deviation from the input image
        guidance_scale=guidance_scale,  # Control adherence to the prompt
    )["images"]

    return generated_images[0] 


### Generate and Display Image

In [None]:
input_image_path = "../input_images/00002.jpg"
prompt = "The person is described as male around 30-35 years old with an oval face, defined cheekbones, medium-length straight hair, and light skin tone."

# Generate the image
generated_image = generate_image_with_sketch_and_embeddings(
    input_image_path, prompt, strength=0.3, guidance_scale=7.5
)

# Save the generated imagea
generated_image.save("output_with_sketch_and_embeddings.png")
display(generated_image)
print("Image saved as 'output_with_sketch_and_embeddings.png'")