## Feature Inject

Implementation of feature inject ([paper](https://arxiv.org/abs/2501.14524)) using hooks.

In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

import json
import os
import torch
import gc
from argparse import Namespace


from main import main
import yaml
import gc

# Determine device.
device = 'cuda' if torch.cuda.is_available() else 'mps'

model = "stabilityai/sd-turbo" #"CompVis/stable-diffusion-v1-4" 
variant = 'fp16'
model_name ='auto'
image_size = 512 
float_ = torch.float16



The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
from utils.utils_test import generate_triplet
output_folder = 'outputs'
os.makedirs(output_folder, exist_ok=True)

selected_skip_keys = [
        ['unet.up_blocks.0.resnets.0'],
        ['unet.up_blocks.0.resnets.1'],
        ['unet.up_blocks.0.resnets.2'],
        ['unet.up_blocks.1.resnets.0'],
        ['unet.up_blocks.0.resnets.0',
        'unet.up_blocks.1.resnets.0'],
        ['unet.up_blocks.0.resnets.0',
        'unet.up_blocks.0.resnets.1',
        'unet.up_blocks.0.resnets.2'],
        ['unet.up_blocks.0.resnets.0',
        'unet.up_blocks.0.resnets.1',
        'unet.up_blocks.0.resnets.2',
        'unet.up_blocks.1.resnets.0']]



generate_triplet(
    source_prompt="a dog on a chair",
    target_prompt="a cat on a chair",
    output_folder="outputs/triplet_test",
    model=model,
    model_name=model_name,
    variant=variant,
    device=device,
    image_size=image_size,
    selected_skip_keys=selected_skip_keys,
    float_=float_,
    main_fn=main  # your generation function
)


In [None]:
args = {}

output_folder = 'outputs_eval'
os.makedirs(output_folder, exist_ok=True)

selected_skip_keys = [
        ['unet.up_blocks.0.resnets.0'],
        ['unet.up_blocks.0.resnets.1'],
        ['unet.up_blocks.0.resnets.2'],
        ['unet.up_blocks.1.resnets.0'],
        ['unet.up_blocks.0.resnets.0',
        'unet.up_blocks.1.resnets.0'],
        ['unet.up_blocks.0.resnets.0',
        'unet.up_blocks.0.resnets.1',
        'unet.up_blocks.0.resnets.2'],
        ['unet.up_blocks.0.resnets.0',
        'unet.up_blocks.0.resnets.1',
        'unet.up_blocks.0.resnets.2',
        'unet.up_blocks.1.resnets.0']]


yml_file = 'data/pnp/wild-ti2i-fake.yaml'

with open(yml_file, "r") as f:
    tests = yaml.safe_load(f)  # This should be a list of dicts (each dict is one test)


os.makedirs(output_folder, exist_ok=True)

for idx, test in enumerate(tests):
    scale = test.get("scale", 7.5)
    seed = test.get("seed", 0)
    ddim_steps = test.get("ddim_steps", 50)
    source_prompt = test.get("source_prompt", "")
    target_prompts = test.get("target_prompts", [])
        
    if 'turbo' in model or 'schnell' in model:
        scale = 0.0
        ddim_steps = 3

    if 'kandinsky' in model:
        ddim_steps = min(ddim_steps, 30)
        
    print(f"\n--- Running test #{idx+1} ---")
    print(f"scale={scale}, seed={seed}, ddim_steps={ddim_steps}")
    print(f"source_prompt='{source_prompt}'")
    print(f"target_prompts={target_prompts}")
    
    for j, target_prompt in enumerate(target_prompts):
        test_tag = f"test{idx+1}_pair{j+1}"

        # ---- Generate A & B once ----
        base_args = {
            'out_dir': output_folder,
            'prompt_A': target_prompt,
            'variant': variant,
            'device': device,
            'prompt_B': source_prompt,
            'image_size': image_size,
            'model': model,
            'model_name': model_name,
            'guidance_scale': 0.0 if ('turbo' in model) or ('schnell' in model) else scale,
            'num_inference_steps': args.get('num_inference_steps', ddim_steps),
            'seed': seed,
            'float': float_,
            'timesteps': [1000, 0],
            'switch_guidance': {},
            'selected_skip_keys': selected_skip_keys[0]
        }

        print(f"Generating A & B for {source_prompt} -> {target_prompt}")
        image_A, image_B, injected_skips, pipe_B = main(Namespace(**base_args), save_results=False, save_b=True)
        image_A.save(os.path.join(output_folder, f"A_{test_tag}.png"))
        image_B.save(os.path.join(output_folder, f"B_{test_tag}.png"))

        # ---- Loop over hyperparameters for C ----
        switch_guidance_list = [{}]#, 0.9, 1.2]
        timestep_list = [[1000, 200]]#, [1000, 100], [1000,200]]

        for skips in selected_skip_keys:
            print(skips)
            skip_tag = f"skips_{'_'.join([s.split('.')[-1] +'_' + s.split('.')[-3] for s in skips])}"

            for sg, ts in zip(switch_guidance_list, timestep_list):
                hyper_args = base_args.copy()
                hyper_args.update({
                    'switch_guidance': sg,
                    'timesteps': ts,
                    'selected_skip_keys': skips
                })

                print(f"Generating C with skip={skip_tag}, SG={sg}, timesteps={ts}")
                image_C = main(Namespace(**hyper_args), injected_skips=injected_skips, pipe_B=pipe_B, save_results=False)
    
                # Save C with detailed name
                sg_tag = f"SG{sg}_T{ts[0]}-{ts[1]}"
                filename = f"C_{test_tag}_{skip_tag}_{sg_tag}.png"
                image_C.save(os.path.join(output_folder, filename))
                
                # Save metadata
                metadata = {
                    "test_pair": test_tag,
                    "source_prompt": source_prompt,
                    "target_prompt": target_prompt,
                    "scale": scale,
                    "seed": seed,
                    "ddim_steps": ddim_steps,
                    "skip_injection": skips,
                    "switch_guidance": sg,
                    "timesteps": ts
                }
                meta_name = f"metadata_{test_tag}_{skip_tag}_{sg_tag}.json"
                with open(os.path.join(output_folder, meta_name), "w") as f:
                    json.dump(metadata, f, indent=4)
        del pipe_B
        gc.collect()

        #time.sleep(5)

    # Save general config for this test
    config = {
        "scale": scale,
        "seed": seed,
        "ddim_steps": ddim_steps,
        "model": model,
        "selected_skip_keys": selected_skip_keys
    }
    config_path = os.path.join(output_folder, "config.json")
    with open(config_path, "w") as f:
        json.dump(config, f, indent=4)
    print(f"Saved experiment config to {config_path}")



--- Running test #1 ---
scale=0.0, seed=50, ddim_steps=3
source_prompt='a photo of a horse in mud'
target_prompts=['a photo of a zebra in the snow', 'a photo of a husky on the grass', 'an oil painting of a white horse', 'a photo of a blue horse toy in playroom']
Generating A & B for a photo of a horse in mud -> a photo of a zebra in the snow

Initializing Pipeline A (skip capture mode)...


Loading pipeline components...: 100%|██████████| 5/5 [00:00<00:00,  8.93it/s]
You have disabled the safety checker for <class 'diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline'> by passing `safety_checker=None`. Ensure that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered results in services or applications open to the public. Both the diffusers team and Hugging Face strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling it only for use-cases that involve analyzing network behavior or auditing its results. For more information, please have a look at https://github.com/huggingface/diffusers/pull/254 .


In [10]:
import random 
os.makedirs(output_folder, exist_ok=True)
import sys
sys.path.insert(0, './test')
from prompts import scribblr_prompts, stockimg_prompts, objects, backgrounds, styles
    
test_type = 'controlled'

output_folder = 'outputs_test_2turbo_new_'+test_type
os.makedirs(output_folder, exist_ok=True)

if test_type == 'wild':
    prompts = scribblr_prompts + stockimg_prompts

    # sample a list of 500 random combinations of prompts (no duplicates)
    sampled_prompts = random.sample([(a, b) for a in prompts for b in prompts if a != b], 10)         
    print(sampled_prompts)
else:
    sampled_prompts = []
    # sample combinations of object, background, style 500 times
    for i in range(500): 
        obj = random.choice(objects)
        back = random.choice(backgrounds)
        style = random.choice(styles)
        prompt1 = f"A high-resolution image of a {obj} in the {back}, {style}"
        
        obj_2 = random.choice(objects)
        back_2 = random.choice(backgrounds)
        style_2 = random.choice(styles)
        prompt2 = f"A high-resolution image of a {obj_2} in the {back_2}, {style_2}"
        
        if prompt1 != prompt2 and (prompt1, prompt2) not in sampled_prompts:
            sampled_prompts.append((prompt1, prompt2))
                
                        
scale = 0.0
seed = 42
ddim_steps = 4

# Generate all commands.
for prompt_A, prompt_B in sampled_prompts:
    sg_tag = f"{'_'.join(prompt_A.split())}_{'_'.join(prompt_B.split())}"
    if not os.path.isdir(os.path.join(output_folder, sg_tag)):
        os.makedirs(os.path.join(output_folder, sg_tag), exist_ok=True)
        print(f"Generating A & B for {prompt_A} -> {prompt_B}")
        base_args = {
            'out_dir': output_folder,
            'prompt_A': prompt_A,
            'variant': variant,
            'device': device,
            'prompt_B': prompt_B,
            'image_size': image_size,
            'model': model,
            'model_name': model_name,
            'guidance_scale': 0.0 if ('turbo' in model) or ('schnell' in model) else scale,
            'num_inference_steps': ddim_steps,
            'seed': seed,
            'float': float_,
            'timesteps': [1000, 0],
            'switch_guidance': {},
            'selected_skip_keys': ''
        }
        image_A, image_B, injected_skips, pipe_B = main(Namespace(**base_args), save_results=False, save_b=True)
        image_A.save(os.path.join(output_folder, sg_tag, f"A.png"))
        image_B.save(os.path.join(output_folder, sg_tag, f"B.png"))
        for layer in injected_skips.keys():
            sample = [layer]
            skip_tag = f"skips_{'_'.join(sample)}"
             
            hyper_args = base_args.copy()
            hyper_args.update({
                    'selected_skip_keys': sample
            })
            print(f"Generating C with skip={skip_tag}")
            image_C = main(Namespace(**hyper_args), injected_skips=injected_skips, pipe_B=pipe_B, save_results=False)
            filename = f"C_{skip_tag}.png"
            image_C.save(os.path.join(output_folder, sg_tag, filename))
        
        for i in range(15):
            n = random.choice([2, 3])
            sample = random.sample(list(injected_skips.keys()), n)
            skip_tag = f"skips_{'_'.join(sample)}"
             
            hyper_args = base_args.copy()
            hyper_args.update({
                    'selected_skip_keys': sample
            })
            print(f"Generating C with skip={skip_tag}")
            image_C = main(Namespace(**hyper_args), injected_skips=injected_skips, pipe_B=pipe_B, save_results=False)
            filename = f"C_{skip_tag}.png"
            image_C.save(os.path.join(output_folder, sg_tag, filename))
        del pipe_B
        gc.collect()
                
                

Generating A & B for A high-resolution image of a steampunk airship in the old european village, cyberpunk aesthetic -> A high-resolution image of a vintage camera in the coral reef underwater world, japanese anime style

Initializing Pipeline A (skip capture mode)...


Loading pipeline components...: 100%|██████████| 5/5 [00:00<00:00, 11.75it/s]
You have disabled the safety checker for <class 'diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline'> by passing `safety_checker=None`. Ensure that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered results in services or applications open to the public. Both the diffusers team and Hugging Face strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling it only for use-cases that involve analyzing network behavior or auditing its results. For more information, please have a look at https://github.com/huggingface/diffusers/pull/254 .


Saved 52 blocks for injection.


100%|██████████| 4/4 [00:02<00:00,  1.38it/s]



Running Pipeline C (non-injected mode)...
Saved 52 blocks for injection.


100%|██████████| 4/4 [00:10<00:00,  2.67s/it]
Loading pipeline components...: 100%|██████████| 5/5 [00:00<00:00,  8.85it/s]
You have disabled the safety checker for <class 'diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline'> by passing `safety_checker=None`. Ensure that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered results in services or applications open to the public. Both the diffusers team and Hugging Face strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling it only for use-cases that involve analyzing network behavior or auditing its results. For more information, please have a look at https://github.com/huggingface/diffusers/pull/254 .


Generating C with skip=skips_unet.down_blocks.0.resnets.0

Initializing Pipeline A (skip capture mode)...
Saved 52 blocks for injection.


100%|██████████| 4/4 [00:05<00:00,  1.49s/it]


Generating C with skip=skips_unet.down_blocks.0.attentions.0

Initializing Pipeline A (skip capture mode)...
Saved 52 blocks for injection.


100%|██████████| 4/4 [00:07<00:00,  1.92s/it]


Generating C with skip=skips_unet.down_blocks.0.resnets.1

Initializing Pipeline A (skip capture mode)...
Saved 52 blocks for injection.


100%|██████████| 4/4 [00:07<00:00,  1.79s/it]


Generating C with skip=skips_unet.down_blocks.0.attentions.1

Initializing Pipeline A (skip capture mode)...
Saved 52 blocks for injection.


100%|██████████| 4/4 [00:08<00:00,  2.05s/it]


Generating C with skip=skips_unet.down_blocks.1.resnets.0

Initializing Pipeline A (skip capture mode)...
Saved 52 blocks for injection.


100%|██████████| 4/4 [00:07<00:00,  1.90s/it]


Generating C with skip=skips_unet.down_blocks.1.attentions.0

Initializing Pipeline A (skip capture mode)...
Saved 52 blocks for injection.


 75%|███████▌  | 3/4 [00:07<00:02,  2.44s/it]


KeyboardInterrupt: 