SDE Drag pipeline

This pipeline provides drag-and-drop image editing using stochastic differential equations. It enables image editing by inputting prompt, image, mask_image, source_points, and target_points.See [paper](https://arxiv.org/abs/2311.01410), [paper page](https://ml-gsai.github.io/SDE-Drag-demo/), [original repo](https://github.com/ML-GSAI/SDE-Drag) for more information. This script was contributed by [Fengqi Zhu](https://github.com/MarkRich) and [NieShen](https://github.com/NieShenRuc).The notebook contributed by [Parag Ekbote](https://github.com/ParagEkbote).

In [11]:
pip install diffusers torch pillow requests torchvision

Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.3.1[0m[39;49m -> [0m[32;49m25.0[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


In [12]:
import torch
from diffusers import DDIMScheduler, DiffusionPipeline
from PIL import Image
import requests
from io import BytesIO
import numpy as np

# Load the pipeline
model_path = "stable-diffusion-v1-5/stable-diffusion-v1-5"
scheduler = DDIMScheduler.from_pretrained(model_path, subfolder="scheduler")
pipe = DiffusionPipeline.from_pretrained(model_path, scheduler=scheduler, custom_pipeline="sde_drag")

# Ensure the model is moved to the GPU
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe.to(device)

# Function to load image from URL
def load_image_from_url(url):
    response = requests.get(url)
    return Image.open(BytesIO(response.content)).convert("RGB")

# Function to prepare mask
def prepare_mask(mask_image):
    # Convert to grayscale
    mask = mask_image.convert("L")
    return mask

# Function to convert numpy array to PIL Image
def array_to_pil(array):
    # Ensure the array is in uint8 format
    if array.dtype != np.uint8:
        if array.max() <= 1.0:
            array = (array * 255).astype(np.uint8)
        else:
            array = array.astype(np.uint8)
    
    # Handle different array shapes
    if len(array.shape) == 3:
        if array.shape[0] == 3:  # If channels first
            array = array.transpose(1, 2, 0)
        return Image.fromarray(array)
    elif len(array.shape) == 4:  # If batch dimension
        array = array[0]
        if array.shape[0] == 3:  # If channels first
            array = array.transpose(1, 2, 0)
        return Image.fromarray(array)
    else:
        raise ValueError(f"Unexpected array shape: {array.shape}")

# Image and mask URLs
image_url = 'https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png'
mask_url = 'https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png'

# Load the images
image = load_image_from_url(image_url)
mask_image = load_image_from_url(mask_url)

# Resize images to a size that's compatible with the model's latent space
image = image.resize((512, 512))
mask_image = mask_image.resize((512, 512))

# Prepare the mask (keep as PIL Image)
mask = prepare_mask(mask_image)

# Provide the prompt and points for drag editing
prompt = "A cute dog"
source_points = [[32, 32]]  # Adjusted for 512x512 image
target_points = [[64, 64]]  # Adjusted for 512x512 image

# Generate the output image
output_array = pipe(
    prompt=prompt,
    image=image,
    mask_image=mask,
    source_points=source_points,
    target_points=target_points
)

# Convert output array to PIL Image and save
output_image = array_to_pil(output_array)
output_image.save("./output.png")
print("Output image saved as './output.png'")



Loading pipeline components...:   0%|          | 0/5 [00:00<?, ?it/s]

SDE Drag:   0%|          | 0/3 [00:00<?, ?it/s]

Output image saved as './output.png'
Output type: <class 'numpy.ndarray'>
Output shape: (512, 512, 3)
Output dtype: uint8
Output min/max values: 0, 255
