## Skip Inject

Implementation of skip inject ([paper](https://arxiv.org/abs/2501.14524)) using hooks

In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

import json
import os
import torch
import gc
from argparse import Namespace


from main import main
import yaml
import gc

# Determine device.
device = 'mps'

model = "CompVis/stable-diffusion-v1-4" 
variant = 'fp16'
model_name ='auto'
image_size = 512 
float_ = torch.float16

output_folder = 'outputs'
os.makedirs(output_folder, exist_ok=True)

selected_skip_keys = [
        ['unet.up_blocks.0.resnets.0'],
        ['unet.up_blocks.0.resnets.1'],
        ['unet.up_blocks.0.resnets.2'],
        ['unet.up_blocks.1.resnets.0'],
        ['unet.up_blocks.0.resnets.0',
        'unet.up_blocks.1.resnets.0'],
        ['unet.up_blocks.0.resnets.0',
        'unet.up_blocks.0.resnets.1',
        'unet.up_blocks.0.resnets.2'],
        ['unet.up_blocks.0.resnets.0',
        'unet.up_blocks.0.resnets.1',
        'unet.up_blocks.0.resnets.2',
        'unet.up_blocks.1.resnets.0']]



  from .autonotebook import tqdm as notebook_tqdm


In [None]:
args = {}

yml_file = 'data/pnp/wild-ti2i-fake.yaml'

with open(yml_file, "r") as f:
    tests = yaml.safe_load(f)  # This should be a list of dicts (each dict is one test)


os.makedirs(output_folder, exist_ok=True)

for idx, test in enumerate(tests):
    scale = test.get("scale", 7.5)
    seed = test.get("seed", 0)
    ddim_steps = test.get("ddim_steps", 50)
    source_prompt = test.get("source_prompt", "")
    target_prompts = test.get("target_prompts", [])
        
    if 'turbo' in model or 'schnell' in model:
        scale = 0.0
        ddim_steps = 3

    if 'kandinsky' in model:
        ddim_steps = min(ddim_steps, 30)
        
    print(f"\n--- Running test #{idx+1} ---")
    print(f"scale={scale}, seed={seed}, ddim_steps={ddim_steps}")
    print(f"source_prompt='{source_prompt}'")
    print(f"target_prompts={target_prompts}")
    
    for j, target_prompt in enumerate(target_prompts):
        test_tag = f"test{idx+1}_pair{j+1}"

        # ---- Generate A & B once ----
        base_args = {
            'out_dir': output_folder,
            'prompt_A': target_prompt,
            'variant': variant,
            'device': device,
            'prompt_B': source_prompt,
            'image_size': image_size,
            'model': model,
            'model_name': model_name,
            'guidance_scale': 0.0 if ('turbo' in model) or ('schnell' in model) else scale,
            'num_inference_steps': args.get('num_inference_steps', ddim_steps),
            'seed': seed,
            'float': float_,
            'timesteps': [1000, 0],
            'switch_guidance': {},
            'selected_skip_keys': selected_skip_keys[0]
        }

        print(f"Generating A & B for {source_prompt} -> {target_prompt}")
        image_A, image_B, injected_skips, pipe_B = main(Namespace(**base_args), save_results=False, save_b=True)
        image_A.save(os.path.join(output_folder, f"A_{test_tag}.png"))
        image_B.save(os.path.join(output_folder, f"B_{test_tag}.png"))

        # ---- Loop over hyperparameters for C ----
        switch_guidance_list = [{}]#, 0.9, 1.2]
        timestep_list = [[1000, 200]]#, [1000, 100], [1000,200]]

        for skips in selected_skip_keys:
            print(skips)
            skip_tag = f"skips_{'_'.join([s.split('.')[-1] +'_' + s.split('.')[-3] for s in skips])}"

            for sg, ts in zip(switch_guidance_list, timestep_list):
                hyper_args = base_args.copy()
                hyper_args.update({
                    'switch_guidance': sg,
                    'timesteps': ts,
                    'selected_skip_keys': skips
                })

                print(f"Generating C with skip={skip_tag}, SG={sg}, timesteps={ts}")
                image_C = main(Namespace(**hyper_args), injected_skips=injected_skips, pipe_B=pipe_B, save_results=False)
    
                # Save C with detailed name
                sg_tag = f"SG{sg}_T{ts[0]}-{ts[1]}"
                filename = f"C_{test_tag}_{skip_tag}_{sg_tag}.png"
                image_C.save(os.path.join(output_folder, filename))
                
                # Save metadata
                metadata = {
                    "test_pair": test_tag,
                    "source_prompt": source_prompt,
                    "target_prompt": target_prompt,
                    "scale": scale,
                    "seed": seed,
                    "ddim_steps": ddim_steps,
                    "skip_injection": skips,
                    "switch_guidance": sg,
                    "timesteps": ts
                }
                meta_name = f"metadata_{test_tag}_{skip_tag}_{sg_tag}.json"
                with open(os.path.join(output_folder, meta_name), "w") as f:
                    json.dump(metadata, f, indent=4)
        del pipe_B
        gc.collect()

        #time.sleep(5)

    # Save general config for this test
    config = {
        "scale": scale,
        "seed": seed,
        "ddim_steps": ddim_steps,
        "model": model,
        "selected_skip_keys": selected_skip_keys
    }
    config_path = os.path.join(output_folder, "config.json")
    with open(config_path, "w") as f:
        json.dump(config, f, indent=4)
    print(f"Saved experiment config to {config_path}")


In [None]:
from utils.utils_test import generate_triplet

generate_triplet(
    source_prompt="a cat on a chair",
    target_prompt="a dog on a chair",
    output_folder="outputs/triplet_test",
    model=model ,
    model_name=model_name,
    variant=variant,
    device=device,
    image_size=image_size,
    selected_skip_keys=selected_skip_keys,
    main_fn=main  # your generation function
)
