# Notebook 1: Enhancing Images with ControlNet and Stable Diffusion

### 1. Install Required Libraries

In [None]:
# Run this cell to install necessary packages
!pip install --upgrade diffusers transformers torch torchvision ipywidgets
!pip install opencv-python matplotlib

#### 3.1 Optional: Install `accelerate` for faster and less memory-intense model loading.

In [None]:
!pip install accelerate

### 2. Import Libraries

In [None]:
import torch
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler
from PIL import Image
import cv2
import numpy as np
import matplotlib.pyplot as plt

In [None]:
### 3. Load Models

In [None]:
# Load the ControlNet model (Canny edge detection)
controlnet = ControlNetModel.from_pretrained(
    "lllyasviel/sd-controlnet-canny",
    torch_dtype=torch.float16
)

In [None]:
# Load the Stable Diffusion pipeline with ControlNet
pipe = StableDiffusionControlNetPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5",
    controlnet=controlnet,
    torch_dtype=torch.float16
)

In [None]:
# Enable GPU acceleration
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe.to(device)

In [None]:
# Use an efficient scheduler
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)

### 4. Prepare Input Image and Control Image

In [None]:
# Load the input image
input_image = Image.open("test-images/20240929_102048-EDIT.jpg").convert("RGB")
input_image = input_image.resize((512, 512))

plt.imshow(input_image)
plt.axis('off')  # Hide the axes
plt.show()

In [None]:
# Convert the image to a NumPy array
image_np = np.array(input_image)

In [None]:
# Apply Canny edge detection
low_threshold = 100
high_threshold = 200
edges = cv2.Canny(image_np, low_threshold, high_threshold)

In [None]:
# Convert edges to PIL Image
control_image = Image.fromarray(edges)

plt.imshow(control_image)
plt.axis('off')  # Hide the axes
plt.show()

### 5. Define Parameters

In [None]:
# Set the number of sampling steps
num_inference_steps = 50

# Define prompts and parameters
prompt = ""
negative_prompt = "low quality, blurry, deformed, bad anatomy"
guidance_scale = 7.5

### 6. Generate the Enhanced Image

In [None]:
# Generate the enhanced image
with torch.autocast(device):
    output = pipe(
        prompt=prompt,
        negative_prompt=negative_prompt,
        num_inference_steps=num_inference_steps,
        guidance_scale=guidance_scale,
        image=input_image,
        control_image=control_image
    )

enhanced_image = output.images[0]

### 7. Display and Save the Results

In [None]:
# Save the enhanced image
enhanced_image.save("output-images/enhanced_controlnet.jpg")

In [None]:
# Display the original and enhanced images
fig, ax = plt.subplots(1, 2, figsize=(12, 6))

ax[0].imshow(input_image)
ax[0].set_title("Original Image")
ax[0].axis("off")

ax[1].imshow(enhanced_image)
ax[1].set_title("Enhanced Image")
ax[1].axis("off")

plt.show()