# CTRLorALTer

This notebook visualizes the approach of optimization in the CTRLorALTer space of a Stable Diffusion 1.5 model using a sample batch.

## Setup

In [None]:
import matplotlib.pyplot as plt
import copy
from src.ctrloralter.model import SD15
import torch

In [None]:
from torch import Generator

# Manual seed for reproducibility
SEED = 42

def get_generator(seed=SEED, device="cuda"):
	"""Get a torch generator with a fixed seed."""
	return Generator(device=device).manual_seed(seed)

### Load Evaluation Batch

In [None]:
batch = torch.load("../data/ffhq/eval_batch/size_512.pt", map_location="cuda")

### Load Adapter

In [None]:
raw_cfg = {
	"ckpt_path": "ctrloralter/checkpoints",
	"lora": {},
}

#### Style Adapter (Required)

In [None]:
from src.ctrloralter.annotators.openclip import VisionModel
from src.ctrloralter.mapper_network import SimpleMapper

style_cfg_base = {
	"enable": "always",
	"optimize": False,
	"ckpt_path": "ctrloralter/checkpoints/sd15-style-cross-160-h",
	"ignore_check": False,
	"cfg": True,
	"transforms": [],
	"config": {
		"lora_scale": 1.0,
		"rank": 160,
		"c_dim": 1024,
		"adaption_mode": "only_cross",
		"lora_cls": "SimpleLoraLinear",
		"broadcast_tokens": True,
	},
	"encoder": VisionModel(clip_model="laion/CLIP-ViT-H-14-laion2B-s32B-b79K", local_files_only=False),
	"mapper_network": SimpleMapper(1024, 1024),
}

STYLE_FINETUNED_PATH = "/BS/optdif/work/models/sd_lora/version_2/checkpoints/epoch_036" # FID best after 36 epochs

#### Depth Structure Adapter

In [None]:
from src.ctrloralter.annotators.midas import DepthEstimator
from src.ctrloralter.mapper_network import FixedStructureMapper15

depth_cfg_base = {
	"enable": "always",
	"optimize": False,
	"ckpt_path": "ctrloralter/checkpoints/sd15-depth-128-only-res",
	"ignore_check": False,
	"cfg": False,
	"transforms": [],
	"config": {
		"lora_scale": 1.0,
		"rank": 128,
		"c_dim": 128,
		"adaption_mode": "only_res_conv",
		"lora_cls": "NewStructLoRAConv",
	},
	"encoder": DepthEstimator(size=512, local_files_only=False),
	"mapper_network": FixedStructureMapper15(c_dim=128),
}

STYLE_DEPTH_FINETUNED_PATH = "/BS/optdif/work/models/sd_lora/version_3/checkpoints/epoch_000" # FID best after one epoch

#### HED Structure Adapter

In [None]:
from src.ctrloralter.annotators.hed import TorchHEDdetector
from src.ctrloralter.mapper_network import FixedStructureMapper15

hed_cfg_base = {
	"enable": "always",
	"optimize": False,
	"ckpt_path": "ctrloralter/checkpoints/sd15-hed-128-only-res",
	"ignore_check": False,
	"cfg": False,
	"transforms": [],
	"config": {
		"lora_scale": 1.0,
		"rank": 128,
		"c_dim": 128,
		"adaption_mode": "only_res_conv",
		"lora_cls": "NewStructLoRAConv",
	},
	"encoder": TorchHEDdetector(size=512, local_files_only=False),
	"mapper_network": FixedStructureMapper15(c_dim=128),
}

STYLE_HED_FINETUNED_PATH = "/BS/optdif/work/models/sd_lora/version_6/checkpoints/epoch_000" # loss best after five epochs, FID just gets worse

#### Add Adapters to Model

In [None]:
from omegaconf import OmegaConf
from src.ctrloralter.utils import add_lora_from_config

def add_adapters(model, raw_cfg, style_cfg=None, depth_cfg=None, hed_cfg=None, device="cuda"):
	cfg = copy.deepcopy(raw_cfg)

	if style_cfg is not None:
		cfg["lora"]["style"] = style_cfg
	if depth_cfg is not None:
		cfg["lora"]["struct"] = depth_cfg
	elif hed_cfg is not None:
		cfg["lora"]["struct"] = hed_cfg
            
	# wrap it in a DictConfig
	cfg = OmegaConf.create(cfg, flags={"allow_objects": True})

	return add_lora_from_config(model, cfg, device=device)

## Workflow

### Load Model

In [None]:
sd15 = SD15(
	pipeline_type="diffusers.StableDiffusionPipeline",
	model_name="runwayml/stable-diffusion-v1-5",
	local_files_only=False,
).cuda().eval()

### Add adapters to model

In [None]:
cfg_mask = add_adapters(sd15, raw_cfg, style_cfg=style_cfg_base)

### Predict phi

In [None]:
phi = sd15.predict_phi(batch, branch_idx=0)

### Sample Images

Sample image from the model using the obtained $\varphi$ as condition. Note that these $\varphi$ have not been optimized, but are the direct output of the global mapper of the style adapter. So the sampled images can't be seen as optimized images, but rather as some form reconstruction of the input images.

In [None]:
sampled_images = sd15.sample_custom(
    prompt="",
    num_images_per_prompt=batch.shape[0],
    cs=[
        phi,    # style conditioning
        batch,  # structure conditioning
    ],
    generator=get_generator(),
    cfg_mask=cfg_mask, # use classifier-free guidance mask
    skip_encode=[0], # skip encoding the first conditioning (style, already in phi)
    skip_mapping=[0], # skip mapping the first conditioning
)

### Visualize results

In [None]:
# Visualize input images (batch) and sampled images (sampled_images) next to each other
import matplotlib.pyplot as plt

fig, axes = plt.subplots(2, batch.shape[0], figsize=(2*batch.shape[0], 4))
for i in range(batch.shape[0]):
	axes[0, i].imshow(batch[i].permute(1, 2, 0).cpu().numpy() * 0.5 + 0.5)
	axes[0, i].axis('off')
	axes[1, i].imshow(sampled_images[i])
	axes[1, i].axis('off')
plt.tight_layout()
plt.show()

## Supplements

### Visualize Depth Maps

In [None]:
from src.ctrloralter.annotators.midas import DepthEstimator

de = DepthEstimator(size=512, local_files_only=False).to("cuda").eval()

depths = de(batch)
depths = depths.mean(dim=1, keepdim=True)  # Average over the color

# Visualize input images (batch) and their corresponding depth maps
fig, axes = plt.subplots(2, batch.shape[0], figsize=(2*batch.shape[0], 4))
for i in range(batch.shape[0]):
	axes[0, i].imshow(batch[i].permute(1, 2, 0).cpu().numpy() * 0.5 + 0.5)
	axes[0, i].axis('off')
	axes[1, i].imshow(depths[i].permute(1, 2, 0).cpu().numpy(), cmap='gray')
	axes[1, i].axis('off')
plt.tight_layout()
plt.show()

### Visualize HED Maps

In [None]:
from src.ctrloralter.annotators.hed import TorchHEDdetector

hed = TorchHEDdetector(size=512, local_files_only=False).to("cuda").eval()

edges = hed(batch)
edges = edges.mean(dim=1, keepdim=True)  # Average over the color

# Visualize input images (batch) and their corresponding depth maps
fig, axes = plt.subplots(2, batch.shape[0], figsize=(2*batch.shape[0], 4))
for i in range(batch.shape[0]):
	axes[0, i].imshow(batch[i].permute(1, 2, 0).cpu().numpy() * 0.5 + 0.5)
	axes[0, i].axis('off')
	axes[1, i].imshow(edges[i].permute(1, 2, 0).cpu().numpy(), cmap='gray')
	axes[1, i].axis('off')
plt.tight_layout()
plt.show()

## Comparison Conditioning Opportunities

#### Direct style reconstructions (without structure adapter)

In [None]:
# Configure model
sd15 = SD15(
	pipeline_type="diffusers.StableDiffusionPipeline",
	model_name="runwayml/stable-diffusion-v1-5",
	local_files_only=False,
).cuda().eval()

# Add only style adapter
cfg_mask = add_adapters(sd15, raw_cfg, style_cfg=style_cfg_base)

# Predict phi
phi = sd15.predict_phi(batch, branch_idx=0)

# Sample style images
style = sd15.sample_custom(
    prompt="",
    num_images_per_prompt=batch.shape[0],
    cs=[phi],
    generator=get_generator(),
    cfg_mask=cfg_mask,
    skip_encode=True,
    skip_mapping=True,
)

#### Style + Depth Structure Adapter

In [None]:
# Configure model
sd15 = SD15(
	pipeline_type="diffusers.StableDiffusionPipeline",
	model_name="runwayml/stable-diffusion-v1-5",
	local_files_only=False,
).cuda().eval()

depth_model = DepthEstimator(size=512, local_files_only=False).to("cuda").eval()
depth_maps = depth_model(batch)
depth_maps = depth_maps.mean(dim=1, keepdim=True)  # average over color channels

# Add only style adapter
cfg_mask = add_adapters(sd15, raw_cfg, style_cfg=style_cfg_base, depth_cfg=depth_cfg_base)

# Predict phi
phi = sd15.predict_phi(batch, branch_idx=0)

# Sample images
style_depth = sd15.sample_custom(
    prompt="",
    num_images_per_prompt=batch.shape[0],
    cs=[phi, batch],  # style and structure conditioning
    generator=get_generator(),
    cfg_mask=cfg_mask,
    skip_encode=[0],
    skip_mapping=[0],
)

#### Style + HED Structure Adapter

In [None]:
# Configure model
sd15 = SD15(
	pipeline_type="diffusers.StableDiffusionPipeline",
	model_name="runwayml/stable-diffusion-v1-5",
	local_files_only=False,
).cuda().eval()

hed_model = TorchHEDdetector(size=512, local_files_only=False).to("cuda").eval()
hed_maps = hed_model(batch)
hed_maps = hed_maps.mean(dim=1, keepdim=True)  # average over color channels

# Add only style adapter
cfg_mask = add_adapters(sd15, raw_cfg, style_cfg=style_cfg_base, hed_cfg=hed_cfg_base)

# Predict phi
phi = sd15.predict_phi(batch, branch_idx=0)

# Sample images
style_hed = sd15.sample_custom(
    prompt="",
    num_images_per_prompt=batch.shape[0],
    cs=[phi, batch],  # style and structure conditioning
    generator=get_generator(),
    cfg_mask=cfg_mask,
    skip_encode=[0],
    skip_mapping=[0],
)

#### Visualize all images

In [None]:
import matplotlib.pyplot as plt

fig, ax = plt.subplots(6, batch.shape[0], figsize=(2*batch.shape[0], 2.2*6), squeeze=False)

# Row 0: Original images
ax[0, 0].set_title("Original Images", loc="left")
for i in range(batch.shape[0]):
	img = (batch[i].permute(1, 2, 0).cpu().numpy() * 0.5 + 0.5).clip(0, 1)
	ax[0, i].imshow(img)
	ax[0, i].axis('off')

# Row 1: Style images
ax[1, 0].set_title("Reconstruction based on Style", loc="left")
for i in range(batch.shape[0]):
	ax[1, i].imshow(style[i])
	ax[1, i].axis('off')

# Row 2: Depth Maps
ax[2, 0].set_title("Depth Maps", loc="left")
for i in range(batch.shape[0]):
	img = (depth_maps[i].permute(1, 2, 0).cpu().numpy() * 0.5 + 0.5).clip(0, 1)
	ax[2, i].imshow(img, cmap='gray')
	ax[2, i].axis('off')

# Row 3: Style + Depth
ax[3, 0].set_title("Reconstruction based on Style + Depth", loc="left")
for i in range(batch.shape[0]):
	ax[3, i].imshow(style_depth[i])
	ax[3, i].axis('off')

# Row 4: HED Maps
ax[4, 0].set_title("HED Maps", loc="left")
for i in range(batch.shape[0]):
	img = (hed_maps[i].permute(1, 2, 0).cpu().numpy() * 0.5 + 0.5).clip(0, 1)
	ax[4, i].imshow(img, cmap='gray')
	ax[4, i].axis('off')

# Row 5: Style + HED
ax[5, 0].set_title("Reconstruction based on Style + HED", loc="left")
for i in range(batch.shape[0]):
	ax[5, i].imshow(style_hed[i])
	ax[5, i].axis('off')
	
#plt.tight_layout()
plt.show()

## LoRA Scale Ablation

### Direct style reconstructions (without structure adapter)

In [None]:
results = {}

style_scales = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0, 1.2, 1.4, 1.6]

for scale in style_scales:
	print(f"Testing lora scale: {scale}")

	# Load model
	sd15 = SD15(
		pipeline_type="diffusers.StableDiffusionPipeline",
		model_name="runwayml/stable-diffusion-v1-5",
		local_files_only=False,
	).cuda().eval()

	# Set LoRA scale
	style_cfg = copy.deepcopy(style_cfg_base)
	style_cfg["config"]["lora_scale"] = scale
	cfg_mask = add_adapters(sd15, raw_cfg, style_cfg=style_cfg)

	# Predict phi
	phi = sd15.predict_phi(batch, branch_idx=0)

	# Sample style images
	style = sd15.sample_custom(
		prompt="",
		num_images_per_prompt=batch.shape[0],
		cs=[phi],
		generator=get_generator(),
		cfg_mask=cfg_mask,
		skip_encode=True,
		skip_mapping=True,
	)

	# Append to results
	results[scale] = style

# Visualize results
fig, ax = plt.subplots(len(results)+1, batch.shape[0], figsize=(2.2*batch.shape[0], 2.2*(len(results)+1)), squeeze=False)
ax[0, 0].set_title("Original Images", loc="left")
for j in range(batch.shape[0]):
	img = (batch[j].permute(1, 2, 0).cpu().numpy() * 0.5 + 0.5).clip(0, 1)
	ax[0, j].imshow(img)
	ax[0, j].axis('off')
for i, (scale, images) in enumerate(results.items()):
	ax[i+1, 0].set_title(f"LoRA Scale: {scale:.1f}", loc="left")
	for j, img in enumerate(images):
		ax[i+1, j].imshow(img)
		ax[i+1, j].axis('off')
plt.tight_layout()
plt.show()

### Style + Depth Structure Adapter

In [None]:
# If two scales are ablated, visualize effects only on one image
eval_sample = batch[6].clone()

style_scales = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0, 1.2, 1.4, 1.6]
depth_scales = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0, 1.2, 1.4, 1.6]
# style_scales = [0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
# depth_scales = [0.2, 0.3, 0.4, 0.5, 0.6]

results = {}
for style_scale in style_scales:
	for depth_scale in depth_scales:
		print(f"Testing lora scale: {style_scale}, {depth_scale}")

		# Load model
		sd15 = SD15(
			pipeline_type="diffusers.StableDiffusionPipeline",
			model_name="runwayml/stable-diffusion-v1-5",
			local_files_only=False,
		).cuda().eval()

		# Set LoRA scale
		style_cfg = copy.deepcopy(style_cfg_base)
		depth_cfg = copy.deepcopy(depth_cfg_base)
		style_cfg["config"]["lora_scale"] = style_scale
		depth_cfg["config"]["lora_scale"] = depth_scale
		cfg_mask = add_adapters(sd15, raw_cfg, style_cfg=style_cfg, depth_cfg=depth_cfg)

		# Predict phi
		phi = sd15.predict_phi(eval_sample.unsqueeze(0), branch_idx=0)
		print(f"Phi shape: {phi.shape}")

		# Sample style images
		style_depth = sd15.sample_custom(
			prompt="",
			num_images_per_prompt=1,
			cs=[phi, eval_sample.unsqueeze(0)],  # style and structure conditioning
			generator=get_generator(),
			cfg_mask=cfg_mask,
			skip_encode=[0],
			skip_mapping=[0],
		)

		# Append to results
		results[(style_scale, depth_scale)] = style_depth

# Visualize results
fig, ax = plt.subplots(len(style_scales)+1, len(depth_scales)+1, figsize=(2.2*(len(depth_scales)+1), 2.2*(len(style_scales)+1)))
# First row and column is the original image
ax[0, 0].set_title("Original Image")
ax[0, 0].imshow((eval_sample.permute(1, 2, 0).cpu().numpy() * 0.5 + 0.5).clip(0, 1))
ax[0, 0].axis('off')
# Label rows and columns based on style and depth scales
for i, s in enumerate(style_scales):
    ax[i+1, 0].text(0.5, 0.5, f"Style Scale:\n{s:.1f}", fontsize=14, va="center", ha="center")
    ax[i+1, 0].axis('off')
for j, d in enumerate(depth_scales):
    ax[0, j+1].text(0.5, 0.5, f"Depth Scale:\n{d:.1f}", fontsize=14, va="center", ha="center")
    ax[0, j+1].axis('off')
# Grid of images
for i, style_scale in enumerate(style_scales):
	for j, depth_scale in enumerate(depth_scales):
		ax[i+1, j+1].imshow(results[(style_scale, depth_scale)][0])
		ax[i+1, j+1].axis('off')
plt.tight_layout()
plt.show()

### Style + HED Structure Adapter

In [None]:
# If two scales are ablated, visualize effects only on one image
eval_sample = batch[6].clone()

style_scales = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0, 1.2, 1.4, 1.6]
hed_scales = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0, 1.2, 1.4, 1.6]

results = {}
for style_scale in style_scales:
	for hed_scale in hed_scales:
		print(f"Testing lora scale: {style_scale}, {hed_scale}")

		# Load model
		sd15 = SD15(
			pipeline_type="diffusers.StableDiffusionPipeline",
			model_name="runwayml/stable-diffusion-v1-5",
			local_files_only=False,
		).cuda().eval()

		# Set LoRA scale
		style_cfg = copy.deepcopy(style_cfg_base)
		hed_cfg = copy.deepcopy(hed_cfg_base)
		style_cfg["config"]["lora_scale"] = style_scale
		hed_cfg["config"]["lora_scale"] = hed_scale
		cfg_mask = add_adapters(sd15, raw_cfg, style_cfg=style_cfg, hed_cfg=hed_cfg)

		# Predict phi
		phi = sd15.predict_phi(eval_sample.unsqueeze(0), branch_idx=0)
		print(f"Phi shape: {phi.shape}")

		# Sample style images
		style_hed = sd15.sample_custom(
			prompt="",
			num_images_per_prompt=1,
			cs=[phi, eval_sample.unsqueeze(0)],  # style and structure conditioning
			generator=get_generator(),
			cfg_mask=cfg_mask,
			skip_encode=[0],
			skip_mapping=[0],
		)

		# Append to results
		results[(style_scale, hed_scale)] = style_hed

# Visualize results
fig, ax = plt.subplots(len(style_scales)+1, len(hed_scales)+1, figsize=(2.2*(len(hed_scales)+1), 2.2*(len(style_scales)+1)))
# First row and column is the original image
ax[0, 0].set_title("Original Image")
ax[0, 0].imshow((eval_sample.permute(1, 2, 0).cpu().numpy() * 0.5 + 0.5).clip(0, 1))
ax[0, 0].axis('off')
# Label rows and columns based on style and depth scales
for i, s in enumerate(style_scales):
    ax[i+1, 0].text(0.5, 0.5, f"Style Scale:\n{s:.1f}", fontsize=14, va="center", ha="center")
    ax[i+1, 0].axis('off')
for j, d in enumerate(hed_scales):
    ax[0, j+1].text(0.5, 0.5, f"HED Scale:\n{d:.1f}", fontsize=14, va="center", ha="center")
    ax[0, j+1].axis('off')
# Grid of images
for i, style_scale in enumerate(style_scales):
	for j, hed_scale in enumerate(hed_scales):
		ax[i+1, j+1].imshow(results[(style_scale, hed_scale)][0])
		ax[i+1, j+1].axis('off')
plt.tight_layout()
plt.show()

## LoRA Fine-Tuning Analysis

### Style Finetuned

In [None]:
# 1) Initial style
sd15 = SD15(
	pipeline_type="diffusers.StableDiffusionPipeline",
	model_name="runwayml/stable-diffusion-v1-5",
	local_files_only=False,
).cuda().eval()

cfg_mask = add_adapters(sd15, raw_cfg, style_cfg=style_cfg_base)
phi = sd15.predict_phi(batch, branch_idx=0)

initial_style = sd15.sample_custom(
	prompt="",
	num_images_per_prompt=batch.shape[0],
	cs=[phi],  # style and structure conditioning
	generator=get_generator(),
	cfg_mask=cfg_mask,
	skip_encode=True,
	skip_mapping=True,
)

# 2) Finetuned style
sd15 = SD15(
	pipeline_type="diffusers.StableDiffusionPipeline",
	model_name="runwayml/stable-diffusion-v1-5",
	local_files_only=False,
).cuda().eval()

# Set checkpoint path to fine-tuned model
style_cfg = copy.deepcopy(style_cfg_base)
style_cfg["ckpt_path"] = STYLE_FINETUNED_PATH
cfg_mask = add_adapters(sd15, raw_cfg, style_cfg=style_cfg)

# Predict phi
phi = sd15.predict_phi(batch, branch_idx=0)

# Sample images
finetuned_style = sd15.sample_custom(
	prompt="",
	num_images_per_prompt=batch.shape[0],
	cs=[phi],  # style and structure conditioning
	generator=get_generator(),
	cfg_mask=cfg_mask,
	skip_encode=True,
	skip_mapping=True,
)

# Visualize results
fig, ax = plt.subplots(3, batch.shape[0], figsize=(2*batch.shape[0], 2.2*3), squeeze=False)

# Row 0: Original images
ax[0, 0].set_title("Original Images", loc="left")
for i in range(batch.shape[0]):
	img = (batch[i].permute(1, 2, 0).cpu().numpy() * 0.5 + 0.5).clip(0, 1)
	ax[0, i].imshow(img)
	ax[0, i].axis('off')

# Row 1: Initial Style
ax[1, 0].set_title("Reconstruction based on Initial Style", loc="left")
for i in range(batch.shape[0]):
	ax[1, i].imshow(initial_style[i])
	ax[1, i].axis('off')

# Row 2: Finetuned Style
ax[2, 0].set_title("Reconstruction based on Finetuned Style", loc="left")
for i in range(batch.shape[0]):
	ax[2, i].imshow(finetuned_style[i])
	ax[2, i].axis('off')

plt.show()

### Style + Depth

In [None]:
# 1) Initial Style + Initial Depth
sd15 = SD15(
	pipeline_type="diffusers.StableDiffusionPipeline",
	model_name="runwayml/stable-diffusion-v1-5",
	local_files_only=False,
).cuda().eval()

cfg_mask = add_adapters(sd15, raw_cfg, style_cfg=style_cfg_base, depth_cfg=depth_cfg_base)
phi = sd15.predict_phi(batch, branch_idx=0)

initial_style_initial_depth = sd15.sample_custom(
	prompt="",
	num_images_per_prompt=batch.shape[0],
	cs=[phi, batch],
	generator=get_generator(),
	cfg_mask=cfg_mask,
	skip_encode=[0],
	skip_mapping=[0],
)

# 2) Initial Style + Finetuned Depth
sd15 = SD15(
	pipeline_type="diffusers.StableDiffusionPipeline",
	model_name="runwayml/stable-diffusion-v1-5",
	local_files_only=False,
).cuda().eval()

depth_cfg = copy.deepcopy(depth_cfg_base)
depth_cfg["ckpt_path"] = STYLE_DEPTH_FINETUNED_PATH
cfg_mask = add_adapters(sd15, raw_cfg, style_cfg=style_cfg_base, depth_cfg=depth_cfg)
phi = sd15.predict_phi(batch, branch_idx=0)

initial_style_finetuned_depth = sd15.sample_custom(
	prompt="",
	num_images_per_prompt=batch.shape[0],
	cs=[phi, batch],
	generator=get_generator(),
	cfg_mask=cfg_mask,
	skip_encode=[0],
	skip_mapping=[0],
)

# 3) Finetuned Style + Initial Depth
sd15 = SD15(
	pipeline_type="diffusers.StableDiffusionPipeline",
	model_name="runwayml/stable-diffusion-v1-5",
	local_files_only=False,
).cuda().eval()

style_cfg = copy.deepcopy(style_cfg_base)
style_cfg["ckpt_path"] = STYLE_FINETUNED_PATH
cfg_mask = add_adapters(sd15, raw_cfg, style_cfg=style_cfg, depth_cfg=depth_cfg_base)
phi = sd15.predict_phi(batch, branch_idx=0)

finetuned_style_initial_depth = sd15.sample_custom(
	prompt="",
	num_images_per_prompt=batch.shape[0],
	cs=[phi, batch],
	generator=get_generator(),
	cfg_mask=cfg_mask,
	skip_encode=[0],
	skip_mapping=[0],
)

# 4) Finetuned Style + Finetuned Depth
sd15 = SD15(
	pipeline_type="diffusers.StableDiffusionPipeline",
	model_name="runwayml/stable-diffusion-v1-5",
	local_files_only=False,
).cuda().eval()

style_cfg = copy.deepcopy(style_cfg_base)
depth_cfg = copy.deepcopy(depth_cfg_base)
style_cfg["ckpt_path"] = STYLE_DEPTH_FINETUNED_PATH
depth_cfg["ckpt_path"] = STYLE_DEPTH_FINETUNED_PATH
cfg_mask = add_adapters(sd15, raw_cfg, style_cfg=style_cfg, depth_cfg=depth_cfg)
phi = sd15.predict_phi(batch, branch_idx=0)

finetuned_style_finetuned_depth = sd15.sample_custom(
	prompt="",
	num_images_per_prompt=batch.shape[0],
	cs=[phi, batch],
	generator=get_generator(),
	cfg_mask=cfg_mask,
	skip_encode=[0],
	skip_mapping=[0],
)

# Visualize results
fig, ax = plt.subplots(5, batch.shape[0], figsize=(2*batch.shape[0], 2.2*5), squeeze=False)

# Row 0: Original images
ax[0, 0].set_title("Original Images", loc="left")
for i in range(batch.shape[0]):
	img = (batch[i].permute(1, 2, 0).cpu().numpy() * 0.5 + 0.5).clip(0, 1)
	ax[0, i].imshow(img)
	ax[0, i].axis('off')

# Row 1: Initial Style + Initial Depth
ax[1, 0].set_title("Reconstruction based on Initial Style + Initial Depth", loc="left")
for i in range(batch.shape[0]):
	ax[1, i].imshow(initial_style_initial_depth[i])
	ax[1, i].axis('off')

# Row 2: Initial Style + Finetuned Depth
ax[2, 0].set_title("Reconstruction based on Initial Style + Finetuned Depth", loc="left")
for i in range(batch.shape[0]):
	ax[2, i].imshow(initial_style_finetuned_depth[i])
	ax[2, i].axis('off')

# Row 3: Finetuned Style + Initial Depth
ax[3, 0].set_title("Reconstruction based on Finetuned Style + Initial Depth", loc="left")
for i in range(batch.shape[0]):
	ax[3, i].imshow(finetuned_style_initial_depth[i])
	ax[3, i].axis('off')

# Row 4: Finetuned Style + Finetuned Depth
ax[4, 0].set_title("Reconstruction based on Finetuned Style + Finetuned Depth", loc="left")
for i in range(batch.shape[0]):
	ax[4, i].imshow(finetuned_style_finetuned_depth[i])
	ax[4, i].axis('off')

plt.show()

### Style + HED Structure Adapter

In [None]:
# 1) Initial Style + Initial HED
sd15 = SD15(
	pipeline_type="diffusers.StableDiffusionPipeline",
	model_name="runwayml/stable-diffusion-v1-5",
	local_files_only=False,
).cuda().eval()

cfg_mask = add_adapters(sd15, raw_cfg, style_cfg=style_cfg_base, hed_cfg=hed_cfg_base)
phi = sd15.predict_phi(batch, branch_idx=0)

initial_style_initial_hed = sd15.sample_custom(
	prompt="",
	num_images_per_prompt=batch.shape[0],
	cs=[phi, batch],
	generator=get_generator(),
	cfg_mask=cfg_mask,
	skip_encode=[0],
	skip_mapping=[0],
)

# 2) Initial Style + Finetuned HED
sd15 = SD15(
	pipeline_type="diffusers.StableDiffusionPipeline",
	model_name="runwayml/stable-diffusion-v1-5",
	local_files_only=False,
).cuda().eval()

hed_cfg = copy.deepcopy(hed_cfg_base)
hed_cfg["ckpt_path"] = STYLE_HED_FINETUNED_PATH
cfg_mask = add_adapters(sd15, raw_cfg, style_cfg=style_cfg_base, hed_cfg=hed_cfg)
phi = sd15.predict_phi(batch, branch_idx=0)

initial_style_finetuned_hed = sd15.sample_custom(
	prompt="",
	num_images_per_prompt=batch.shape[0],
	cs=[phi, batch],
	generator=get_generator(),
	cfg_mask=cfg_mask,
	skip_encode=[0],
	skip_mapping=[0],
)

# 3) Finetuned Style + Initial HED
sd15 = SD15(
	pipeline_type="diffusers.StableDiffusionPipeline",
	model_name="runwayml/stable-diffusion-v1-5",
	local_files_only=False,
).cuda().eval()

style_cfg = copy.deepcopy(style_cfg_base)
style_cfg["ckpt_path"] = STYLE_FINETUNED_PATH
cfg_mask = add_adapters(sd15, raw_cfg, style_cfg=style_cfg, hed_cfg=hed_cfg_base)
phi = sd15.predict_phi(batch, branch_idx=0)

finetuned_style_initial_hed = sd15.sample_custom(
	prompt="",
	num_images_per_prompt=batch.shape[0],
	cs=[phi, batch],
	generator=get_generator(),
	cfg_mask=cfg_mask,
	skip_encode=[0],
	skip_mapping=[0],
)

# 4) Finetuned Style + Finetuned HED
sd15 = SD15(
	pipeline_type="diffusers.StableDiffusionPipeline",
	model_name="runwayml/stable-diffusion-v1-5",
	local_files_only=False,
).cuda().eval()

style_cfg = copy.deepcopy(style_cfg_base)
hed_cfg = copy.deepcopy(hed_cfg_base)
style_cfg["ckpt_path"] = STYLE_HED_FINETUNED_PATH
hed_cfg["ckpt_path"] = STYLE_HED_FINETUNED_PATH
cfg_mask = add_adapters(sd15, raw_cfg, style_cfg=style_cfg, hed_cfg=hed_cfg)
phi = sd15.predict_phi(batch, branch_idx=0)

finetuned_style_finetuned_hed = sd15.sample_custom(
	prompt="",
	num_images_per_prompt=batch.shape[0],
	cs=[phi, batch],
	generator=get_generator(),
	cfg_mask=cfg_mask,
	skip_encode=[0],
	skip_mapping=[0],
)

# Visualize results
fig, ax = plt.subplots(5, batch.shape[0], figsize=(2*batch.shape[0], 2.2*5), squeeze=False)

# Row 0: Original images
ax[0, 0].set_title("Original Images", loc="left")
for i in range(batch.shape[0]):
	img = (batch[i].permute(1, 2, 0).cpu().numpy() * 0.5 + 0.5).clip(0, 1)
	ax[0, i].imshow(img)
	ax[0, i].axis('off')

# Row 1: Initial Style + Initial HED
ax[1, 0].set_title("Reconstruction based on Initial Style + Initial HED", loc="left")
for i in range(batch.shape[0]):
	ax[1, i].imshow(initial_style_initial_hed[i])
	ax[1, i].axis('off')

# Row 2: Initial Style + Finetuned HED
ax[2, 0].set_title("Reconstruction based on Initial Style + Finetuned HED", loc="left")
for i in range(batch.shape[0]):
	ax[2, i].imshow(initial_style_finetuned_hed[i])
	ax[2, i].axis('off')

# Row 3: Finetuned Style + Initial HED
ax[3, 0].set_title("Reconstruction based on Finetuned Style + Initial HED", loc="left")
for i in range(batch.shape[0]):
	ax[3, i].imshow(finetuned_style_initial_hed[i])
	ax[3, i].axis('off')

# Row 4: Finetuned Style + Finetuned HED
ax[4, 0].set_title("Reconstruction based on Finetuned Style + Finetuned HED", loc="left")
for i in range(batch.shape[0]):
	ax[4, i].imshow(finetuned_style_finetuned_hed[i])
	ax[4, i].axis('off')

plt.show()