In [None]:
from transformers import CLIPProcessor, CLIPModel
from diffusers import StableDiffusionImg2ImgPipeline
from PIL import Image
import torch
import torch.nn as nn

### Functions for Pipeline

In [None]:
class LoRALinear(nn.Module):
    def __init__(self, original_linear, r, lora_alpha, lora_dropout):
        super().__init__()
        self.original_linear = original_linear
        self.r = r
        self.lora_alpha = lora_alpha
        self.lora_dropout = nn.Dropout(lora_dropout)
        self.lora_A = nn.Parameter(torch.zeros(original_linear.in_features, r))
        self.lora_B = nn.Parameter(torch.zeros(r, original_linear.out_features))
        self.scaling = lora_alpha / r

        # Initialize parameters
        nn.init.normal_(self.lora_A, std=0.02)
        nn.init.normal_(self.lora_B, std=0.02)

    def forward(self, x):
        result = self.original_linear(x)
        if self.r > 0:
            lora_result = (self.lora_dropout(x) @ self.lora_A @ self.lora_B) * self.scaling
            return result + lora_result
        return result


def apply_lora_to_model(model, r=16, lora_alpha=32, lora_dropout=0.1):
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear) and any(substr in name for substr in ['q_proj', 'v_proj', 'k_proj']):
            parent = model
            for name_part in name.split('.')[:-1]:
                parent = getattr(parent, name_part)
            layer_name = name.split('.')[-1]
            setattr(parent, layer_name, LoRALinear(module, r, lora_alpha, lora_dropout))


def setup_clip_model(checkpoint_path, device):
    model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
    processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
    apply_lora_to_model(model)  # Apply LoRA modifications
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'], strict=True)
    model = model.to(device).eval()
    return model, processor

projection_layer = nn.Linear(512, 768).to("cuda")

def generate_image_with_sketch_and_embeddings(
    input_image_path, prompt, clip_model, clip_processor, stable_diffusion, device="cuda", strength=0.3, guidance_scale=7.5
):
    # Load and preprocess the input sketch
    input_image = Image.open(input_image_path).convert("RGB")

    # Preprocess the image and prompt for CLIP
    inputs = clip_processor(text=[prompt], images=input_image, return_tensors="pt", padding=True).to(device)

    # Generate CLIP embeddings
    with torch.no_grad():
        image_features = clip_model.get_image_features(pixel_values=inputs["pixel_values"])
        text_features = clip_model.get_text_features(input_ids=inputs["input_ids"])

    # Normalize and combine embeddings
    image_features = image_features / image_features.norm(dim=-1, keepdim=True)
    text_features = text_features / text_features.norm(dim=-1, keepdim=True)
    combined_embeddings = (image_features + text_features) / 2.0
    combined_embeddings = combined_embeddings / combined_embeddings.norm(dim=-1, keepdim=True)

    # Project embeddings to match Stable Diffusion's expected dimension
    projected_embeddings = projection_layer(combined_embeddings)

    # Reshape embeddings to match Stable Diffusion's expected shape
    projected_embeddings = projected_embeddings.unsqueeze(1)  # Add seq_len dimension (batch_size, seq_len, embed_dim)

    # Create negative_prompt_embeds (typically all zeros or empty embeddings)
    negative_prompt_embeds = torch.zeros_like(projected_embeddings)
    
    # Generate the image using the Stable Diffusion pipeline
    generated_images = stable_diffusion(
          prompt_embeds=projected_embeddings,
          negative_prompt_embeds=negative_prompt_embeds,  # Provide negative embeddings
          image=input_image,  # Ensure this is a PIL.Image.Image object
          strength=strength,  # Control the level of deviation from the input image
          guidance_scale=guidance_scale,  # Control adherence to the prompt
      )["images"]


    return generated_images[0]

### Generate and Display Image

In [None]:
device = "cuda"
checkpoint_path = "/content/drive/My Drive/clip_checkpoint_epoch_20.pt"
clip_model, clip_processor = setup_clip_model(checkpoint_path, device)
stable_diffusion = StableDiffusionImg2ImgPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16
).to("cuda")

# Example usage
text_prompt = "The person is described as male around 30-35 years old with an oval face, defined cheekbones, medium-length straight hair, and light skin tone."
image_path = "/content/00002.jpg"
generated_image = generate_image_with_sketch_and_embeddings(
    image_path, text_prompt, clip_model, clip_processor, stable_diffusion, device
)

generated_image.save("output_with_sketch_and_embeddings.png")
display(generated_image)
print("Image saved as 'output_with_sketch_and_embeddings.png'")