# LATENT DIFFUSION

---

In [None]:
import sys
sys.path.append('..')

import random
from dataclasses import dataclass

import torch
from PIL import Image
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

from arch import diffusion, utils

In [None]:
DEVICE = utils.device_mapper()
print(f"Device: {str(DEVICE).upper()}")

In [None]:
pipe = diffusion.LatentDiffusion(DEVICE)

---

### TEXT2IMAGE

In [None]:
@dataclass
class HyperConfig:
    prompt = "a photograph of a barn on fire, cinematic, film grain, analog, 70mm, technicolor, 4K, IMAX"
    negative_prompt = "black and white"
    w, h = 1280//2, 720//2
    infer_steps = 50
    cfg_scale = 7.0
    batch_size = 1

config = HyperConfig()

In [None]:
images = []
for i in tqdm(range(9)):
    latents = pipe.generate(config, random.randint(0, 1e6), 'txt2img')
    img = pipe.decode(latents)
    images.append(img.numpy())

In [None]:
utils.contact_layer(images, 3, 3)

---

### ENCODING & DECODING

In [None]:
# input_image = Image.open('data/macaw.jpg').resize((512, 512))
input_image = utils.load_image('data/macaw.jpg')
utils.array2image(input_image)

In [None]:
# Encode to the latent space
encoded = pipe.encode(input_image)
encoded.shape

In [None]:
# Four channels of latent representation
fig, axs = plt.subplots(1, 4, figsize=(16, 4))
for c in range(4):
    axs[c].imshow(encoded[0][c].cpu(), cmap='Greys')

In [None]:
# Decode to image
decoded = pipe.decode(encoded)
utils.array2image(decoded.numpy())

---

### IMG2IMG

In [None]:
@dataclass
class HyperConfig:
    prompt = "A colorful dancer, nat geo photo"
    negative_prompt = ""
    w, h = 512, 512
    infer_steps = 50
    sampling_step = 10
    cfg_scale = 8.0
    batch_size = 1

config = HyperConfig()

In [None]:
pipe.scheduler.set_timesteps(config.infer_steps)

plt.figure(figsize=(8,4))
plt.plot(pipe.scheduler.sigmas)
plt.title('Noise Schedule'), plt.xlabel('Sampling Step'), plt.ylabel('Sigma')
plt.xlim([0, len(pipe.scheduler.timesteps)])
plt.axvline(x=config.sampling_step, color='red', lw='0.1')
plt.grid(alpha=0.25), plt.tight_layout()

In [None]:
noise = torch.randn_like(encoded, dtype=torch.float16)
encoded_and_noised = pipe.scheduler.add_noise(encoded, noise, timesteps=torch.tensor([pipe.scheduler.timesteps[config.sampling_step]]))

img = pipe.decode(encoded_and_noised)
utils.array2image(img.numpy())

In [None]:
latents = pipe.generate(config, 12, 'img2img', encoded)
img = pipe.decode(latents)

In [None]:
utils.array2image(img.numpy())