In [None]:
ENV_TYPE = "TEST"

if(ENV_TYPE != "TEST"):
  !git clone "https://github.com/kk-digital/kcg-ml-sd1p4.git"
  %cd kcg-ml-sd1p4
  !pip3 install -r requirements.txt
  exit()
  base_directory = "./"
else:
  base_directory = "../"

# Magical check for fixing all of our directory issues
import subprocess
output = subprocess.check_output(["pwd"], universal_newlines=True)
if "notebooks" in output:
    %cd ..
del output

In [None]:
!python3 ./download_models.py

In [None]:
!python3 ./process_models.py

In [None]:
import os
import sys
import torch
import time
import shutil
from torchvision.transforms import ToPILImage
from os.path import join

base_directory = "../"
sys.path.insert(0, base_directory)

output_base_dir = join(base_directory, "./output/sd2-notebook/")
output_directory = join(output_base_dir, "demo/")

try:
    shutil.rmtree(output_directory)
except Exception as e:
    print(e, "\n", "Creating the path...")
    os.makedirs(output_directory, exist_ok=True)
else:
    os.makedirs(output_directory, exist_ok=True)

from stable_diffusion import StableDiffusion
from stable_diffusion.utils_backend import *
from stable_diffusion.utils_image import *
from stable_diffusion.utils_model import *
from stable_diffusion.model.clip_image_encoder import CLIPImageEncoder

from stable_diffusion.model_paths import *
from configs.model_config import ModelPathConfig

device = get_device()
to_pil = lambda image: ToPILImage()(torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0))

In [None]:
base_dir = os.getcwd()
sys.path.insert(0, base_dir)

batch_size = 1
model_config = ModelPathConfig()
pt = IODirectoryTree(model_config)

print(
    pt.autoencoder
)  # should be `.../kcg-ml-sd1p4/input/model/autoencoder/autoencoder.ckpt`

In [None]:
# define the logistic distribution
def logistic_distribution(loc, scale):
    base_distribution = torch.distributions.Uniform(0, 1)
    transforms = [
        torch.distributions.transforms.SigmoidTransform().inv,
        torch.distributions.transforms.AffineTransform(loc=loc, scale=scale),
    ]
    logistic = torch.distributions.TransformedDistribution(
        base_distribution, transforms
    )
    return logistic


noise_fn = (
    lambda shape, device=device: logistic_distribution(loc=0.0, scale=0.5)
    .sample(shape)
    .to(device)
)

In [None]:
# Load Stable Diffusion
DEVICE = get_device()
N_STEPS = 25


sd = StableDiffusion(device=DEVICE, n_steps=N_STEPS)
sd.quick_initialize()
sd.model.load_submodel_tree()

In [None]:
# choose a temperature for the sampling (in general higher means more diversity but less quality) and generate an image, then save it and show it
# temperature only makes any difference if `ddim_eta` is different from zero

temperature = 1.0
imgs = sd.generate_images(
    prompt="A woman with flowers in her hair in a courtyard, in the style of Frank Frazetta",
    negative_prompt="Ugly, weird",
    seed=2982,
    noise_fn=noise_fn,
    temperature=temperature,
)
save_images(
    imgs,
    join(
        output_directory,
        f"test_sample_temp{temperature:.3f}_eta{sd.ddim_eta:.3f}.png",
    ),
)
to_pil(imgs[0])

In [None]:
# change the ddim_eta parameter and generate another image, then save it and show it
sd.ddim_eta = 0.1
temperature = 1.0
imgs = sd.generate_images(
    prompt="A woman with flowers in her hair in a courtyard, in the style of Frank Frazetta",
    negative_prompt="Ugly, weird",
    seed=2982,
    noise_fn=noise_fn,
    temperature=temperature,
)

save_images(
    imgs,
    join(
        output_directory,
        f"test_sample_temp{temperature:.3f}_eta{sd.ddim_eta:.3f}.png",
    ),
)
to_pil(imgs[0])

In [None]:
# higher `ddim_eta`s imply higher noise levels
sd.ddim_eta = 0.5
temperature = 1.0
imgs = sd.generate_images(
    prompt="A woman with flowers in her hair in a courtyard, in the style of Frank Frazetta",
    negative_prompt="Ugly, weird",
    seed=2982,
    noise_fn=noise_fn,
    temperature=temperature,
)

save_images(
    imgs,
    join(
        output_directory,
        f"test_sample_temp{temperature:.3f}_eta{sd.ddim_eta:.3f}.png",
    ),
)
to_pil(imgs[0])

In [None]:
# and so do higher temperatures
sd.ddim_eta = 0.5
temperature = 1.8
imgs = sd.generate_images(
    prompt="A woman with flowers in her hair in a courtyard, in the style of Frank Frazetta",
    negative_prompt="Ugly, weird",
    seed=2982,
    noise_fn=noise_fn,
    temperature=temperature,
)

save_images(
    imgs,
    join(
        output_directory,
        f"test_sample_temp{temperature:.3f}_eta{sd.ddim_eta:.3f}.png",
    ),
)
to_pil(imgs[0])

In [None]:
# we can check how the images change with the ddim_eta parameter
temperature = 1.0
images = []
eta_steps = 5
eta_0 = 0.0
for i in range(eta_steps):
    ddim_eta = eta_0 + i * 0.1
    sd.ddim_eta = ddim_eta
    imgs = sd.generate_images(
        prompt="A woman with flowers in her hair in a courtyard, in the style of Frank Frazetta",
        negative_prompt="Ugly, weird",
        seed=2982,
        noise_fn=noise_fn,
        temperature=temperature,
    )
    print(imgs.shape)
    images.append(imgs)
images = torch.cat(images, dim=0)
grid = torchvision.utils.make_grid(
    images, normalize=False, range=(-1, 1), scale_each=True, pad_value=0
)
grid_img = to_pil(grid)
grid_img.save(
    join(
        output_directory,
        f"test_grid_temp{temperature:.3f}_eta{eta_0:.3f}-{sd.ddim_eta:.3f}.png",
    )
)
grid_img

In [None]:
# or we can check how the images change with the temperature alone
temperature = 1.0
sd.ddim_eta = 0.1
images = []
temp_steps = 5
for i in range(temp_steps):
    temperature += 0.1
    imgs = sd.generate_images(
        prompt="A woman with flowers in her hair in a courtyard, in the style of Frank Frazetta",
        negative_prompt="Ugly, weird",
        seed=2982,
        noise_fn=noise_fn,
        temperature=temperature,
    )
    print(imgs.shape)
    images.append(imgs)
images = torch.cat(images, dim=0)
grid = torchvision.utils.make_grid(
    images,
    normalize=False,
    nrow=temp_steps,
    range=(-1, 1),
    scale_each=True,
    pad_value=0,
)
grid_img = to_pil(grid)
grid_img.save(
    join(
        output_directory,
        f"test_grid_temp{temperature-temp_steps*0.1:.3f}-{temperature:.3f}_eta{sd.ddim_eta:.3f}.png",
    )
)
grid_img

In [None]:
# or we can vary both things simultaneously (ddim_eta on y-axis, temperature on x-axis; it increases from left to right and from top to bottom)
grid_side = 2
temperature = 1.0
ddim_eta = 0.1
grid = []
# rows = []
for i in range(grid_side + 1):
    temperature += 0.2
    for j in range(grid_side + 1):
        sd.ddim_eta = ddim_eta + j * 0.1
        imgs = sd.generate_images(
            prompt="A woman with flowers in her hair in a courtyard, in the style of Frank Frazetta",
            negative_prompt="Ugly, weird",
            seed=2982,
            noise_fn=noise_fn,
            temperature=temperature,
        )
        # rows.append(imgs)
        grid.append(imgs)
    # grid.append(torch.cat(rows, dim=0))

tensor_grid = torch.cat(grid, dim=0)
tensor_grid.shape
grid = torchvision.utils.make_grid(
    tensor_grid,
    nrow=grid_side + 1,
    normalize=False,
    range=(-1, 1),
    scale_each=True,
    pad_value=0,
)
grid_img = to_pil(grid)
grid_img.save(
    join(
        output_directory,
        f"test_grid_temp{temperature-grid_side*0.2:.3f}-{temperature:.3f}_eta{ddim_eta:.3f}-{sd.ddim_eta:.3f}.png",
    )
)
grid_img