In [None]:
# !git clone https://github.com/google/style-aligned

In [None]:
import sys
from pathlib import Path

repo_dir = Path("style-aligned")
sys.path.append(str(repo_dir.resolve()))

import torch
import math
import itertools
import mediapy as media

from diffusers import StableDiffusionXLPipeline, DDIMScheduler
import sa_handler

In [None]:
scheduler = DDIMScheduler(
    beta_start=0.00085,
    beta_end=0.012,
    beta_schedule="scaled_linear",
    clip_sample=False,
    set_alpha_to_one=False,
)

pipe = StableDiffusionXLPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    torch_dtype=torch.float16,
    variant="fp16",
    use_safetensors=True,
    scheduler=scheduler,
).to("cuda")

pipe.enable_vae_slicing()


In [None]:
handler = sa_handler.Handler(pipe)

def run_sa(prompts, sa_args, seed=0, num_inference_steps=30, guidance_scale=7.5):
    g = torch.Generator(device="cpu")
    g.manual_seed(int(seed))
    handler.register(sa_args)
    images = pipe(
        prompts,
        generator=g,
        num_inference_steps=int(num_inference_steps),
        guidance_scale=float(guidance_scale),
    ).images
    handler.remove()
    return images


In [None]:
base_prompts = [
    "a toy train",
    "a toy airplane",
    "a toy bicycle",
    "a toy car",
    "a toy boat",
]

style_prompt = "macro photo, 3d game asset"

prompts = [p + ", " + style_prompt for p in base_prompts]
prompts


In [None]:
grid = {
    "share_group_norm": [False, True],
    "share_layer_norm": [False, True],
    "share_attention": [True],
    "adain_queries": [True],
    "adain_keys": [True],
    "adain_values": [False],
    "only_self_level": [0.0, 0.5],
    "shared_score_shift": [0.0, math.log(2)],
    "shared_score_scale": [1.0],
    "full_attention_share": [False],
}

keys = list(grid.keys())
values = [grid[k] for k in keys]

results = []

for combo in itertools.product(*values):
    cfg = dict(zip(keys, combo))
    sa_args = sa_handler.StyleAlignedArgs(
        share_group_norm=cfg["share_group_norm"],
        share_layer_norm=cfg["share_layer_norm"],
        share_attention=cfg["share_attention"],
        adain_queries=cfg["adain_queries"],
        adain_keys=cfg["adain_keys"],
        adain_values=cfg["adain_values"],
        full_attention_share=cfg["full_attention_share"],
        shared_score_scale=cfg["shared_score_scale"],
        shared_score_shift=cfg["shared_score_shift"],
        only_self_level=cfg["only_self_level"],
    )
    images = run_sa(
        prompts,
        sa_args,
        seed=10,
        num_inference_steps=30,
        guidance_scale=7.5,
    )
    results.append((cfg, images))


In [None]:
def cfg_to_title(cfg):
    return (
        f"GN={int(cfg['share_group_norm'])} "
        f"LN={int(cfg['share_layer_norm'])} "
        f"only_self={cfg['only_self_level']} "
        f"shift={round(cfg['shared_score_shift'],3)}"
    )

for cfg, images in results:
    titles = [cfg_to_title(cfg)] + [""] * (len(images) - 1)
    media.show_images(images, titles=titles, height=192)
