# Stable diffusion pipeline

Here is a sample basic [stable diffusion pipeline from huggingface](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0).


Make sure to close the notebook and stop the kernel as this notebook uses multiple GPUs for image generation.

In [None]:
# pip install diffusers invisible_watermark transformers accelerate safetensors

In [None]:
from diffusers import DiffusionPipeline
import torch

In [None]:
DEVICE1 = "cuda:1"
DEVICE2 = "cuda:2"
num_images_per_prompt = 2

In [None]:
from diffusers import DiffusionPipeline
import torch

# load base
base = DiffusionPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    torch_dtype=torch.float16,
    variant="fp16",
    use_safetensors=True,
)


# Define how many steps and what % of steps to be run on each experts (80/20) here
n_steps = 40
high_noise_frac = 0.8

In [None]:
prompt = "An astronaut riding a green horse"
# prompt = "Photograph of an astronaut riding an orange unicorn"
# prompt = "A majestic lion jumping from a big stone at night"
# prompt = "Three ML engineers discussing about stable diffusion"
# prompt = "A table lamp in the shape of an aeroplane with dim blue light"

In [None]:
torch.cuda.empty_cache()

# run both experts
base = base.to(DEVICE1)
base_image = base(
    prompt=prompt,
    num_inference_steps=n_steps,
    denoising_end=high_noise_frac,
    output_type="latent",
    num_images_per_prompt=num_images_per_prompt,
    height=1024,
    width=1024,
).images

In [None]:
# load refiner
refiner = DiffusionPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-refiner-1.0",
    text_encoder_2=base.text_encoder_2,
    vae=base.vae,
    torch_dtype=torch.float16,
    use_safetensors=True,
    variant="fp16",
)

In [None]:
image = base_image.to(DEVICE2)
refiner = refiner.to(DEVICE2)

image = refiner(
    prompt=prompt,
    num_inference_steps=n_steps,
    denoising_start=high_noise_frac,
    image=image,
    num_images_per_prompt=num_images_per_prompt,
).images

In [None]:
torch.cuda.empty_cache()

In [None]:
image[0].resize((512, 512))

In [None]:
image[1].resize((512, 512))