# Modifying frequency features for Latent diffusion models

- Youngseok Yoon

## Goals.

- Understand how different hyperparameters affects image generation.
- Modify the diffusion generation process based on frequency characteristics.

## Diffusion models and Stable diffusion

- Diffusion models work by gradually denoising images from Gaussian noises.

![diffusion](figs/diffusion.png)

## Diffusion models and Stable diffusion

- Stable Diffusion is one of the most famous latent diffusion models.

![ldm](figs/ldm.png)

## Classifier-free guidance

- Technique to provide condition (or control) for image generation.
- Way for text-to-image generation.

![cfg_equ](figs/cfg_equation.png)

![cfg_fig](figs/cfg_figure.png)

In [1]:

from PIL import Image, ImageDraw
import numpy as np
import torch
from torch import autocast
from pytorch_lightning import seed_everything

from scripts.txt2img import OmegaConf, load_model_from_config
from ldm.models.diffusion.dpm_solver import DPMSolverSampler

import os

class Args:
    def __init__(self,):
        pass
    
args = Args()

args.ckpt_name = "sd-v2-512-base-ema"
args.steps = 30
args.n_iter = 1
args.H = 512
args.W = 512
args.C = 4
args.f = 8
args.n_samples = 5
args.config = "configs/stable-diffusion/v2-inference.yaml"
args.ckpt = f"checkpoints/stable-diffusion-v2/{args.ckpt_name}.ckpt"
args.seed = 100
args.repeat = 1
args.device = "cuda"

seed_everything(args.seed)
precision_scope = autocast

device = torch.device("cuda") if args.device == "cuda" else torch.device("cpu")
start_code = torch.randn([args.n_samples, args.C, args.H // args.f, args.W // args.f], device=device)

config = OmegaConf.load(f"{args.config}")
model = load_model_from_config(config, f"{args.ckpt}", device, verbose=True)

sampler = DPMSolverSampler(model, device=device)

Global seed set to 100


Loading model from checkpoints/stable-diffusion-v2/sd-v2-512-base-ema.ckpt
Global Step: 875000
No module 'xformers'. Proceeding without it.
LatentDiffusion: Running in eps-prediction mode
DiffusionWrapper has 865.91 M params.
making attention of type 'vanilla' with 512 in_channels
Working with z of shape (1, 4, 32, 32) = 4096 dimensions.
making attention of type 'vanilla' with 512 in_channels




unexpected keys:
['model_ema.decay', 'model_ema.num_updates']


In [2]:
def post_process(imgs_prompts):
    imgs = torch.stack([torch.stack(imgs_prompt, 0) for imgs_prompt in imgs_prompts], 0)
    imgs = torch.clamp((imgs.to(torch.float32) + 1.0) / 2.0, min=0.0, max=1.0).float().numpy()
    imgs = (imgs[:, :, :, :3] * 255).round().astype("uint8")
    return imgs

def post_process_fft(imgs_prompts):
    imgs = torch.stack([torch.stack(imgs_prompt, 0) for imgs_prompt in imgs_prompts], 0)
    imgs_fft = torch.fft.fft2((imgs.to(torch.float32) + 1.0) / 2.0)
    imgs_fft = np.abs(torch.fft.fftshift(imgs_fft[:, :, :, :3], dim=[-2, -1]).numpy()).clip(0, 255).round().astype("uint8")
    return imgs_fft

def fourier(patch):
    fft = torch.fft.fft2(patch / 2 + 0.5)
    fft = torch.fft.fftshift(fft, dim=[-2, -1])
    return fft

In [3]:
prompt = "A photo of a ece graduate student"
prompt = "A photo of a pikachu"
prompt = "A photo of a classroom for Digital Image Processing course."
prompts = [prompt]

In [6]:
for scale in [9.0, 20.0]: #[3.0, 9.0, 20.0]:
    for low_pass in [8]: # [False, 0.3, 0.5, 0.7]:
        images_prompts, latents_prompts, cond_noises_prompts, uncond_noises_prompts, diffs_prompts = [], [], [], [], []
        for i, prompt in enumerate(prompts):
            with torch.no_grad(), precision_scope(args.device), model.ema_scope():
                shape = [args.C, args.H // args.f, args.W // args.f]
                # scale = scale

                if isinstance(prompt, list):
                    uc = model.get_learned_conditioning(len(prompt) * [""]).repeat(args.n_samples, 1, 1, 1)
                    c = model.get_learned_conditioning(prompt).repeat(args.n_samples, 1, 1, 1)
                    uc = [ten.squeeze(1) for ten in torch.split(uc, 1, 1)]
                    c = [ten.squeeze(1) for ten in torch.split(c, 1, 1)]
                else:
                    uc = model.get_learned_conditioning(args.n_samples * [""])
                    c = model.get_learned_conditioning(args.n_samples * [prompt])

                latents, (preds, noises, cond_noises, uncond_noises) = sampler.sample(
                    S=args.steps, 
                    conditioning=c, 
                    batch_size=args.n_samples, 
                    shape=shape,
                    unconditional_conditioning=uc,
                    unconditional_guidance_scale=scale,
                    need_all=True,
                    low_pass=low_pass,
                    x_T=start_code,)
                
                images = [model.decode_first_stage(latent) for latent in latents]

            images_prompts.append([image.cpu().to(torch.float32) for image in images])
            latents_prompts.append([latent.cpu().to(torch.float32) for latent in latents])
            # preds_prompts.append([pred.cpu().to(torch.float32) for pred in preds])
            # noises_prompts.append([noise.cpu().to(torch.float32) for noise in noises])
            cond_noises_prompts.append([cond_noise.cpu().to(torch.float32) for cond_noise in cond_noises])
            uncond_noises_prompts.append([uncond_noise.cpu().to(torch.float32) for uncond_noise in uncond_noises])
            diffs_prompts.append([cond_noise.cpu().to(torch.float32) - uncond_noise.cpu().to(torch.float32) for cond_noise, uncond_noise in zip(cond_noises, uncond_noises)])
            
        images = post_process(images_prompts)
        
        latents = post_process(latents_prompts)
        cond_noises = post_process(cond_noises_prompts)
        uncond_noises = post_process(uncond_noises_prompts)
        diff_noises = post_process(diffs_prompts)

        latents_fft = post_process_fft(latents_prompts)
        cond_noises_fft = post_process_fft(cond_noises_prompts)
        uncond_noises_fft = post_process_fft(uncond_noises_prompts)
        diff_noises_fft = post_process_fft(diffs_prompts)
        
        root = os.path.join("outputs", prompt, f"CFG_{scale}")
        if low_pass is not False:
            root += f"_gaussian_{low_pass}"   
        dics = {
            "image": images, "latent": latents, "latent_fft": latents_fft, "cond": cond_noises, "cond_fft": cond_noises_fft, "diff": diff_noises, "diff_fft": diff_noises_fft
        }
        for i in range(args.n_samples):
            os.makedirs(os.path.join(root), exist_ok=True)
            
            img = images[0, -1, i]
            img = Image.fromarray(img.transpose(1, 2, 0))
            img.save(os.path.join(root, f"img_{i}.png"))
            
            downsampled = []
            for name, img in dics.items():
                im = img[:, :, i]
                im = np.concatenate(np.split(im, args.steps + 1, 1), -1).squeeze(1)
                im = np.concatenate(np.split(im, len(prompts), 0), -2).squeeze(0)
                im = np.array(Image.fromarray(im.transpose(1, 2, 0)).resize((64 * (args.steps + 1), 64 * len(prompts))))
                
                downsampled.append(im)
            ims = Image.fromarray(np.concatenate(downsampled, 0))
            ims_d = ImageDraw.Draw(ims)
            for j in range(1, len(dics)):
                ims_d.line([(0, j*64), (ims.size[0], j*64)], fill="red", width=3)
            
            ims.save(os.path.join(root, f"{i}.png"))

Data shape for DPM-Solver sampling is (5, 4, 64, 64), sampling steps 30


DPM init order: 100%|██████████| 1/1 [00:00<00:00,  4.77it/s]
DPM multistep: 100%|██████████| 29/29 [00:06<00:00,  4.77it/s]


Data shape for DPM-Solver sampling is (5, 4, 64, 64), sampling steps 30


DPM init order: 100%|██████████| 1/1 [00:00<00:00,  4.77it/s]
DPM multistep: 100%|██████████| 29/29 [00:06<00:00,  4.74it/s]


## Frequency behavior of latent during diffusion steps

- CFG 3.0.

<img src="outputs/A photo of a pikachu/CFG_3.0/img_0.png" width="150" height="150">
<img src="outputs/A photo of a pikachu/CFG_3.0/img_1.png" width="150" height="150">
<img src="outputs/A photo of a pikachu/CFG_3.0/img_2.png" width="150" height="150">
<img src="outputs/A photo of a pikachu/CFG_3.0/img_3.png" width="150" height="150">
<img src="outputs/A photo of a pikachu/CFG_3.0/img_4.png" width="150" height="150">

- CFG 9.0.

<img src="outputs/A photo of a pikachu/CFG_9.0/img_0.png" width="150" height="150">
<img src="outputs/A photo of a pikachu/CFG_9.0/img_1.png" width="150" height="150">
<img src="outputs/A photo of a pikachu/CFG_9.0/img_2.png" width="150" height="150">
<img src="outputs/A photo of a pikachu/CFG_9.0/img_3.png" width="150" height="150">
<img src="outputs/A photo of a pikachu/CFG_9.0/img_4.png" width="150" height="150">

- CFG 20.0

<img src="outputs/A photo of a pikachu/CFG_20.0/img_0.png" width="150" height="150">
<img src="outputs/A photo of a pikachu/CFG_20.0/img_1.png" width="150" height="150">
<img src="outputs/A photo of a pikachu/CFG_20.0/img_2.png" width="150" height="150">
<img src="outputs/A photo of a pikachu/CFG_20.0/img_3.png" width="150" height="150">
<img src="outputs/A photo of a pikachu/CFG_20.0/img_4.png" width="150" height="150">

- CFG 3.0.

<img src="outputs/A photo of a ece graduate student/CFG_3.0/img_0.png" width="150" height="150">
<img src="outputs/A photo of a ece graduate student/CFG_3.0/img_1.png" width="150" height="150">
<img src="outputs/A photo of a ece graduate student/CFG_3.0/img_2.png" width="150" height="150">
<img src="outputs/A photo of a ece graduate student/CFG_3.0/img_3.png" width="150" height="150">
<img src="outputs/A photo of a ece graduate student/CFG_3.0/img_4.png" width="150" height="150">

- CFG 9.0.

<img src="outputs/A photo of a ece graduate student/CFG_9.0/img_0.png" width="150" height="150">
<img src="outputs/A photo of a ece graduate student/CFG_9.0/img_1.png" width="150" height="150">
<img src="outputs/A photo of a ece graduate student/CFG_9.0/img_2.png" width="150" height="150">
<img src="outputs/A photo of a ece graduate student/CFG_9.0/img_3.png" width="150" height="150">
<img src="outputs/A photo of a ece graduate student/CFG_9.0/img_4.png" width="150" height="150">

- CFG 20.0

<img src="outputs/A photo of a ece graduate student/CFG_20.0/img_0.png" width="150" height="150">
<img src="outputs/A photo of a ece graduate student/CFG_20.0/img_1.png" width="150" height="150">
<img src="outputs/A photo of a ece graduate student/CFG_20.0/img_2.png" width="150" height="150">
<img src="outputs/A photo of a ece graduate student/CFG_20.0/img_3.png" width="150" height="150">
<img src="outputs/A photo of a ece graduate student/CFG_20.0/img_4.png" width="150" height="150">

## Degradation with self-consuming chain of diffusion

<img src="gc_data/POKEMON/COND/COMBINED_BLIP_MIN_50_MAX_75_WAIFU/SDXL-BASE-1.0/240601_LR_SEARCH/VAE_SDXL_RES_512_IMG_PER_PRO_1_GEN_CFG_7.5_LR_UNET_1e-4_TEXT_5e-5/ITR_1/025_0.png" width="150" height="150">
<img src="gc_data/POKEMON/COND/COMBINED_BLIP_MIN_50_MAX_75_WAIFU/SDXL-BASE-1.0/240601_LR_SEARCH/VAE_SDXL_RES_512_IMG_PER_PRO_1_GEN_CFG_7.5_LR_UNET_1e-4_TEXT_5e-5/ITR_2/025_0.png" width="150" height="150">
<img src="gc_data/POKEMON/COND/COMBINED_BLIP_MIN_50_MAX_75_WAIFU/SDXL-BASE-1.0/240601_LR_SEARCH/VAE_SDXL_RES_512_IMG_PER_PRO_1_GEN_CFG_7.5_LR_UNET_1e-4_TEXT_5e-5/ITR_3/025_0.png" width="150" height="150">
<img src="gc_data/POKEMON/COND/COMBINED_BLIP_MIN_50_MAX_75_WAIFU/SDXL-BASE-1.0/240601_LR_SEARCH/VAE_SDXL_RES_512_IMG_PER_PRO_1_GEN_CFG_7.5_LR_UNET_1e-4_TEXT_5e-5/ITR_4/025_0.png" width="150" height="150">
<img src="gc_data/POKEMON/COND/COMBINED_BLIP_MIN_50_MAX_75_WAIFU/SDXL-BASE-1.0/240601_LR_SEARCH/VAE_SDXL_RES_512_IMG_PER_PRO_1_GEN_CFG_7.5_LR_UNET_1e-4_TEXT_5e-5/ITR_5/025_0.png" width="150" height="150">

- CFG 3.0.

![grid_1](outputs/A%20photo%20of%20a%20pikachu/CFG_3.0/2.png)

- CFG 9.0.

![grid_2](outputs/A%20photo%20of%20a%20pikachu/CFG_9.0/2.png)

- CFG 20.0

![grid_3](outputs/A%20photo%20of%20a%20pikachu/CFG_20.0/2.png)

- CFG 20.0

<img src="outputs/A photo of a pikachu/CFG_20.0/img_0.png" width="150" height="150">
<img src="outputs/A photo of a pikachu/CFG_20.0/img_1.png" width="150" height="150">
<img src="outputs/A photo of a pikachu/CFG_20.0/img_2.png" width="150" height="150">
<img src="outputs/A photo of a pikachu/CFG_20.0/img_3.png" width="150" height="150">
<img src="outputs/A photo of a pikachu/CFG_20.0/img_4.png" width="150" height="150">

- CFG 20.0, low pass filter 0.5.

<img src="outputs/A photo of a pikachu/CFG_20.0_low_pass_0.5/img_0.png" width="150" height="150">
<img src="outputs/A photo of a pikachu/CFG_20.0_low_pass_0.5/img_1.png" width="150" height="150">
<img src="outputs/A photo of a pikachu/CFG_20.0_low_pass_0.5/img_2.png" width="150" height="150">
<img src="outputs/A photo of a pikachu/CFG_20.0_low_pass_0.5/img_3.png" width="150" height="150">
<img src="outputs/A photo of a pikachu/CFG_20.0_low_pass_0.5/img_4.png" width="150" height="150">

- CFG 20.0, Gaussian low pass filter 8.

<img src="outputs/A photo of a pikachu/CFG_20.0_gaussian_8/img_0.png" width="150" height="150">
<img src="outputs/A photo of a pikachu/CFG_20.0_gaussian_8/img_1.png" width="150" height="150">
<img src="outputs/A photo of a pikachu/CFG_20.0_gaussian_8/img_2.png" width="150" height="150">
<img src="outputs/A photo of a pikachu/CFG_20.0_gaussian_8/img_3.png" width="150" height="150">
<img src="outputs/A photo of a pikachu/CFG_20.0_gaussian_8/img_4.png" width="150" height="150">

- CFG 9.0

<img src="outputs/A photo of a pikachu/CFG_9.0/img_0.png" width="150" height="150">
<img src="outputs/A photo of a pikachu/CFG_9.0/img_1.png" width="150" height="150">
<img src="outputs/A photo of a pikachu/CFG_9.0/img_2.png" width="150" height="150">
<img src="outputs/A photo of a pikachu/CFG_9.0/img_3.png" width="150" height="150">
<img src="outputs/A photo of a pikachu/CFG_9.0/img_4.png" width="150" height="150">

- CFG 9.0, low pass filter 0.5.

<img src="outputs/A photo of a pikachu/CFG_9.0_low_pass_0.5/img_0.png" width="150" height="150">
<img src="outputs/A photo of a pikachu/CFG_9.0_low_pass_0.5/img_1.png" width="150" height="150">
<img src="outputs/A photo of a pikachu/CFG_9.0_low_pass_0.5/img_2.png" width="150" height="150">
<img src="outputs/A photo of a pikachu/CFG_9.0_low_pass_0.5/img_3.png" width="150" height="150">
<img src="outputs/A photo of a pikachu/CFG_9.0_low_pass_0.5/img_4.png" width="150" height="150">

- CFG 9.0, gaussian low pass filter 8.

<img src="outputs/A photo of a pikachu/CFG_9.0_gaussian_8/img_0.png" width="150" height="150">
<img src="outputs/A photo of a pikachu/CFG_9.0_gaussian_8/img_1.png" width="150" height="150">
<img src="outputs/A photo of a pikachu/CFG_9.0_gaussian_8/img_2.png" width="150" height="150">
<img src="outputs/A photo of a pikachu/CFG_9.0_gaussian_8/img_3.png" width="150" height="150">
<img src="outputs/A photo of a pikachu/CFG_9.0_gaussian_8/img_4.png" width="150" height="150">

- CFG 20.0

<img src="outputs/A photo of a ece graduate student/CFG_20.0/img_0.png" width="150" height="150">
<img src="outputs/A photo of a ece graduate student/CFG_20.0/img_1.png" width="150" height="150">
<img src="outputs/A photo of a ece graduate student/CFG_20.0/img_2.png" width="150" height="150">
<img src="outputs/A photo of a ece graduate student/CFG_20.0/img_3.png" width="150" height="150">
<img src="outputs/A photo of a ece graduate student/CFG_20.0/img_4.png" width="150" height="150">

- CFG 20.0, low pass filter 0.5.

<img src="outputs/A photo of a ece graduate student/CFG_20.0_low_pass_0.5/img_0.png" width="150" height="150">
<img src="outputs/A photo of a ece graduate student/CFG_20.0_low_pass_0.5/img_1.png" width="150" height="150">
<img src="outputs/A photo of a ece graduate student/CFG_20.0_low_pass_0.5/img_2.png" width="150" height="150">
<img src="outputs/A photo of a ece graduate student/CFG_20.0_low_pass_0.5/img_3.png" width="150" height="150">
<img src="outputs/A photo of a ece graduate student/CFG_20.0_low_pass_0.5/img_4.png" width="150" height="150">

- CFG 20.0, Gaussian low pass filter 8.

<img src="outputs/A photo of a ece graduate student/CFG_20.0_gaussian_8/img_0.png" width="150" height="150">
<img src="outputs/A photo of a ece graduate student/CFG_20.0_gaussian_8/img_1.png" width="150" height="150">
<img src="outputs/A photo of a ece graduate student/CFG_20.0_gaussian_8/img_2.png" width="150" height="150">
<img src="outputs/A photo of a ece graduate student/CFG_20.0_gaussian_8/img_3.png" width="150" height="150">
<img src="outputs/A photo of a ece graduate student/CFG_20.0_gaussian_8/img_4.png" width="150" height="150">

- CFG 9.0

<img src="outputs/A photo of a ece graduate student/CFG_9.0/img_0.png" width="150" height="150">
<img src="outputs/A photo of a ece graduate student/CFG_9.0/img_1.png" width="150" height="150">
<img src="outputs/A photo of a ece graduate student/CFG_9.0/img_2.png" width="150" height="150">
<img src="outputs/A photo of a ece graduate student/CFG_9.0/img_3.png" width="150" height="150">
<img src="outputs/A photo of a ece graduate student/CFG_9.0/img_4.png" width="150" height="150">

- CFG 9.0, low pass filter 0.5.

<img src="outputs/A photo of a ece graduate student/CFG_9.0_low_pass_0.5/img_0.png" width="150" height="150">
<img src="outputs/A photo of a ece graduate student/CFG_9.0_low_pass_0.5/img_1.png" width="150" height="150">
<img src="outputs/A photo of a ece graduate student/CFG_9.0_low_pass_0.5/img_2.png" width="150" height="150">
<img src="outputs/A photo of a ece graduate student/CFG_9.0_low_pass_0.5/img_3.png" width="150" height="150">
<img src="outputs/A photo of a ece graduate student/CFG_9.0_low_pass_0.5/img_4.png" width="150" height="150">

- CFG 9.0, gaussian low pass filter 8.

<img src="outputs/A photo of a ece graduate student/CFG_9.0_gaussian_8/img_0.png" width="150" height="150">
<img src="outputs/A photo of a ece graduate student/CFG_9.0_gaussian_8/img_1.png" width="150" height="150">
<img src="outputs/A photo of a ece graduate student/CFG_9.0_gaussian_8/img_2.png" width="150" height="150">
<img src="outputs/A photo of a ece graduate student/CFG_9.0_gaussian_8/img_3.png" width="150" height="150">
<img src="outputs/A photo of a ece graduate student/CFG_9.0_gaussian_8/img_4.png" width="150" height="150">

- CFG 20.0

![grid_3](outputs/A%20photo%20of%20a%20pikachu/CFG_20.0/2.png)

- CFG 20.0, low pass filter 0.5.

![grid_3](outputs/A%20photo%20of%20a%20pikachu/CFG_20.0_low_pass_0.5/2.png)

- CFG 20.0, gaussian low pass filter 8.

![grid_3](outputs/A%20photo%20of%20a%20pikachu/CFG_20.0_gaussian_8/2.png)