> dynamic Classifier-free Guidance across several Diffusion models.

# Introduction

This notebook is Part 7 in a [series](https://enzokro.dev/blog/posts/2022-11-15-guidance-expts-6/) on dynamic Classifier-free Guidance. It checks whether our proposed schedules and normalizations improve images across Diffusion models.

## Recap of Parts 1-6

In the first six parts, we found a good set of schedules and normalizations for a dynamic Classifier-free Guidance. The best performing schedules are used in this notebook.  

## Part 7: Improvement across models

Part 7 takes our best schedule so far, `Inverse kDecay`, and tries it on a few different models:  

- Stable Diffusion v1-4
- Stable Diffusion v1-5
- Prompt Hero's openjourney
- Stable Diffusion 2-base

# Python imports

We start with a few python imports.

In [1]:
import os
import random
from functools import partial
import torch
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors

## Seed for reproducibility

`seed_everything` makes sure that the results are reproducible across notebooks.

In [2]:
# set the seed and pseudo random number generator
SEED = 1337802893 # inca 977145576 # inca2, warrior
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    generator = torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    return generator

# for sampling the initial, noisy latents
generator = seed_everything(SEED)

# Cosine schedules with k-decay

We create the schedules with different $k$ values using the `cf_guidance` library.

In [3]:
# helpers to create cosine schedules
from cf_guidance.schedules  import get_cos_sched

# normalizations for classifier-free guidance
from cf_guidance.transforms import GuidanceTfm, BaseNormGuidance, TNormGuidance, FullNormGuidance

For the other schedule parameters, we keep the [same values](https://enzokro.dev/blog/posts/2022-11-20-guidance-expts-2/#default-schedule-parameters) from the rest of the series. The functions below are also shared with previous notebooks. 

In [8]:
# Default schedule parameters from the blog post
######################################
max_val           = 8     # guidance scaling value
min_val           = 1     # minimum guidance scaling
num_steps         = 50    # number of diffusion steps
num_warmup_steps  = 0     # number of warmup steps
warmup_init_val   = 0     # the intial warmup value
num_cycles        = 0.5     # number of cosine cycles
k_decay           = 1     # k-decay for cosine curve scaling 

# smaller values for T-Norm and FullNorm
max_T = 0.15
min_T = 0.01
######################################

DEFAULT_COS_PARAMS = {
    'max_val':           max_val,
    'num_steps':         num_steps,
    'min_val':           min_val,
    'num_cycles':        num_cycles,
    'k_decay':           k_decay,
    'num_warmup_steps':  num_warmup_steps,
    'warmup_init_val':   warmup_init_val,
}

DEFAULT_T_PARAMS = {
    'max_val':           max_T,
    'num_steps':         num_steps,
    'min_val':           min_T,
    'num_cycles':        num_cycles,
    'k_decay':           k_decay,
    'num_warmup_steps':  num_warmup_steps,
    'warmup_init_val':   warmup_init_val,
}

def cos_harness(default_params, new_params):
    '''Creates cosine schedules with updated parameters in `new_params`
    '''
    # start from the given baseline `default_params`
    cos_params = dict(default_params)
    # update the with the new, given parameters
    cos_params.update(new_params)
    
    # return the new cosine schedule
    sched = get_cos_sched(**cos_params)
    return sched


def create_expts(params: dict, schedule_func) -> list:
    '''Creates a list of experiments.
    
    Each element is a dictionary with the name, value, and schedule for a given parameter.
    A `title` field is also added for easy plotting.
    '''
    names = sorted(params)
    expts = []
    # step through parameter names and their values
    for i,name in enumerate(names):
        for j,val in enumerate(params[name]):
            # create the experiment
            expt = {'param_name': name,
                    'val': val,
                    'schedule': schedule_func(new_params={name: val})}
            # name for plotting
            expt['title'] = f'Param: "{name}", val={val}'
            # add it to the experiment list
            expts.append(expt)
    return expts

Next we create the best k-decay cosine schedules.

In [9]:
# setup for the Inverse-k-decay cosine schedules
inv_k_params = {'k_decay': [0.15]}
inv_k_func = partial(cos_harness, default_params=DEFAULT_COS_PARAMS)
inv_k_expts = create_expts(inv_k_params, inv_k_func)

# invert the `k` schedules
for s in inv_k_expts:
    s['schedule'] = [max_val - g + min_val for g in s['schedule']]

# put all schedules together
all_k_expts = inv_k_expts

In [None]:
#| echo: false
#| output: true
colors=list(mcolors.TABLEAU_COLORS)

# setup the plot
fig,ax = plt.subplots(figsize=(12,8))
plt.title('Inverse Cosine Schedules with K-decay', fontsize='xx-large')
plt.xlabel('Diffusion timesteps', fontsize='x-large')
plt.ylabel('Guidance parameter', fontsize='x-large')

# plot each k values
for idx,s in enumerate(inv_k_expts):
    ax.plot(s['schedule'], c=colors[idx], label=f'k: {s["val"]:.2f}')
    
plt.legend()
plt.tight_layout();

We repeat this for the `T` and `Full` Normalizations as well

In [None]:
# create the Inverse-k-decay cosine experiments
T_inv_k_func = partial(cos_harness, default_params=DEFAULT_T_PARAMS)
T_inv_k_expts = create_expts(inv_k_params, T_inv_k_func)

# stores the inverted schedules
# invert the `k` schedules
for s in T_inv_k_expts:
    s['schedule'] = [max_T - g + min_T for g in s['schedule']]

all_T_k_expts = T_inv_k_expts

In [None]:
#| echo: false
#| output: true
colors=list(mcolors.TABLEAU_COLORS)

# setup the plot
fig,ax = plt.subplots(figsize=(12,8))
plt.title('T Inverse Cosine Schedules with K-decay', fontsize='xx-large')
plt.xlabel('Diffusion timesteps', fontsize='x-large')
plt.ylabel('Guidance parameter', fontsize='x-large')

# plot each k values
for idx,s in enumerate(T_inv_k_expts):
    ax.plot(s['schedule'], c=colors[idx], label=f'k: {s["val"]:.2f}')
    
plt.legend()
plt.tight_layout();

# Loading different StableDiffusion models

We need to wrap our experiment pipeline in a single loop so we can easily run it with different models. To do this, we'll move the model loading code below in its own function, and add a function cleanup gpu memory.

In [None]:
%%capture
# to load Stable Diffusion pipelines
from min_diffusion.core import MinimalDiffusion

# to plot generated images
from min_diffusion.utils import show_image, image_grid, plot_grid

We use it to load the `Stable Diffusion v1-4` model on the GPU, with `torch.float16` precision.

In [None]:
def load_sd_model(model_name, device, dtype, model_kwargs={}, generator=None):
    pipeline = MinimalDiffusion(model_name, device, dtype, generator=generator)
    pipeline.load(**model_kwargs);
    return pipeline

# Text prompt for image generations

We use the familiar, running prompt in our series to generate an image:  

> "a photograph of an astronaut riding a horse"

In [None]:
# text prompt for image generations

# prompt = "a beautiful painting of an elegant cat, highly detailed, 4K, 8K, trending on art station, Award winning"

# prompt = "digital painting of hanan pacha, the incan world above us where the sun and moon live, by filipe pagliuso and justin gerard, symmetric, fantasy, realistic, highly detailed, realistic, intricate, sharp focus, tarot card, portrait"
prompt = "digital painting of masked incan warrior, by filipe pagliuso and justin gerard, symmetric, fantasy, highly detailed, realistic, intricate, portrait, sharp focus, tarot card, face, handsome, peruvian, ax"

## Image parameters

Images will be generated over $50$ diffusion steps. They will have a height and width of `512 x 512` pixels. 

In [None]:
# the number of diffusion steps
num_steps = 50

# dimensions for v1 and v2 Stable Diffusions
sd2_dims = {'height': 768, 'width': 768}
sd_dims  = {'height': 640, 'width': 512} # goddess prompt


# Running the experiments

We modify the `run` function to now load the Stable Diffusion model internally. This makes it easy to pass in and try different generators. We add a bit of GPU cleanup at the end to make sure there is enough memory for the models. 

In [None]:
def run(pipeline, prompt, schedules,
        guide_tfm=None, generator=None, show_each=False, test_run=False, gen_kwargs={}):
    """Runs a dynamic Classifier-free Guidance experiment. 
    
    Generates an image for the text `prompt` given all the values in `schedules`.
    Uses a Guidance Transformation class from the `cf_guidance` library.  
    Stores the output images with a matching title for plotting. 
    Optionally shows each image as its generated.
    If `test_run` is true, it runs a single schedule for testing. 
    """
    # store generated images and their title (the experiment name)
    images, titles = [], []
    
    # make sure we have a valid guidance transform
    assert guide_tfm
    print(f'Using Guidance Transform: {guide_tfm}')
    
    # optionally run a single test schedule
    if test_run:
        print(f'Running a single schedule for testing.')
        schedules = schedules[:1]
        
    # run all schedule experiments
    for i,s in enumerate(schedules):
        
        # parse out the title for the current run
        cur_title  = s['title']
        titles.append(cur_title)
        
        # create the guidance transformation 
        cur_sched = s['schedule']
        gtfm = guide_tfm({'g': cur_sched})
        
        print(f'Running experiment [{i+1} of {len(schedules)}]: {cur_title}...')
        img = pipeline.generate(prompt, gtfm, **gen_kwargs)
        images.append(img)
        
        # optionally plot the image
        if show_each:
            show_image(img, scale=1)
            
    print('Done.')
    return {'images': images,
            'titles': titles}

# Gathering models and arguments

Next we create the arguments and parameters to run different models.

In [None]:

# group the different models to run
model_expts = [

    # SD v1-4
    {'model_name': 'CompVis/stable-diffusion-v1-4',
     'model_kwargs': {'better_vae': 'mse'}},
    # SD v1-5
    {'model_name': 'runwayml/stable-diffusion-v1-5',
     'model_kwargs': {'better_vae': 'mse'}},
    # openjourney
    {'model_name': "prompthero/openjourney",
     'model_kwargs': {}},
    # SD 2-base
    {'model_name': 'stabilityai/stable-diffusion-2-base',
     'model_kwargs': {'unet_attn_slice': False}},
    # # SD 2
    # {'model_name': 'stabilityai/stable-diffusion-2',
    #  'model_kwargs': {'unet_attn_slice': False}},

]

## Creating the baseline image with $G = 7.5$

First we create the baseline image using a constant Classifier-free Guidance with $G = 7.5$. Since this is a constant schedule, $k$ does not come into play.  

In [None]:
# create the baseline schedule with the new function
baseline_g = 7.5
baseline_params = {'max_val': [baseline_g]}
baseline_func = lambda *args, **kwargs: [baseline_g for _ in range(num_steps)]
baseline_expts = create_expts(baseline_params, baseline_func)


T_baseline_g = 0.15
T_baseline_params = {'max_val': [T_baseline_g]}
T_baseline_func = lambda *args, **kwargs: [T_baseline_g for _ in range(num_steps)]
T_baseline_expts = create_expts(T_baseline_params, T_baseline_func)

## Improving the baseline with schedules and normalizations 

In [None]:
outputs = {}

device = 'cuda'
dtype = torch.float16

for mparams in model_expts:
    print(f'Running model: {mparams}')
    
    model_name = mparams['model_name']
    model_kwargs = mparams['model_kwargs']
    
    # a bit of a manual patch, we need a keyword for openjourney model
    if 'openjourney' in model_name:
        cur_prompt = "mdjrny-v4 style " + prompt
    else:
        cur_prompt = prompt
    print(f'Using prompt: {cur_prompt}')
    
    if model_name == 'stabilityai/stable-diffusion-2':
        gen_kwargs = sd2_dims
    else:
        gen_kwargs = sd_dims
    print(f'Generation kwargs: {gen_kwargs}')
    
    # load the current Diffusion model
    pipeline = load_sd_model(model_name, device, dtype, generator=generator,
                             model_kwargs=model_kwargs)
    
    # make the baseline for this model
    baseline_res = run(pipeline, cur_prompt, baseline_expts, gen_kwargs=gen_kwargs,
                       guide_tfm=GuidanceTfm, generator=generator)
    outputs[(model_name,'baseline')] = baseline_res
    
    # generate images with different normalizations and schedules
    base_norm_res = run(pipeline, cur_prompt, baseline_expts + all_k_expts, gen_kwargs=gen_kwargs,
                        guide_tfm=BaseNormGuidance, generator=generator)
    outputs[(model_name,'baseNorm')] = base_norm_res
                            
    T_res = run(pipeline, cur_prompt, T_baseline_expts + all_T_k_expts, gen_kwargs=gen_kwargs,
                guide_tfm=TNormGuidance, generator=generator)
    outputs[(model_name,'TNorm')] = T_res

    full_res = run(pipeline, cur_prompt, T_baseline_expts + all_T_k_expts, gen_kwargs=gen_kwargs,
                   guide_tfm=FullNormGuidance, generator=generator)
    outputs[(model_name,'FullNorm')] = full_res
    
    # cleanup the model for the next run
    del pipeline
    pipeline = None
    torch.cuda.empty_cache()

# Results

Let's make some helpers to grab all output images for a given Stable Diffusion model.

In [None]:
# names of all the models we tried
model_names = [
    'CompVis/stable-diffusion-v1-4',
    'runwayml/stable-diffusion-v1-5',
    'prompthero/openjourney',
    'stabilityai/stable-diffusion-2-base',
]

# number of images to plot
num_runs = 2
num_rows = 1

# plot dimensions
plot_height, plot_width = 640, 512

def get_results(model_name):
    types = ['baseline', 'baseNorm', 'TNorm', 'FullNorm']
    return [outputs[(model_name,t)] for t in types]

def plot_all_results(model_name):
    mres = get_results(model_name)
    for i in range(num_runs):
        image_grid(
            [mres[0]['images'][0]] + [o['images'][i] for o in mres[1:]], 
            title=[mres[0]['titles'][0]] + [o['titles'][i] for o in mres[1:]],
            rows=num_rows, width=plot_width, height=plot_height
        )
        plt.suptitle(f'Model: {model_name} | Output #{i}')

# SD 1-4

In [None]:
plot_all_results('CompVis/stable-diffusion-v1-4')

# SD 1-5

In [None]:
plot_all_results('runwayml/stable-diffusion-v1-5')

# openjourney

In [None]:
plot_all_results('prompthero/openjourney')

# SD 2-base

In [None]:
plot_all_results('stabilityai/stable-diffusion-2-base')

# Conclusion