In [None]:
# Waifu Diffusion
#!wget https://thisanimedoesnotexist.ai/downloads/wd-v1-2-full-ema.ckpt

#Stable Diffusion
#!wget https://r2-public-worker.drysys.workers.dev/sd-v1-4-full-ema.ckpt

#Poke Diffusion
!wget https://sd-finetune.vonk.workers.dev/pokediffusion_ckpts/pokediffusion_epoch_10_pruned.ckpt

!git clone https://github.com/CompVis/stable-diffusion.git
%cd stable-diffusion
!wget https://raw.githubusercontent.com/justinpinkney/stable-diffusion/main/requirements.txt
!pip install -r requirements.txt
!pip install --upgrade pytorch-lightning
!apt-get update -y && apt-get install libgl1 -y && apt-get install libglib2.0-0 -y

In [None]:
# !! Restart your notebook here !!

In [1]:
%cd stable-diffusion
# Waifu Diffusion
#ckpt_file = "wd-v1-2-full-ema.ckpt"

#Stable Diffusion
#ckpt_file = "sd-v1-4-full-ema.ckpt"

#Poke Diffusion
ckpt_file = "pokediffusion_epoch_10_pruned.ckpt"

/workspace/stable-diffusion


In [2]:
from io import BytesIO
import os
from contextlib import nullcontext

import fire
import numpy as np
import torch
from einops import rearrange
from omegaconf import OmegaConf
from PIL import Image
from torch import autocast
from torchvision import transforms
import requests

from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
from ldm.util import instantiate_from_config
from pytorch_lightning import seed_everything

In [3]:
def load_model_from_config(config, ckpt, verbose=False):
    print(f"Loading model from {ckpt}")
    pl_sd = torch.load(ckpt, map_location="cpu")
    if "global_step" in pl_sd:
        print(f"Global Step: {pl_sd['global_step']}")
    sd = pl_sd["state_dict"]
    model = instantiate_from_config(config.model)
    m, u = model.load_state_dict(sd, strict=False)
    if len(m) > 0 and verbose:
        print("missing keys:")
        print(m)
    if len(u) > 0 and verbose:
        print("unexpected keys:")
        print(u)

    model.cuda()
    model.eval()
    return model

In [None]:
config = OmegaConf.load("configs/stable-diffusion/v1-inference.yaml")
model = load_model_from_config(config, f"../{ckpt_file}")

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = model.to(device)

sampler = DDIMSampler(model)

sample_path = "outs"
os.makedirs(sample_path, exist_ok=True)

start_code = None

In [5]:
import random

In [7]:
import time

In [None]:
seeds = [random.randint(1, 2000000) for x in range(8)]
for seed in seeds:
    batch_size = 1

    prompt = "official art of a pokemon Abomasnow, Ice, Grass"
    scale = 7.5
    C = 4
    H = 512
    W = 512
    f = 8
    data = [batch_size * [prompt]]

    seed_everything(seed)

    precision_scope = autocast
    with torch.no_grad():
        with precision_scope("cuda"):
            with model.ema_scope():
                tic = time.time()
                all_samples = list()
                for n in range(1):
                    for prompts in data:
                        uc = None
                        if scale != 1.0:
                            uc = model.get_learned_conditioning(batch_size * [""])
                        if isinstance(prompts, tuple):
                            prompts = list(prompts)
                        c = model.get_learned_conditioning(prompts)
                        shape = [C, H // f, W // f]
                        samples_ddim, _ = sampler.sample(S=50,
                                                         conditioning=c,
                                                         batch_size=1,
                                                         shape=shape,
                                                         verbose=False,
                                                         unconditional_guidance_scale=scale,
                                                         unconditional_conditioning=uc,
                                                         eta=0.0,
                                                         x_T=start_code)

                        x_samples_ddim = model.decode_first_stage(samples_ddim)
                        x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
                        x_samples_ddim = x_samples_ddim.cpu().permute(0, 2, 3, 1).numpy()

                        x_checked_image_torch = torch.from_numpy(x_samples_ddim).permute(0, 3, 1, 2)

                        for x_sample in x_checked_image_torch:
                            x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
                            img = Image.fromarray(x_sample.astype(np.uint8))
                            img.save(os.path.join(sample_path, f"{prompt[:25]}_{seed}.png"))

                toc = time.time()

Global seed set to 1290026


Data shape for DDIM sampling is (1, 4, 64, 64), eta 0.0
Running DDIM Sampling with 50 timesteps


DDIM Sampler: 100%|██████████| 50/50 [00:07<00:00,  6.83it/s]
Global seed set to 937160


Data shape for DDIM sampling is (1, 4, 64, 64), eta 0.0
Running DDIM Sampling with 50 timesteps


DDIM Sampler: 100%|██████████| 50/50 [00:07<00:00,  6.77it/s]
Global seed set to 979059


Data shape for DDIM sampling is (1, 4, 64, 64), eta 0.0
Running DDIM Sampling with 50 timesteps


DDIM Sampler: 100%|██████████| 50/50 [00:07<00:00,  6.82it/s]
Global seed set to 184899


Data shape for DDIM sampling is (1, 4, 64, 64), eta 0.0
Running DDIM Sampling with 50 timesteps


DDIM Sampler: 100%|██████████| 50/50 [00:07<00:00,  6.78it/s]
Global seed set to 591307


Data shape for DDIM sampling is (1, 4, 64, 64), eta 0.0
Running DDIM Sampling with 50 timesteps


DDIM Sampler: 100%|██████████| 50/50 [00:07<00:00,  6.81it/s]
Global seed set to 1527272


Data shape for DDIM sampling is (1, 4, 64, 64), eta 0.0
Running DDIM Sampling with 50 timesteps


DDIM Sampler: 100%|██████████| 50/50 [00:07<00:00,  6.80it/s]
Global seed set to 871330


Data shape for DDIM sampling is (1, 4, 64, 64), eta 0.0
Running DDIM Sampling with 50 timesteps


DDIM Sampler: 100%|██████████| 50/50 [00:07<00:00,  6.80it/s]
Global seed set to 1401714


Data shape for DDIM sampling is (1, 4, 64, 64), eta 0.0
Running DDIM Sampling with 50 timesteps


DDIM Sampler: 100%|██████████| 50/50 [00:07<00:00,  6.73it/s]
