In [7]:
import numpy as np
from scipy import stats
from lol import lol_iid

# The object shape of the data and latents in Stable Diffusion 3.
# Note that the "data" in this model are not images, but VAE latents.
object_shape = [16, 96, 96]
# The latent distribution (of all elements) in the Stable Diffusion 3.
latent_distribution = stats.norm(loc=0, scale=1)

num_latent_dims = int(np.prod(object_shape))

# We will create a 6D subspace based on 6 seed latents
seed_latents = latent_distribution.rvs(size=[6, num_latent_dims])

X = seed_latents
U, _ = np.linalg.qr(X.transpose())
X_pseudo_inverse = np.linalg.pinv(X)

h_lower_limits = np.min(X @ U, axis=0)
h_upper_limits = np.max(X @ U, axis=0)

grid_dimension_a = 2
grid_dimension_b = 4
num_grid_points_per_dim = 9

h_grid_in_dimension_a_and_b = np.stack(
  np.meshgrid(
    np.linspace(h_lower_limits[grid_dimension_a], h_upper_limits[grid_dimension_a], num=num_grid_points_per_dim),
    np.linspace(h_lower_limits[grid_dimension_b], h_upper_limits[grid_dimension_b], num=num_grid_points_per_dim),
  ),
  axis=-1
).reshape(-1, 2)
num_grid_points = len(h_grid_in_dimension_a_and_b)

default_latent_for_other_dimensions = seed_latents[0, :]
default_h = default_latent_for_other_dimensions @ U
default_h_grid = np.repeat(default_h[None, :], num_grid_points, axis=0)
h_grid = default_h_grid.copy()
h_grid[:, [grid_dimension_a, grid_dimension_b]] = h_grid_in_dimension_a_and_b

w_grid = (X_pseudo_inverse.T @ U @ h_grid.T).T
latents = lol_iid(
  w=w_grid,
  X=seed_latents,
  cdf=latent_distribution.cdf,
  inverse_cdf=latent_distribution.ppf
)

In [8]:
from diffusers import StableDiffusion3Pipeline, DDIMInverseScheduler, AutoencoderKL, DDIMScheduler
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_id = "stabilityai/stable-diffusion-3-medium-diffusers"
dtype = torch.float16

pipeline = StableDiffusion3Pipeline.from_pretrained(model_id, torch_dtype=dtype)
pipeline = pipeline.to(device)
num_inference_steps = 20
guidance_scale = 7.0
height, width = 768, 768

image_per_latent = []
for i, latent in enumerate(latents):
  latent_object = np.reshape(latent, (1, *object_shape))
  with torch.no_grad():
    image = pipeline(
      prompt="A duck in a pond with a reflection, photographed with a high-resolution DSLR camera.",
      latents=torch.tensor(latent_object, device=device, dtype=dtype),
      height=height,
      width=width,
      guidance_scale=guidance_scale,
      num_inference_steps=num_inference_steps,
      num_images_per_prompt=1,
    ).images[0]
    torch.cuda.empty_cache()
  image.save(f"/tmp/subspace_image_{i}.png")
  image_per_latent.append(image)

Loading pipeline components...:  89%|████████▉ | 8/9 [00:01<00:00,  7.02it/s]
Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A
Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  3.65it/s][A
Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  3.76it/s][A
Loading pipeline components...: 100%|██████████| 9/9 [00:01<00:00,  4.84it/s]
100%|██████████| 20/20 [00:03<00:00,  5.09it/s]
100%|██████████| 20/20 [00:03<00:00,  5.45it/s]
100%|██████████| 20/20 [00:03<00:00,  5.89it/s]
100%|██████████| 20/20 [00:03<00:00,  6.32it/s]
100%|██████████| 20/20 [00:03<00:00,  6.24it/s]
100%|██████████| 20/20 [00:03<00:00,  6.17it/s]
100%|██████████| 20/20 [00:03<00:00,  5.46it/s]
100%|██████████| 20/20 [00:03<00:00,  6.24it/s]
100%|██████████| 20/20 [00:03<00:00,  6.14it/s]
100%|██████████| 20/20 [00:03<00:00,  6.22it/s]
100%|██████████| 20/20 [00:03<00:00,  6.11it/s]
100%|██████████| 20/20 [00:03<00:00,  6.11it/s]
100%|██████████| 20/20 [00:03<00:00,  6.02it/s]


In [9]:
from utilities import plot_image_grid

grid_positions = np.meshgrid(
  np.arange(num_grid_points_per_dim),
  np.arange(num_grid_points_per_dim),
)
grid_positions = np.stack(grid_positions, axis=-1).reshape(-1, 2)

plot_image_grid(
  images=[np.array(image) for image in image_per_latent],
  filepath=f"/tmp/subspace_slice_dim_{grid_dimension_a}_and_dim_{grid_dimension_b}.png",
  grid_positions=grid_positions
)