In [None]:
# !pip install diffusers transformers accelerate safetensors opencv-python
# !pip install --upgrade diffusers transformers accelerate

In [None]:
import torch
from diffusers import StableDiffusionXLControlNetPipeline,ControlNetModel,AutoencoderKL
from diffusers.utils import load_image
from PIL import Image
import cv2
import numpy as np

## Control Net + stable diffusion + refiner

In [None]:
# Load your ControlNet condition image (e.g., edge map or sketch)
def load_canny_image(path):
    image = cv2.imread(path)
    image = cv2.Canny(image, 100, 200) #Applies Canny edge detection (highlighting edges as white lines on black)
    image = Image.fromarray(image)
    image = image.convert("RGB").resize((768, 768))
    return image

# Step 1: Load ControlNet model (e.g., canny edges)
controlnet = ControlNetModel.from_pretrained(
    "diffusers/controlnet-canny-sdxl-1.0", torch_dtype=torch.float16
)

# Step 2: Load SDXL base pipeline with ControlNet
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    controlnet=controlnet,
    torch_dtype=torch.float16,
    variant="fp16",
    use_safetensors=True,
)


# # Step 3: Load SDXL Refiner
# refiner = StableDiffusionXLRefinerPipeline.from_pretrained(
#     "stabilityai/stable-diffusion-xl-refiner-1.0",
#     torch_dtype=torch.float16,
#     variant="fp16",
#     use_safetensors=True,
# )


# Step 4: Load your sketch/condition image
control_image = load_canny_image("final_asteria.png")

# Step 5: Define your prompt
prompt = (
    "An isometric exploded diagram of a lotus leaf, with photorealistic top surface, "
    "wax nanorods below, stippled microstructure, and hemispherical domes in a grid. "
    "Neutral grey background, no text, minimal line art, scientific style."
)
negative_prompt = "blurry, distorted, shadows, artistic style, text, background scenery"

# Step 6: Generate base image with ControlNet + SDXL
base_image = pipe(
    prompt=prompt,
    negative_prompt=negative_prompt,
    image=control_image,
    num_inference_steps=20,
    guidance_scale=8.0,
    controlnet_conditioning_scale=1.0,
    output_type="pil",
    denoising_end=0.8,  # Leave 20% for the refiner
).images[0]

# # Step 7: Refine the image using SDXL Refiner
# refined_image = refiner(
#     prompt=prompt,
#     image=base_image,
#     num_inference_steps=20,
#     guidance_scale=5.0,
#     denoising_start=0.8,
# ).images[0]

# Step 8: Save result
base_image.save("refined_lotus_leaf_diagram.png")
