In [9]:
import numpy as np
from scipy import stats
from lol import lol_iid

# The object shape of the data and latents in Stable Diffusion 3.
# Note that the "data" in this model are not images, but VAE latents.
object_shape = [16, 96, 96]
# The latent distribution (of all elements) in the Stable Diffusion 3.
latent_distribution = stats.norm(loc=0, scale=1)

num_latent_dims = int(np.prod(object_shape))
num_interpolation_points = 10

# Interpolation
seeds = latent_distribution.rvs(size=[2, num_latent_dims])
interpolation_weight = np.linspace(0, 1, num=num_interpolation_points)
weights = np.array([[1 - w, w] for w in interpolation_weight])
latents = lol_iid(
  w=weights,
  X=seeds,
  cdf=latent_distribution.cdf,
  inverse_cdf=latent_distribution.ppf
)

In [10]:
from diffusers import StableDiffusion3Pipeline, DDIMInverseScheduler, AutoencoderKL, DDIMScheduler
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_id = "stabilityai/stable-diffusion-3-medium-diffusers"
dtype = torch.float16

pipeline = StableDiffusion3Pipeline.from_pretrained(model_id, torch_dtype=dtype)
pipeline = pipeline.to(device)
num_inference_steps = 50
guidance_scale = 7.0
height, width = 768, 768

image_per_latent = []
for i, latent in enumerate(latents):
  latent_object = np.reshape(latent, (1, *object_shape))
  with torch.no_grad():
    image = pipeline(
      prompt="A duck in a pond with a reflection, photographed with a high-resolution DSLR camera.",
      latents=torch.tensor(latent_object, device=device, dtype=dtype),
      height=height,
      width=width,
      guidance_scale=guidance_scale,
      num_inference_steps=num_inference_steps,
      num_images_per_prompt=1,
    ).images[0]
    torch.cuda.empty_cache()
  image.save(f"/tmp/interpolation_image_{i}.png")
  image_per_latent.append(image)

Loading pipeline components...:  11%|█         | 1/9 [00:00<00:01,  4.22it/s]
Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A
Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  3.86it/s][A
Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  3.89it/s][A
Loading pipeline components...: 100%|██████████| 9/9 [00:01<00:00,  5.65it/s]
100%|██████████| 50/50 [00:09<00:00,  5.53it/s]
100%|██████████| 50/50 [00:09<00:00,  5.41it/s]
100%|██████████| 50/50 [00:08<00:00,  5.84it/s]
100%|██████████| 50/50 [00:08<00:00,  6.16it/s]
100%|██████████| 50/50 [00:08<00:00,  6.07it/s]
100%|██████████| 50/50 [00:08<00:00,  5.91it/s]
100%|██████████| 50/50 [00:08<00:00,  5.93it/s]
100%|██████████| 50/50 [00:08<00:00,  6.13it/s]
100%|██████████| 50/50 [00:08<00:00,  6.00it/s]
100%|██████████| 50/50 [00:08<00:00,  6.11it/s]


In [11]:
from utilities import plot_image_grid

plot_image_grid(
  images=[np.array(image) for image in image_per_latent],
  filepath="/tmp/interpolation.png"
)