# Multi-scale single-image guidance

Optimize a single learnable image with multi-scale CLIP loss (no overlay, no rotation).

In [None]:
import numpy as np
import rp
import torch
import source.stable_diffusion as sd
from source.learnable_textures import LearnableImageFourier
from source.clip import get_clip_image_similarity, get_clip_logits


In [None]:
# Text prompt (optional). Leave '' to disable.
prompt_a = ''
prompt_b = ''
prompt_c = ''
text_prompt_weight = 0.5  # 0 disables text guidance


In [None]:
# Multi-scale loss configuration
multi_scales = [1.0, 0.5, 0.25]
multi_scale_weights = [1.0, 0.5, 0.25]
assert len(multi_scales) == len(multi_scale_weights)


In [None]:
s = sd._get_stable_diffusion_singleton()
DEVICE = s.device

learnable_image = LearnableImageFourier(512, 512, 3).to(DEVICE)

optim = torch.optim.SGD(learnable_image.parameters(), lr=1e-3)


In [None]:
def multiscale_clip_loss(composite, text_prompt):
    loss = 0.0
    for scale, w in zip(multi_scales, multi_scale_weights):
        if scale == 1.0:
            comp_s = composite
        else:
            size = int(512 * scale)
            comp_s = rp.torch_resize_image(composite, (size, size))
        if text_prompt:
            text_logit = get_clip_logits(comp_s, text_prompt)[0]
            loss = loss - w * text_prompt_weight * text_logit
    return loss


In [None]:
NUM_ITER = 8000
DISPLAY_INTERVAL = 200

ims = []
display_eta = rp.eta(NUM_ITER, title='Status: ')

try:
    for iter_num in range(NUM_ITER):
        display_eta(iter_num)
        composite = learnable_image()
        loss = multiscale_clip_loss(composite, text_prompt)
        loss.backward(retain_graph=True)

        with torch.no_grad():
            if iter_num and not iter_num % (DISPLAY_INTERVAL * 50):
                from IPython.display import clear_output
                clear_output()
            if not iter_num % DISPLAY_INTERVAL:
                img = composite.detach().cpu()
                im = rp.as_numpy_image(img)
                ims.append(im)
                rp.display_image(im)

        optim.step()
        optim.zero_grad()
except KeyboardInterrupt:
    print()
    print(f'Interrupted early at iteration {iter_num}')
    if ims:
        rp.display_image(ims[-1])
