# CTRLorALTer SD-XL

This notebook visualizes the approach of optimization in the CTRLorALTer space of SD-XL using a sample batch.

## Setup

In [None]:
import matplotlib.pyplot as plt
import copy
from src.ctrloralter.model import SDXL
import torch

In [None]:
# Set half-precision for the model
dtype = torch.bfloat16

In [None]:
from torch import Generator

# Manual seed for reproducibility
SEED = 42

def get_generator(seed=SEED, device="cuda"):
	"""Get a torch generator with a fixed seed."""
	return Generator(device=device).manual_seed(seed)

### Load Evaluation Batch

In [None]:
batch = torch.load("../data/ffhq/eval/batch_1024.pt", map_location="cuda")

### Load Adapter

#### Full B-LoRA (Style+Content)

In [None]:
from src.ctrloralter.annotators.openclip import VisionModel
from src.ctrloralter.mapper_network import SimpleMapper

full_cfg_base = {
	"ckpt_path": "ctrloralter/checkpoints",
	"ignore_check": False,
	"lora": {
		"style": {
			"enable": "always",
			"optimize": False,
			"ckpt_path": "ctrloralter/checkpoints/sdxl_b-lora_256",
			"cfg": True,
			"transforms": [],
			"config": {
				"lora_scale": 1.0,
				"rank": 256,
				"c_dim": 1024,
				"adaption_mode": "b-lora",
				"lora_cls": "SimpleLoraLinear",
			},
			"encoder": VisionModel(clip_model="laion/CLIP-ViT-H-14-laion2B-s32B-b79K", local_files_only=False),
            "mapper_network": SimpleMapper(d_model=1024, c_dim=1024)
		}
	}
}

#### B-Lora Style

In [None]:
from src.ctrloralter.annotators.openclip import VisionModel
from src.ctrloralter.mapper_network import SimpleMapper

style_cfg_base = {
	"ckpt_path": "ctrloralter/checkpoints",
	"ignore_check": True,
	"lora": {
		"style": {
			"enable": "always",
			"optimize": False,
			"ckpt_path": "ctrloralter/checkpoints/sdxl_b-lora_256",
			"cfg": True,
			"transforms": [],
			"config": {
				"lora_scale": 1.0,
				"rank": 256,
				"c_dim": 1024,
				"adaption_mode": "b-lora_style",
				"lora_cls": "SimpleLoraLinear",
			},
			"encoder": VisionModel(clip_model="laion/CLIP-ViT-H-14-laion2B-s32B-b79K", local_files_only=False),
            "mapper_network": SimpleMapper(d_model=1024, c_dim=1024)
		}
	}
}

#### B-LoRA Content

In [None]:
from src.ctrloralter.annotators.openclip import VisionModel
from src.ctrloralter.mapper_network import SimpleMapper

content_cfg_base = {
	"ckpt_path": "ctrloralter/checkpoints",
	"ignore_check": True,
	"lora": {
		"style": {
			"enable": "always",
			"optimize": False,
			"ckpt_path": "ctrloralter/checkpoints/sdxl_b-lora_256",
			"cfg": True,
			"transforms": [],
			"config": {
				"lora_scale": 1.0,
				"rank": 256,
				"c_dim": 1024,
				"adaption_mode": "b-lora_content",
				"lora_cls": "SimpleLoraLinear",
			},
			"encoder": VisionModel(clip_model="laion/CLIP-ViT-H-14-laion2B-s32B-b79K", local_files_only=False),
            "mapper_network": SimpleMapper(d_model=1024, c_dim=1024)
		}
	}
}

#### Add Adapters to Model

In [None]:
from omegaconf import OmegaConf
from src.ctrloralter.utils import add_lora_from_config

def add_adapter(model, cfg, device="cuda", dtype=dtype):
            
	# wrap it in a DictConfig
	omega_cfg = OmegaConf.create(cfg, flags={"allow_objects": True})

	return add_lora_from_config(model, omega_cfg, device=device, dtype=dtype)

## Workflow

### Load Model

In [None]:
sdxl = SDXL(
	pipeline_type="diffusers.StableDiffusionXLPipeline",
	model_name="stabilityai/stable-diffusion-xl-base-1.0",
	local_files_only=False,
    guidance_scale=10,
).cuda().eval()

# set correct dtype
sdxl = sdxl.to(dtype)

### Add adapters to model

In [None]:
cfg_mask = add_adapter(sdxl, full_cfg_base)

### Predict phi

In [None]:
phi = sdxl.predict_phi(batch.to(dtype), branch_idx=0)

### Sample Images

Sample image from the model using the obtained $\varphi$ as condition. Note that these $\varphi$ have not been optimized, but are the direct output of the global mapper of the style adapter. So the sampled images can't be seen as optimized images, but rather as some form reconstruction of the input images.

In [None]:
sampled_images = sdxl.sample_custom(
    prompt="realistic colorized photograph of a person",
    num_images_per_prompt=batch.shape[0],
    cs=[phi],
    generator=get_generator(),
    cfg_mask=cfg_mask, # use classifier-free guidance mask
    skip_encode=True, # skip encoding conditioning
    skip_mapping=True, # skip mapping conditioning
)

### Visualize results

In [None]:
# Visualize input images (batch) and sampled images (sampled_images) next to each other
import matplotlib.pyplot as plt

fig, axes = plt.subplots(2, batch.shape[0], figsize=(2*batch.shape[0], 4))
for i in range(batch.shape[0]):
	axes[0, i].imshow(batch[i].permute(1, 2, 0).cpu().numpy() * 0.5 + 0.5)
	axes[0, i].axis('off')
	axes[1, i].imshow(sampled_images[i])
	axes[1, i].axis('off')
plt.tight_layout()
plt.show()

## Comparison Conditioning Opportunities

#### Style reconstructions

In [None]:
sdxl = SDXL(
	pipeline_type="diffusers.StableDiffusionXLPipeline",
	model_name="stabilityai/stable-diffusion-xl-base-1.0",
	local_files_only=False,
    guidance_scale=10,
).cuda().eval().to(dtype)

# Add only style adapter
cfg_mask = add_adapter(sdxl, style_cfg_base)

# Predict phi
phi = sdxl.predict_phi(batch.to(dtype), branch_idx=0)

# Sample style images
style = sdxl.sample_custom(
    prompt="",
    num_images_per_prompt=batch.shape[0],
    cs=[phi],
    generator=get_generator(),
    cfg_mask=cfg_mask,
    skip_encode=True,
    skip_mapping=True,
)

#### Content Reconstructions

In [None]:
sdxl = SDXL(
	pipeline_type="diffusers.StableDiffusionXLPipeline",
	model_name="stabilityai/stable-diffusion-xl-base-1.0",
	local_files_only=False,
    guidance_scale=10,
).cuda().eval().to(dtype)

# Add only content adapter
cfg_mask = add_adapter(sdxl, content_cfg_base)

# Predict phi
phi = sdxl.predict_phi(batch.to(dtype), branch_idx=0)

# Sample content images
content = sdxl.sample_custom(
    prompt="",
    num_images_per_prompt=batch.shape[0],
    cs=[phi],
    generator=get_generator(),
    cfg_mask=cfg_mask,
    skip_encode=True,
    skip_mapping=True,
)

#### Style + Content reconstructions

In [None]:
sdxl = SDXL(
	pipeline_type="diffusers.StableDiffusionXLPipeline",
	model_name="stabilityai/stable-diffusion-xl-base-1.0",
	local_files_only=False,
    guidance_scale=10,
).cuda().eval().to(dtype)

# Add full adapter
cfg_mask = add_adapter(sdxl, full_cfg_base)

# Predict phi
phi = sdxl.predict_phi(batch.to(dtype), branch_idx=0)

# Sample full images
content_style = sdxl.sample_custom(
    prompt="",
    num_images_per_prompt=batch.shape[0],
    cs=[phi],
    generator=get_generator(),
    cfg_mask=cfg_mask,
    skip_encode=True,
    skip_mapping=True,
)

#### Visualize all images

In [None]:
import matplotlib.pyplot as plt

fig, ax = plt.subplots(4, batch.shape[0], figsize=(2*batch.shape[0], 2.2*4), squeeze=False)

# Row 0: Original images
ax[0, 0].set_title("Original Images", loc="left")
for i in range(batch.shape[0]):
	img = (batch[i].permute(1, 2, 0).cpu().numpy() * 0.5 + 0.5).clip(0, 1)
	ax[0, i].imshow(img)
	ax[0, i].axis('off')

# Row 1: Style images
ax[1, 0].set_title("Reconstruction based on Style", loc="left")
for i in range(batch.shape[0]):
	ax[1, i].imshow(style[i])
	ax[1, i].axis('off')

# Row 2: Content images
ax[2, 0].set_title("Reconstruction based on Content", loc="left")
for i in range(batch.shape[0]):
	ax[2, i].imshow(content[i])
	ax[2, i].axis('off')

# Row 3: Style + Content images
ax[3, 0].set_title("Reconstruction based on Style + Content", loc="left")
for i in range(batch.shape[0]):
	ax[3, i].imshow(content_style[i])
	ax[3, i].axis('off')

plt.show()

## LoRA Scale Ablation

### Style reconstructions

In [None]:
results = {}

style_scales = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0, 1.2, 1.4, 1.6]

for scale in style_scales:
	print(f"Testing lora scale: {scale}")

	# Load model
	sdxl = SDXL(
		pipeline_type="diffusers.StableDiffusionXLPipeline",
		model_name="stabilityai/stable-diffusion-xl-base-1.0",
		local_files_only=False,
		guidance_scale=10,
	).cuda().eval().to(dtype)

	# Set LoRA scale
	style_cfg = copy.deepcopy(style_cfg_base)
	style_cfg["lora"]["style"]["config"]["lora_scale"] = scale
	cfg_mask = add_adapter(sdxl, style_cfg)

	# Predict phi
	phi = sdxl.predict_phi(batch.to(dtype), branch_idx=0)

	# Sample style images
	style = sdxl.sample_custom(
		prompt="",
		num_images_per_prompt=batch.shape[0],
		cs=[phi],
		generator=get_generator(),
		cfg_mask=cfg_mask,
		skip_encode=True,
		skip_mapping=True,
	)

	# Append to results
	results[scale] = style

# Visualize results
fig, ax = plt.subplots(len(results)+1, batch.shape[0], figsize=(2.2*batch.shape[0], 2.2*(len(results)+1)), squeeze=False)
ax[0, 0].set_title("Original Images", loc="left")
for j in range(batch.shape[0]):
	img = (batch[j].permute(1, 2, 0).cpu().numpy() * 0.5 + 0.5).clip(0, 1)
	ax[0, j].imshow(img)
	ax[0, j].axis('off')
for i, (scale, images) in enumerate(results.items()):
	ax[i+1, 0].set_title(f"LoRA Scale: {scale:.1f}", loc="left")
	for j, img in enumerate(images):
		ax[i+1, j].imshow(img)
		ax[i+1, j].axis('off')
plt.tight_layout()
plt.show()

### Content reconstructions

In [None]:
results = {}

content_scales = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0, 1.2, 1.4, 1.6]

for scale in content_scales:
	print(f"Testing lora scale: {scale}")

	# Load model
	sdxl = SDXL(
		pipeline_type="diffusers.StableDiffusionXLPipeline",
		model_name="stabilityai/stable-diffusion-xl-base-1.0",
		local_files_only=False,
		guidance_scale=10,
	).cuda().eval().to(dtype)

	# Set LoRA scale
	content_cfg = copy.deepcopy(content_cfg_base)
	content_cfg["lora"]["style"]["config"]["lora_scale"] = scale
	cfg_mask = add_adapter(sdxl, content_cfg)

	# Predict phi
	phi = sdxl.predict_phi(batch.to(dtype), branch_idx=0)

	# Sample content images
	content = sdxl.sample_custom(
		prompt="",
		num_images_per_prompt=batch.shape[0],
		cs=[phi],
		generator=get_generator(),
		cfg_mask=cfg_mask,
		skip_encode=True,
		skip_mapping=True,
	)

	# Append to results
	results[scale] = content

# Visualize results
fig, ax = plt.subplots(len(results)+1, batch.shape[0], figsize=(2.2*batch.shape[0], 2.2*(len(results)+1)), squeeze=False)
ax[0, 0].set_title("Original Images", loc="left")
for j in range(batch.shape[0]):
	img = (batch[j].permute(1, 2, 0).cpu().numpy() * 0.5 + 0.5).clip(0, 1)
	ax[0, j].imshow(img)
	ax[0, j].axis('off')
for i, (scale, images) in enumerate(results.items()):
	ax[i+1, 0].set_title(f"LoRA Scale: {scale:.1f}", loc="left")
	for j, img in enumerate(images):
		ax[i+1, j].imshow(img)
		ax[i+1, j].axis('off')
plt.tight_layout()
plt.show()

### Style + Content reconstructions

In [None]:
results = {}

full_scales = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0, 1.2, 1.4, 1.6]

for scale in full_scales:
	print(f"Testing lora scale: {scale}")

	# Load model
	sdxl = SDXL(
		pipeline_type="diffusers.StableDiffusionXLPipeline",
		model_name="stabilityai/stable-diffusion-xl-base-1.0",
		local_files_only=False,
		guidance_scale=10,
	).cuda().eval().to(dtype)

	# Set LoRA scale
	full_cfg = copy.deepcopy(full_cfg_base)
	full_cfg["lora"]["style"]["config"]["lora_scale"] = scale
	cfg_mask = add_adapter(sdxl, full_cfg)

	# Predict phi
	phi = sdxl.predict_phi(batch.to(dtype), branch_idx=0)

	# Sample images
	content_style = sdxl.sample_custom(
		prompt="",
		num_images_per_prompt=batch.shape[0],
		cs=[phi],
		generator=get_generator(),
		cfg_mask=cfg_mask,
		skip_encode=True,
		skip_mapping=True,
	)

	# Append to results
	results[scale] = content_style

# Visualize results
fig, ax = plt.subplots(len(results)+1, batch.shape[0], figsize=(2.2*batch.shape[0], 2.2*(len(results)+1)), squeeze=False)
ax[0, 0].set_title("Original Images", loc="left")
for j in range(batch.shape[0]):
	img = (batch[j].permute(1, 2, 0).cpu().numpy() * 0.5 + 0.5).clip(0, 1)
	ax[0, j].imshow(img)
	ax[0, j].axis('off')
for i, (scale, images) in enumerate(results.items()):
	ax[i+1, 0].set_title(f"LoRA Scale: {scale:.1f}", loc="left")
	for j, img in enumerate(images):
		ax[i+1, j].imshow(img)
		ax[i+1, j].axis('off')
plt.tight_layout()
plt.show()

## Guidance Scale Ablation

### Style reconstructions

In [None]:
results = {}

guidance_scales = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]

for scale in guidance_scales:
	print(f"Testing guidance scale: {scale}")

	# Load model
	sdxl = SDXL(
		pipeline_type="diffusers.StableDiffusionXLPipeline",
		model_name="stabilityai/stable-diffusion-xl-base-1.0",
		local_files_only=False,
		guidance_scale=scale,
	).cuda().eval().to(dtype)

	# Set LoRA scale
	cfg_mask = add_adapter(sdxl, style_cfg_base)

	# Predict phi
	phi = sdxl.predict_phi(batch.to(dtype), branch_idx=0)

	# Sample style images
	style = sdxl.sample_custom(
		prompt="",
		num_images_per_prompt=batch.shape[0],
		cs=[phi],
		generator=get_generator(),
		cfg_mask=cfg_mask,
		skip_encode=True,
		skip_mapping=True,
	)

	# Append to results
	results[scale] = style

# Visualize results
fig, ax = plt.subplots(len(results)+1, batch.shape[0], figsize=(2.2*batch.shape[0], 2.2*(len(results)+1)), squeeze=False)
ax[0, 0].set_title("Original Images", loc="left")
for j in range(batch.shape[0]):
	img = (batch[j].permute(1, 2, 0).cpu().numpy() * 0.5 + 0.5).clip(0, 1)
	ax[0, j].imshow(img)
	ax[0, j].axis('off')
for i, (scale, images) in enumerate(results.items()):
	ax[i+1, 0].set_title(f"Guidance Scale: {scale:.1f}", loc="left")
	for j, img in enumerate(images):
		ax[i+1, j].imshow(img)
		ax[i+1, j].axis('off')
plt.tight_layout()
plt.show()