# LATENT DIFFUSION

---

In [None]:
import sys, warnings
sys.path.append('..')
warnings.filterwarnings('ignore')

import random
from dataclasses import dataclass

import torch
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

from arch import utils, diffusion, ops

In [None]:
DEVICE = utils.device_mapper()
print(f"Device: {str(DEVICE).upper()}")

In [None]:
pipe = diffusion.LatentDiffusion(DEVICE)

---

### TEXT2IMAGE

In [None]:
@dataclass
class HyperConfig:
    prompt = "a photograph of a barn on fire, cinematic, film grain, analog, 70mm, technicolor, 4K, IMAX"
    negative_prompt = "black and white"
    w, h = 1280//2, 720//2
    infer_steps = 50
    cfg_scale = 7.0
    batch_size = 1

config = HyperConfig()

In [None]:
images = []
for i in tqdm(range(9)):
    latents = pipe.generate(config, random.randint(0, 1e6), 'txt2img')
    img = pipe.decode(latents)
    images.append(utils.array2image(img.numpy()))

In [None]:
utils.contact_layer(images, 3, 3)

---

### ENCODING & DECODING

In [None]:
init_image = utils.load_image('data/macaw.jpg')
x = utils.image2array(init_image)

In [None]:
encoded = pipe.encode(x)
encoded.shape

In [None]:
fig, axs = plt.subplots(1, 4, figsize=(16, 4))
for c in range(4):
    axs[c].imshow(encoded[0][c].cpu(), cmap='Greys')

In [None]:
decoded = pipe.decode(encoded).numpy()

In [None]:
diff = ops.diff(x, decoded)

utils.contact_layer([init_image, utils.array2image(decoded), utils.array2image(diff)], 1, 3, ["source", "decoded", "diff"])

In [None]:
from skimage.metrics import peak_signal_noise_ratio
psnr = peak_signal_noise_ratio(x, decoded)
print(f"PSNR: {psnr}")

---

### IMG2IMG

In [None]:
@dataclass
class HyperConfig:
    prompt = "a colorful dancer, nat geo photo"
    negative_prompt = ""
    w, h = 512, 512
    infer_steps = 50
    sampling_step = 10
    cfg_scale = 8.0
    batch_size = 1

config = HyperConfig()

In [None]:
# init_image = utils.load_image('data/poison.png')
# x = utils.image2array(init_image)
# encoded = pipe.encode(x)

In [None]:
pipe.scheduler.set_timesteps(config.infer_steps)

plt.figure(figsize=(8,4))
plt.plot(pipe.scheduler.sigmas)
plt.title('Noise Schedule'), plt.xlabel('Sampling Step'), plt.ylabel('Sigma')
plt.xlim([0, len(pipe.scheduler.timesteps)])
plt.axvline(x=config.sampling_step, color='red', lw='0.1')
plt.grid(alpha=0.25), plt.tight_layout()

In [None]:
noise = torch.randn_like(encoded, dtype=torch.float16)
encoded_and_noised = pipe.scheduler.add_noise(encoded, noise, timesteps=torch.tensor([pipe.scheduler.timesteps[config.sampling_step]]))

img = pipe.decode(encoded_and_noised)
utils.array2image(img.numpy())

In [None]:
latents = pipe.generate(config, 12, 'img2img', encoded)
img = pipe.decode(latents)

In [None]:
utils.array2image(img.numpy())