# Multi-scale overlay guidance (image + optional text prompts)

This notebook adds a multi-scale CLIP loss: large scale keeps the global silhouette, small scale enriches texture/details.

In [None]:
import numpy as np
import rp
import torch
import torch.nn as nn
import source.stable_diffusion as sd
from easydict import EasyDict
from source.learnable_textures import LearnableImageFourier
from source.stable_diffusion_labels import NegativeLabel
from itertools import chain
from source.clip import get_clip_image_similarity, get_clip_logits


In [None]:
# Image prompts (required)
image_prompt_a = rp.load_image('source/GI.jpg')
image_prompt_b = rp.load_image('source/HI.jpg')
image_prompt_c = rp.load_image('source/HSR.jpg')
image_prompt_d = rp.load_image('source/zzz.jpg')

image_prompt_a = rp.as_rgb_image(rp.as_float_image(image_prompt_a))
image_prompt_b = rp.as_rgb_image(rp.as_float_image(image_prompt_b))
image_prompt_c = rp.as_rgb_image(rp.as_float_image(image_prompt_c))
image_prompt_d = rp.as_rgb_image(rp.as_float_image(image_prompt_d))


In [None]:
# Text prompts (optional). Leave '' to disable.
text_prompt_w = ''
text_prompt_x = ''
text_prompt_y = ''
text_prompt_z = ''

text_prompt_weight = 0.5  # 0 disables text guidance


In [None]:
# Multi-scale loss configuration
# Use higher weights for larger scales to keep global structure, smaller for detail.
multi_scales = [1.0, 0.5, 0.25]
multi_scale_weights = [1.0, 0.5, 0.25]
assert len(multi_scales) == len(multi_scale_weights)


In [None]:
s = sd._get_stable_diffusion_singleton()
DEVICE = s.device

# Learnable texture setup
bottom_image = LearnableImageFourier(512, 512, 3).to(DEVICE)
top_image = LearnableImageFourier(512, 512, 3).to(DEVICE)

def simulate_overlay(bottom, top, clean=True):
    if clean:
        exp=1
        brightness=3
        black=0
    else:
        exp=rp.random_float(.5,1)
        brightness=rp.random_float(1,5)
        black=rp.random_float(0,.5)
        bottom=rp.blend(bottom,black,rp.random_float())
        top=rp.blend(top,black,rp.random_float())
    return (bottom**exp * top**exp * brightness).clamp(0,99).tanh()

def bottom_image_torch():
    return bottom_image()

def top_image_torch():
    return top_image()

learnable_image_w = lambda: simulate_overlay(bottom_image_torch(), top_image_torch().rot90(k=0, dims=[1,2]))
learnable_image_x = lambda: simulate_overlay(bottom_image_torch(), top_image_torch().rot90(k=1, dims=[1,2]))
learnable_image_y = lambda: simulate_overlay(bottom_image_torch(), top_image_torch().rot90(k=2, dims=[1,2]))
learnable_image_z = lambda: simulate_overlay(bottom_image_torch(), top_image_torch().rot90(k=3, dims=[1,2]))

image_prompt_a = rp.as_torch_image(image_prompt_a).to(DEVICE)
image_prompt_b = rp.as_torch_image(image_prompt_b).to(DEVICE)
image_prompt_c = rp.as_torch_image(image_prompt_c).to(DEVICE)
image_prompt_d = rp.as_torch_image(image_prompt_d).to(DEVICE)

learnable_images = [learnable_image_w, learnable_image_x, learnable_image_y, learnable_image_z]
weights = np.array([1,1,1,1], dtype=np.float32)
weights = weights / weights.sum()
weights = weights * len(weights)

params = chain(bottom_image.parameters(), top_image.parameters())
optim = torch.optim.SGD(params, lr=1e-3)


In [None]:
def multiscale_clip_loss(composite, prompt_image, text_prompt):
    loss = 0.0
    for scale, w in zip(multi_scales, multi_scale_weights):
        if scale == 1.0:
            comp_s = composite
            prompt_s = prompt_image
        else:
            size = int(512 * scale)
            comp_s = rp.torch_resize_image(composite, (size, size))
            prompt_s = rp.torch_resize_image(prompt_image, (size, size))
        loss = loss - w * get_clip_image_similarity(comp_s, prompt_s)
    if text_prompt_weight and text_prompt:
        text_logit = get_clip_logits(composite, text_prompt)[0]
        loss = loss - text_prompt_weight * text_logit
    return loss


In [None]:
NUM_ITER = 10000
DISPLAY_INTERVAL = 200

text_prompts = [text_prompt_w, text_prompt_x, text_prompt_y, text_prompt_z]

ims = []
display_eta = rp.eta(NUM_ITER, title='Status: ')

try:
    for iter_num in range(NUM_ITER):
        display_eta(iter_num)
        for learnable_image, weight, prompt_image, text_prompt in rp.random_batch(
            list(zip(learnable_images, weights, [image_prompt_a, image_prompt_b, image_prompt_c, image_prompt_d], text_prompts)),
            batch_size=1,
        ):
            composite = learnable_image()
            loss = multiscale_clip_loss(composite, prompt_image, text_prompt)
            loss.backward(retain_graph=True)

        with torch.no_grad():
            if iter_num and not iter_num % (DISPLAY_INTERVAL * 50):
                from IPython.display import clear_output
                clear_output()
            if not iter_num % DISPLAY_INTERVAL:
                # display overlay and components
                img_w = learnable_image_w().detach().cpu()
                img_x = learnable_image_x().detach().cpu()
                img_y = learnable_image_y().detach().cpu()
                img_z = learnable_image_z().detach().cpu()
                bottom = bottom_image().detach().cpu()
                top = top_image().detach().cpu()

                im = rp.as_numpy_image(rp.grid_concatenated_images([
                    [rp.as_numpy_image(img_w), rp.as_numpy_image(img_x), rp.as_numpy_image(img_y), rp.as_numpy_image(img_z)],
                    [rp.as_numpy_image(bottom), rp.as_numpy_image(top)],
                ]))
                ims.append(im)
                rp.display_image(im)

        optim.step()
        optim.zero_grad()
except KeyboardInterrupt:
    print()
    print(f'Interrupted early at iteration {iter_num}')
    if ims:
        rp.display_image(ims[-1])
