In [1]:
!nvidia-smi

Fri Nov 25 08:28:47 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 526.98       Driver Version: 526.98       CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name            TCC/WDDM | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA GeForce ... WDDM  | 00000000:01:00.0  On |                  N/A |
|  0%   48C    P8    44W / 350W |   2327MiB / 24576MiB |     45%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [2]:
device = "cuda"
model_id = "CompVis/stable-diffusion-v1-4"

In [3]:
import torch
from diffusers import StableDiffusionPipeline

pipe = StableDiffusionPipeline.from_pretrained(
    model_id,
    use_auth_token=True
).to(device)

Fetching 16 files:   0%|          | 0/16 [00:00<?, ?it/s]

In [4]:
from PIL import Image

def image_grid(imgs, rows, cols):
    assert len(imgs) == rows*cols

    w, h = imgs[0].size
    grid = Image.new('RGB', size=(cols*w, rows*h))

    for i, img in enumerate(imgs):
        grid.paste(img, box=(i%cols*w, i//cols*h))
    return grid


In [5]:
def gen_latents(num_latents, seeds=None, height=512, width=512):
    generator = torch.Generator(device="cuda")
    seeds = [] if seeds is None else seeds
    latents = None
    for i in range(num_latents):
        seed = generator.seed() if (i >= len(seeds)) else seeds[i]
        if i >= len(seeds):
            seeds.append(seed)

        generator = generator.manual_seed(seed)
        image_latents = torch.randn(
            (1, pipe.unet.in_channels, height // 8, width // 8),
            generator = generator,
            device = device
        )
        latents = image_latents if latents is None else torch.cat((latents, image_latents))
    return [seeds, latents]

In [7]:
def gen_images(pipe, prompt, num_images, seeds=None, height=512, width=512):
    [seeds, latents] = gen_latents(num_images, seeds, height, width)

    images = pipe(
        [prompt] * num_images,
        guidance_scale=7.5,
        latents = latents
    )['images']
    return [seeds, images]