# CTRLorALTer

This notebook visualizes the approach of optimization in the CTRLorALTer latent space using a sample batch.

## Setup

### Load Data

In [None]:
# Create data module
from src.dataloader.ffhq import FFHQWeightedDataset
from src.dataloader.weighting import DataWeighter
from argparse import Namespace
from torchvision import transforms

args = Namespace(
    img_dir="../data/ffhq/images1024x1024",
    img_tensor_dir="../data/ffhq/pt_images",
    attr_path="../data/ffhq/ffhq_smile_scores.json",
    max_property_value=5,
    min_property_value=0,
    batch_size=8,
    num_workers=0,
    val_split=0,
    data_device="cuda",
    aug=True,
    weight_type="uniform",
    rank_weight_k=1e-3,
    weight_quantile=None,
    dbas_noise=None,
    rwr_alpha=None,
)
transform = transforms.Compose([
	transforms.Resize((512, 512)),
	transforms.ToTensor(),
	transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

datamodule = FFHQWeightedDataset(args, DataWeighter(args), transform=transform)

In [None]:
# get one batch of data
batch = next(iter(datamodule.train_dataloader()))
batch = batch.to("cuda")
batch.shape

### Load Model

In [None]:
from src.ctrloralter.model import SD15

sd15 = SD15(
	pipeline_type="diffusers.StableDiffusionPipeline",
	model_name="runwayml/stable-diffusion-v1-5",
	local_files_only=False,
).cuda().eval()

### Load Adapter

In [None]:
raw_cfg = {
	"ckpt_path": "ctrloralter/checkpoints",
	"lora": {},
}

#### Style Adapter (Required)

In [None]:
from src.ctrloralter.annotators.openclip import VisionModel
from src.ctrloralter.mapper_network import SimpleMapper

style_cfg = {
	"enable": "always",
	"optimize": False,
	"ckpt_path": "ctrloralter/checkpoints/sd15-style-cross-160-h",
	"ignore_check": False,
	"cfg": True,
	"transforms": [],
	"config": {
		"lora_scale": 1.0,
		"rank": 160,
		"c_dim": 1024,
		"adaption_mode": "only_cross",
		"lora_cls": "SimpleLoraLinear",
		"broadcast_tokens": True,
	},
	"encoder": VisionModel(clip_model="laion/CLIP-ViT-H-14-laion2B-s32B-b79K", local_files_only=False),
	"mapper_network": SimpleMapper(1024, 1024),
}

#### Depth Structure Adapter (Optional)

In [None]:
from src.ctrloralter.annotators.midas import DepthEstimator
from src.ctrloralter.mapper_network import FixedStructureMapper15

depth_cfg = {
	"enable": "always",
	"optimize": False,
	"ckpt_path": "ctrloralter/checkpoints/sd15-depth-128-only-res",
	"ignore_check": False,
	"cfg": False,
	"transforms": [],
	"config": {
		"lora_scale": 0.35,
		"rank": 128,
		"c_dim": 128,
		"adaption_mode": "only_res_conv",
		"lora_cls": "NewStructLoRAConv",
	},
	"encoder": DepthEstimator(size=512, local_files_only=False),
	"mapper_network": FixedStructureMapper15(c_dim=128),
}

#### HED Structure Adapter (Optional)

In [None]:
from src.ctrloralter.annotators.hed import TorchHEDdetector
from src.ctrloralter.mapper_network import FixedStructureMapper15

hed_cfg = {
	"enable": "always",
	"optimize": False,
	"ckpt_path": "ctrloralter/checkpoints/sd15-hed-128-only-res",
	"ignore_check": False,
	"cfg": False,
	"transforms": [],
	"config": {
		"lora_scale": 1.0,
		"rank": 128,
		"c_dim": 128,
		"adaption_mode": "only_res_conv",
		"lora_cls": "NewStructLoRAConv",
	},
	"encoder": TorchHEDdetector(size=512, local_files_only=False),
	"mapper_network": FixedStructureMapper15(c_dim=128),
}

### Add Adapters to Model

In [None]:
from omegaconf import OmegaConf
from src.ctrloralter.utils import add_lora_from_config
import copy

def add_adapters(model, raw_cfg, style_cfg=None, depth_cfg=None, hed_cfg=None, device="cuda"):
	cfg = copy.deepcopy(raw_cfg)

	if style_cfg is not None:
		cfg["lora"]["style"] = style_cfg
	if depth_cfg is not None:
		cfg["lora"]["struct"] = depth_cfg
	elif hed_cfg is not None:
		cfg["lora"]["struct"] = hed_cfg
            
	# wrap it in a DictConfig
	cfg = OmegaConf.create(cfg, flags={"allow_objects": True})

	add_lora_from_config(model, cfg, device=device)

In [None]:
add_adapters(sd15, raw_cfg, style_cfg=style_cfg, depth_cfg=depth_cfg)

## Predict phi

In [None]:
phi = sd15.predict_phi(batch.to("cuda"), branch_idx=0)

## Sample Images

Sample image from the model using the obtained $\varphi$ as condition. Note that these $\varphi$ have not been optimized, but are the direct output of the global mapper of the style adapter. So the sampled images can't be seen as optimized images, but rather as some form reconstruction of the input images.

In [None]:
sampled_images = sd15.sample_custom(
    prompt="",
    num_images_per_prompt=batch.shape[0],
    cs=[
        phi,    # style conditioning
        batch,  # structure conditioning
    ],
    generator=None,
    skip_encode=[0], # skip encoding the first conditioning (style, already in phi)
    skip_mapping=[0], # skip mapping the first conditioning
)

In [None]:
# Visualize input images (batch) and sampled images (sampled_images) next to each other
import matplotlib.pyplot as plt

fig, axes = plt.subplots(2, batch.shape[0], figsize=(2*batch.shape[0], 4))
for i in range(batch.shape[0]):
	axes[0, i].imshow(batch[i].permute(1, 2, 0).cpu().numpy() * 0.5 + 0.5)
	axes[0, i].axis('off')
	axes[1, i].imshow(sampled_images[i])
	axes[1, i].axis('off')
plt.tight_layout()
plt.show()

## Supplements

### Visualize Depth Maps

In [None]:
from src.ctrloralter.annotators.midas import DepthEstimator

de = DepthEstimator(size=512, local_files_only=False).to("cuda").eval()

depths = de(batch)
depths = depths.mean(dim=1, keepdim=True)  # Average over the color

# Visualize input images (batch) and their corresponding depth maps
fig, axes = plt.subplots(2, batch.shape[0], figsize=(2*batch.shape[0], 4))
for i in range(batch.shape[0]):
	axes[0, i].imshow(batch[i].permute(1, 2, 0).cpu().numpy() * 0.5 + 0.5)
	axes[0, i].axis('off')
	axes[1, i].imshow(depths[i].permute(1, 2, 0).cpu().numpy(), cmap='gray')
	axes[1, i].axis('off')
plt.tight_layout()
plt.show()

### Visualize HED Maps

In [None]:
from src.ctrloralter.annotators.hed import TorchHEDdetector

hed = TorchHEDdetector(size=512, local_files_only=False).to("cuda").eval()

edges = hed(batch)
edges = edges.mean(dim=1, keepdim=True)  # Average over the color

# Visualize input images (batch) and their corresponding depth maps
fig, axes = plt.subplots(2, batch.shape[0], figsize=(2*batch.shape[0], 4))
for i in range(batch.shape[0]):
	axes[0, i].imshow(batch[i].permute(1, 2, 0).cpu().numpy() * 0.5 + 0.5)
	axes[0, i].axis('off')
	axes[1, i].imshow(edges[i].permute(1, 2, 0).cpu().numpy(), cmap='gray')
	axes[1, i].axis('off')
plt.tight_layout()
plt.show()

## All together

#### Direct style reconstructions (without structure adapter)

In [None]:
# Configure model
sd15 = SD15(
	pipeline_type="diffusers.StableDiffusionPipeline",
	model_name="runwayml/stable-diffusion-v1-5",
	local_files_only=False,
).cuda().eval()

# Add only style adapter
add_adapters(sd15, raw_cfg, style_cfg=style_cfg)

# Predict phi
phi = sd15.predict_phi(batch.to("cuda"), branch_idx=0)

# Sample style images
style = sd15.sample_custom(
    prompt="",
    num_images_per_prompt=batch.shape[0],
    cs=[phi],
    generator=None,
    skip_encode=True,
    skip_mapping=True,
)

#### Style + Depth Structure Adapter

In [None]:
# Configure model
sd15 = SD15(
	pipeline_type="diffusers.StableDiffusionPipeline",
	model_name="runwayml/stable-diffusion-v1-5",
	local_files_only=False,
).cuda().eval()

depth_model = DepthEstimator(size=512, local_files_only=False).to("cuda").eval()
depth_maps = depth_model(batch)
depth_maps = depth_maps.mean(dim=1, keepdim=True)  # average over color channels

# Add only style adapter
add_adapters(sd15, raw_cfg, style_cfg=style_cfg, depth_cfg=depth_cfg)

# Predict phi
phi = sd15.predict_phi(batch.to("cuda"), branch_idx=0)

# Sample images
style_depth = sd15.sample_custom(
    prompt="",
    num_images_per_prompt=batch.shape[0],
    cs=[phi, batch],  # style and structure conditioning
    generator=None,
    skip_encode=[0],
    skip_mapping=[0],
)

#### Style + HED Structure Adapter

In [None]:
# Configure model
sd15 = SD15(
	pipeline_type="diffusers.StableDiffusionPipeline",
	model_name="runwayml/stable-diffusion-v1-5",
	local_files_only=False,
).cuda().eval()

hed_model = TorchHEDdetector(size=512, local_files_only=False).to("cuda").eval()
hed_maps = hed_model(batch)
hed_maps = hed_maps.mean(dim=1, keepdim=True)  # average over color channels

# Add only style adapter
add_adapters(sd15, raw_cfg, style_cfg=style_cfg, hed_cfg=hed_cfg)

# Predict phi
phi = sd15.predict_phi(batch.to("cuda"), branch_idx=0)

# Sample images
style_hed = sd15.sample_custom(
    prompt="",
    num_images_per_prompt=batch.shape[0],
    cs=[phi, batch],  # style and structure conditioning
    generator=None,
    skip_encode=[0],
    skip_mapping=[0],
)

#### Visualize all images

In [None]:
import matplotlib.pyplot as plt

fig, ax = plt.subplots(6, batch.shape[0], figsize=(2*batch.shape[0], 2.2*6), squeeze=False)

# Row 0: Original images
ax[0, 0].set_title("Original Images", loc="left")
for i in range(batch.shape[0]):
	img = (batch[i].permute(1, 2, 0).cpu().numpy() * 0.5 + 0.5).clip(0, 1)
	ax[0, i].imshow(img)
	ax[0, i].axis('off')

# Row 1: Style images
ax[1, 0].set_title("Reconstruction based on Style", loc="left")
for i in range(batch.shape[0]):
	ax[1, i].imshow(style[i])
	ax[1, i].axis('off')

# Row 2: Depth Maps
ax[2, 0].set_title("Depth Maps", loc="left")
for i in range(batch.shape[0]):
	img = (depth_maps[i].permute(1, 2, 0).cpu().numpy() * 0.5 + 0.5).clip(0, 1)
	ax[2, i].imshow(img, cmap='gray')
	ax[2, i].axis('off')

# Row 3: Style + Depth
ax[3, 0].set_title("Reconstruction based on Style + Depth", loc="left")
for i in range(batch.shape[0]):
	ax[3, i].imshow(style_depth[i])
	ax[3, i].axis('off')

# Row 4: HED Maps
ax[4, 0].set_title("HED Maps", loc="left")
for i in range(batch.shape[0]):
	img = (hed_maps[i].permute(1, 2, 0).cpu().numpy() * 0.5 + 0.5).clip(0, 1)
	ax[4, i].imshow(img, cmap='gray')
	ax[4, i].axis('off')

# Row 5: Style + HED
ax[5, 0].set_title("Reconstruction based on Style + HED", loc="left")
for i in range(batch.shape[0]):
	ax[5, i].imshow(style_hed[i])
	ax[5, i].axis('off')
	
#plt.tight_layout()
plt.show()