In [1]:
import os
import gc
import random
import argparse
import torch
import torch.nn.functional as F
import torchvision.transforms as T
from PIL import Image
from tqdm import tqdm
import numpy as np
from datasets import load_dataset
from diffusers import LCMScheduler, DDPMScheduler, StableDiffusionPipeline
# import ImageReward as RM
# from torchmetrics.functional.multimodal import clip_score
# from torchmetrics.image.fid import FrechetInceptionDistance
from torchvision.transforms.functional import to_tensor, resize
from diffusers.models.attention_processor import AttnProcessor2_0

from utils.loading import load_models
from utils import p2p, generation, inversion

torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True

NUM_REVERSE_CONS_STEPS = 4
REVERSE_TIMESTEPS = [259, 519, 779, 999]
NUM_FORWARD_CONS_STEPS = 4
FORWARD_TIMESTEPS = [19, 259, 519, 779]
NUM_DDIM_STEPS = 50
START_TIMESTEP = 19

def generate_images_batch(solver, reverse_cons_model, prompts, latent):
    images = []
    generator = torch.Generator(device="cuda:0").manual_seed(42)
    controller = p2p.AttentionStore()
    images, gen_latent, latents = generation.runner(
        guidance_scale=0.0,
        tau1=1.0,
        tau2=1.0,
        is_cons_forward=True,
        model=reverse_cons_model,
        dynamic_guidance=False,
        w_embed_dim=512,
        start_time=50,
        solver=solver,
        prompt=prompts,
        controller=controller,
        generator=generator,
        latent=latent,
        return_type="image",
        # num_inference_steps=50,
    )
    return images, gen_latent, latents


def invert_images_batch(solver, prompts, images, guidance_scale, use_reverse_model=False):
    (image_gt, image_rec), latents, uncond_embeddings, latent_orig = inversion.invert(
        is_cons_inversion=True,
        # do_npi=False,
        # do_nti=True,
        w_embed_dim=512,
        stop_step=50,  # from [0, NUM_DDIM_STEPS]
        inv_guidance_scale=guidance_scale,
        dynamic_guidance=False,
        tau1=0.0,
        tau2=0.0,
        solver=solver,
        images=images,
        prompt=prompts,
        # num_inner_steps=10,
        # early_stop_epsilon=1e-5,
        seed=42,
        use_reverse_model=use_reverse_model
    )

    return image_gt, image_rec, latents[-1], latents, latent_orig


2025-04-17 12:41:36.677714: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1744882896.707354  267900 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1744882896.716426  267900 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [2]:
ldm_stable, reverse_cons_model, forward_cons_model = load_models(
    model_id="sd-legacy/stable-diffusion-v1-5",
    device="cuda:0",
    forward_checkpoint="checkpoints/iCD-SD1.5_19_259_519_779.safetensors",
    reverse_checkpoint="checkpoints/iCD-SD1.5_259_519_779_999.safetensors",
    r=64,
    w_embed_dim=512,
    teacher_checkpoint="checkpoints/sd15_cfg_distill.pt",
    dtype="fp16",
)
ldm_stable.unet.set_attn_processor(AttnProcessor2_0())
reverse_cons_model.unet.set_attn_processor(AttnProcessor2_0())
forward_cons_model.unet.set_attn_processor(AttnProcessor2_0())

ldm_stable.set_progress_bar_config(disable=True)
reverse_cons_model.set_progress_bar_config(disable=True)
forward_cons_model.set_progress_bar_config(disable=True)

ldm_stable.safety_checker = None
reverse_cons_model.safety_checker = None
forward_cons_model.safety_checker = None

noise_scheduler = DDPMScheduler.from_pretrained(
    "runwayml/stable-diffusion-v1-5",
    subfolder="scheduler",
)

solver = generation.Generator(
    model=ldm_stable,
    noise_scheduler=noise_scheduler,
    n_steps=NUM_DDIM_STEPS,
    forward_cons_model=forward_cons_model,
    forward_timesteps=FORWARD_TIMESTEPS,
    reverse_cons_model=reverse_cons_model,
    reverse_timesteps=REVERSE_TIMESTEPS,
    num_endpoints=NUM_REVERSE_CONS_STEPS,
    num_forward_endpoints=NUM_FORWARD_CONS_STEPS,
    max_forward_timestep_index=49,
    start_timestep=START_TIMESTEP,
)

# Configure P2P components
p2p.NUM_DDIM_STEPS = NUM_DDIM_STEPS
p2p.tokenizer = ldm_stable.tokenizer
p2p.device = "cuda:0"


Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

  "_class_name": "DDIMScheduler",
  "_diffusers_version": "0.31.0",
  "beta_end": 0.012,
  "beta_schedule": "scaled_linear",
  "beta_start": 0.00085,
  "clip_sample": false,
  "clip_sample_range": 1.0,
  "dynamic_thresholding_ratio": 0.995,
  "num_train_timesteps": 1000,
  "prediction_type": "epsilon",
  "rescale_betas_zero_snr": false,
  "sample_max_value": 1.0,
  "set_alpha_to_one": false,
  "steps_offset": 0,
  "thresholding": false,
  "timestep_spacing": "leading",
  "trained_betas": null
}
 is outdated. `steps_offset` should be set to 1 instead of 0. Please make sure to update the config accordingly as leaving `steps_offset` might led to incorrect results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json` file
  deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)


Forward CD is initialized with guidance embedding, dim 512


Some weights of UNet2DConditionModel were not initialized from the model checkpoint at sd-legacy/stable-diffusion-v1-5 and are newly initialized: ['time_embedding.cond_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Embedded model is loading from checkpoints/sd15_cfg_distill.pt


  unet.load_state_dict(torch.load(teacher_checkpoint))


Reverse CD is loading from checkpoints/iCD-SD1.5_259_519_779_999.safetensors
Forward CD is loading from checkpoints/iCD-SD1.5_19_259_519_779.safetensors
Endpoints reverse CTM: tensor([999, 779, 519, 259]), tensor([779, 519, 259,   0])
Endpoints forward CTM: tensor([ 19, 259, 519, 779]), tensor([259, 519, 779, 999])


In [3]:
torch.manual_seed(42)

# optimizer = torch.optim.Adam([solver.w_embedding], lr=1e-4)

data_files = {
    "test": "data/test-*-of-*.parquet",
}
dataset = load_dataset(
    "bitmind/MS-COCO",
    data_files=data_files,
    split="test",
    verification_mode="no_checks",
)
dataset_sample = dataset.select(
    random.sample(range(len(dataset)), 1000)
)

mse_latent, mse_real = [], []
diff_latents = []

In [4]:
del dataset
torch.cuda.empty_cache()
gc.collect()

38

In [5]:
solver.w_embedding.shape

torch.Size([1, 512])

In [6]:
for p in solver.model.unet.parameters():
    p.requires_grad = False
for p in reverse_cons_model.unet.parameters():
    p.requires_grad = False
for p in forward_cons_model.unet.parameters():
    p.requires_grad = False

In [7]:
solver.w_embedding.requires_grad = True

In [8]:
import itertools

def pretty_count(n):
    """utility — 1 234 567 → '1.23 M' """
    if n < 1e3:
        return str(n)
    elif n < 1e6:
        return f"{n/1e3:,.2f} K"
    elif n < 1e9:
        return f"{n/1e6:,.2f} M"
    else:
        return f"{n/1e9:,.2f} B"

def trainable_parameter_report(*modules):
    """
    Print a small table with #trainable parameters for each module.
    Pass any mix of nn.Module objects (or objects that expose .parameters()).
    """
    rows = []
    grand_total = 0
    for mod in modules:
        n = sum(p.numel() for p in mod.parameters() if p.requires_grad)
        rows.append((mod.__class__.__name__, pretty_count(n)))
        grand_total += n

    width = max(len(name) for name, _ in rows) + 3
    print("─" * (width + 15))
    for name, cnt in rows:
        print(f"{name:<{width}} {cnt:>10}")
    print("─" * (width + 15))
    print(f"{'TOTAL':<{width}} {pretty_count(grand_total):>10}")
    print("─" * (width + 15))
    return grand_total

# --- call it ------------------------------------------------
trainable_parameter_report(solver.model.unet,               # contains w_embedding
                           reverse_cons_model.unet,   # any LoRA or unfrozen layers
                           forward_cons_model.unet)   # idem


──────────────────────────────────────
UNet2DConditionModel             0
UNet2DConditionModel             0
UNet2DConditionModel             0
──────────────────────────────────────
TOTAL                            0
──────────────────────────────────────


0

In [9]:
optimizer = torch.optim.Adam([solver.w_embedding], lr=3e-4)

In [10]:
from utils.p2p import register_attention_control

In [11]:
register_attention_control(solver.model, None)

In [12]:
torch.manual_seed(42)

clip_scores, ir_scores = [], []
# diff_latents = []
mse_latent_log = []
mse_real_log = []
BATCH_SIZE = 2
step_counter = 0

for start_idx in tqdm(
    range(0, len(dataset_sample), BATCH_SIZE), desc="Processing batches"
):
    batch = dataset_sample[start_idx : start_idx + BATCH_SIZE]
    batch_images = [
        img.convert("RGB").resize((512, 512), Image.Resampling.LANCZOS)
        for img in batch["image"]
    ]
    batch_prompts = [s["raw"] for s in batch["sentences"]]
    solver.init_prompt(batch_prompts)

    # 1) Inversion to get latents from the forward (teacher) direction
    # with torch.no_grad():
    image_rec1, latents1, latent1 = solver.cons_inversion(
        batch_images,
        w_embed_dim=0,
        guidance_scale=0.0,
        seed=0,
        use_reverse_model=False,
    )
    with torch.amp.autocast("cuda"):
        image_rec2, latents2, latent2 = solver.cons_inversion(
            batch_images,
            w_embed_dim=512,
            guidance_scale=-1.0,
            seed=0,
            use_reverse_model=True,
            use_w_embed=True,
        )
    
    a1 = latents1[1] - latents2[1]
    a2 = latents1[2] - latents2[2]
    a3 = latents1[3] - latents2[3]
    a4 = latents1[4] - latents2[4]

    # diff_latents.append((a1.detach().cpu(), a2.detach().cpu(), a3.detach().cpu(), a4.detach().cpu()))

    # Use latents from forward pass vs. reverse pass (as an example)
    latent_forward1 = latents1[0].detach()  # no grad
    latent_reverse1 = latents2[0]  # reverse latent
    latent_forward2 = latents1[1].detach()  # no grad
    latent_reverse2 = latents2[1]  # reverse latent
    latent_forward3 = latents1[2].detach()  # no grad
    latent_reverse3 = latents2[2]  # reverse latent
    latent_forward4 = latents1[3].detach()  # no grad
    latent_reverse4 = latents2[3]  # reverse latent

    # 5) Compute MSE loss for training the embedding
    loss1 = F.mse_loss(latent_forward1, latent_reverse1)
    loss2 = F.mse_loss(latent_forward2, latent_reverse2)
    loss3 = F.mse_loss(latent_forward3, latent_reverse3)
    loss4 = F.mse_loss(latent_forward4, latent_reverse4)

    loss = loss1 + loss2 + loss3 + loss4
    # 6) Optimization step
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # 7) Update step counter and log metrics to wandb
    step_counter += 1

    # 8) Periodically log sample images to wandb
    if step_counter % 50 == 0:
        batch_metrics = {
            "step": step_counter,
            "batch_start": start_idx,
            "batch_end": start_idx + BATCH_SIZE,
            "guidance_scale": -1.0,
            "loss": loss.item(),
            "loss_latent_r_1": loss1.item(),
            "loss_latent_r_2": loss2.item(),
            "loss_latent_r_3": loss3.item(),
            "loss_latent_r_4": loss4.item(),

            # "pixel_mse": pixel_mse,
            # "latent_mse": latent_mse,
            "diff_latents_a1": a1.mean().item(),
            "diff_latents_a2": a2.mean().item(),
            "diff_latents_a3": a3.mean().item(),
            "diff_latents_a4": a4.mean().item(),
        }
        # wandb.log(batch_metrics, step=step_counter)
        print("\n--- Batch Metrics ---")
        for key, value in batch_metrics.items():
            print(f"{key:20}: {value}")
        print("-" * 40)
#         try:
#             # with torch.no_grad():
#             gen_images_batch, gen_latents_batch, _ = generate_images_batch(
#                     solver, reverse_cons_model, batch_prompts, latents2[-1],
#             )
#             rec_pil = T.ToPILImage()(image_rec2[0].transpose(1, 2, 0))
#             gen_pil = T.ToPILImage()(gen_images_batch[0].transpose(1, 2, 0))

#             rec_pil.save(os.path.join("images", f"step_{step_counter}_sample_rec.jpg"))
#             gen_pil.save(os.path.join("images", f"step_{step_counter}_sample_gen.jpg"))
#             del (
#                 gen_images_batch, gen_latents_batch, _
#             )
#             gc.collect()
#             torch.cuda.empty_cache()
#         except Exception as e:
#             print(e)
#             break

    # Cleanup
    del (
        a1, 
        a2, 
        a3, 
        a4, 
        image_rec1, 
        latents1, 
        latent1, 
        image_rec2, 
        latents2, 
        latent2, 
        latent_forward1, 
        latent_reverse1, 
        latent_forward2, 
        latent_reverse2, 
        latent_forward3, 
        latent_reverse3, 
        latent_forward4, 
        latent_reverse4
    )
    gc.collect()
    torch.cuda.empty_cache()

Processing batches:  10%|▉         | 49/500 [02:57<26:45,  3.56s/it] 


--- Batch Metrics ---
step                : 50
batch_start         : 98
batch_end           : 100
guidance_scale      : -1.0
loss                : 0.16998666524887085
loss_latent_r_1     : 7.450540806530626e-07
loss_latent_r_2     : 0.0281115360558033
loss_latent_r_3     : 0.06867031753063202
loss_latent_r_4     : 0.07320407032966614
diff_latents_a1     : 0.00017747882520779967
diff_latents_a2     : 0.00236098887398839
diff_latents_a3     : 0.00016356274136342108
diff_latents_a4     : 0.0024823672138154507
----------------------------------------


Processing batches:  20%|█▉        | 99/500 [05:28<18:56,  2.83s/it]


--- Batch Metrics ---
step                : 100
batch_start         : 198
batch_end           : 200
guidance_scale      : -1.0
loss                : 0.16600432991981506
loss_latent_r_1     : 5.879029458810692e-07
loss_latent_r_2     : 0.02899692766368389
loss_latent_r_3     : 0.06601914763450623
loss_latent_r_4     : 0.07098766416311264
diff_latents_a1     : -0.0003876169503200799
diff_latents_a2     : 0.0018087518401443958
diff_latents_a3     : -0.0006995809962972999
diff_latents_a4     : 0.0007289305794984102
----------------------------------------


Processing batches:  30%|██▉       | 149/500 [08:06<20:42,  3.54s/it]


--- Batch Metrics ---
step                : 150
batch_start         : 298
batch_end           : 300
guidance_scale      : -1.0
loss                : 0.16854453086853027
loss_latent_r_1     : 9.478814035901451e-07
loss_latent_r_2     : 0.026401493698358536
loss_latent_r_3     : 0.0669550895690918
loss_latent_r_4     : 0.07518700510263443
diff_latents_a1     : 0.00010483618825674057
diff_latents_a2     : 0.001457225065678358
diff_latents_a3     : -0.0009447133634239435
diff_latents_a4     : 0.001242521102540195
----------------------------------------


Processing batches:  40%|███▉      | 199/500 [11:03<18:21,  3.66s/it]


--- Batch Metrics ---
step                : 200
batch_start         : 398
batch_end           : 400
guidance_scale      : -1.0
loss                : 0.1383953094482422
loss_latent_r_1     : 7.775688573019579e-07
loss_latent_r_2     : 0.02527574822306633
loss_latent_r_3     : 0.052658144384622574
loss_latent_r_4     : 0.0604606494307518
diff_latents_a1     : 0.0001659216359257698
diff_latents_a2     : 0.002881997497752309
diff_latents_a3     : 0.002658013254404068
diff_latents_a4     : 0.004113596398383379
----------------------------------------


Processing batches:  50%|████▉     | 249/500 [13:43<13:19,  3.19s/it]


--- Batch Metrics ---
step                : 250
batch_start         : 498
batch_end           : 500
guidance_scale      : -1.0
loss                : 0.131059929728508
loss_latent_r_1     : 1.227371967615909e-06
loss_latent_r_2     : 0.022724082693457603
loss_latent_r_3     : 0.05129890888929367
loss_latent_r_4     : 0.05703571438789368
diff_latents_a1     : 0.0007219183025881648
diff_latents_a2     : 0.0015598334139212966
diff_latents_a3     : 0.0006340585532598197
diff_latents_a4     : 0.004471641965210438
----------------------------------------


Processing batches:  60%|█████▉    | 299/500 [16:25<10:00,  2.99s/it]


--- Batch Metrics ---
step                : 300
batch_start         : 598
batch_end           : 600
guidance_scale      : -1.0
loss                : 0.16027449071407318
loss_latent_r_1     : 5.162271463632351e-07
loss_latent_r_2     : 0.02746780775487423
loss_latent_r_3     : 0.06329695880413055
loss_latent_r_4     : 0.06950920820236206
diff_latents_a1     : -0.000737191759981215
diff_latents_a2     : 0.0014383954694494605
diff_latents_a3     : 0.0009117266163229942
diff_latents_a4     : 0.0026650296058505774
----------------------------------------


Processing batches:  70%|██████▉   | 349/500 [19:12<08:17,  3.29s/it]


--- Batch Metrics ---
step                : 350
batch_start         : 698
batch_end           : 700
guidance_scale      : -1.0
loss                : 0.12836427986621857
loss_latent_r_1     : 6.714975029353809e-07
loss_latent_r_2     : 0.026609841734170914
loss_latent_r_3     : 0.04867025464773178
loss_latent_r_4     : 0.05308350920677185
diff_latents_a1     : 0.0007436739979311824
diff_latents_a2     : 0.002547305077314377
diff_latents_a3     : -0.0006382087012752891
diff_latents_a4     : 0.000897866440936923
----------------------------------------


Processing batches:  80%|███████▉  | 399/500 [22:03<05:24,  3.21s/it]


--- Batch Metrics ---
step                : 400
batch_start         : 798
batch_end           : 800
guidance_scale      : -1.0
loss                : 0.1448948234319687
loss_latent_r_1     : 8.483624469590723e-07
loss_latent_r_2     : 0.024445462971925735
loss_latent_r_3     : 0.055543653666973114
loss_latent_r_4     : 0.06490486115217209
diff_latents_a1     : -6.769521860405803e-06
diff_latents_a2     : 0.0023037190549075603
diff_latents_a3     : 0.0017033027252182364
diff_latents_a4     : 0.0016092198202386498
----------------------------------------


Processing batches:  90%|████████▉ | 449/500 [24:42<02:46,  3.27s/it]


--- Batch Metrics ---
step                : 450
batch_start         : 898
batch_end           : 900
guidance_scale      : -1.0
loss                : 0.15207573771476746
loss_latent_r_1     : 1.2298562523938017e-06
loss_latent_r_2     : 0.025367118418216705
loss_latent_r_3     : 0.061676423996686935
loss_latent_r_4     : 0.06503095477819443
diff_latents_a1     : 0.00040045345667749643
diff_latents_a2     : 0.0028753923252224922
diff_latents_a3     : 0.0006061262101866305
diff_latents_a4     : 0.0010991034796461463
----------------------------------------


Processing batches: 100%|█████████▉| 499/500 [27:15<00:03,  3.11s/it]


--- Batch Metrics ---
step                : 500
batch_start         : 998
batch_end           : 1000
guidance_scale      : -1.0
loss                : 0.13874554634094238
loss_latent_r_1     : 9.518190040580521e-07
loss_latent_r_2     : 0.024412106722593307
loss_latent_r_3     : 0.052048925310373306
loss_latent_r_4     : 0.06228356808423996
diff_latents_a1     : 5.621431046165526e-05
diff_latents_a2     : 0.0018030592473223805
diff_latents_a3     : 0.002017838880419731
diff_latents_a4     : 0.0018467081245034933
----------------------------------------


Processing batches: 100%|██████████| 500/500 [27:18<00:00,  3.28s/it]


## Infer

In [13]:
with torch.amp.autocast("cuda"):
    image_rec2, latents2, latent2 = solver.cons_inversion(
        batch_images,
        w_embed_dim=512,
        guidance_scale=-1.0,
        seed=0,
        use_reverse_model=True,
    )

In [14]:
gen_images_batch, gen_latents_batch, _ = generate_images_batch(
        solver, reverse_cons_model, batch_prompts, latents2[-1],
)

  batch_size, model.unet.in_channels, height // 8, width // 8


In [15]:
rec_pil = T.ToPILImage()(image_rec2[0].transpose(1, 2, 0))
gen_pil = T.ToPILImage()(gen_images_batch[0].transpose(1, 2, 0))

rec_pil.save(os.path.join("images", f"step_{step_counter}_sample_rec.jpg"))
gen_pil.save(os.path.join("images", f"step_{step_counter}_sample_gen.jpg"))

In [16]:
del (
    image_rec2, latents2, latent2, gen_images_batch, gen_latents_batch, _
)
gc.collect()
torch.cuda.empty_cache()

In [17]:
solver.w_embedding

Parameter containing:
tensor([[-1.3167e+00,  1.0265e+00,  5.8795e-01,  4.3646e-01,  1.0810e+00,
         -3.9021e-01,  9.2430e-01,  3.5386e-01,  1.8032e-01,  1.4343e+00,
         -8.3853e-01, -9.2224e-01, -1.3401e+00, -1.3697e+00, -2.0591e+00,
         -2.5465e-01, -1.7250e-01,  6.6381e-02, -4.0814e-01,  2.5360e-01,
          9.9003e-02, -1.9166e-01,  2.3489e+00,  4.2554e-01, -1.7141e+00,
          6.9852e-01, -5.3022e-01,  1.6999e+00, -7.4754e-01, -1.0702e-01,
          2.4658e-01,  9.6704e-01, -1.4666e+00, -8.1104e-01,  5.7378e-01,
          4.7321e-01, -7.2670e-04,  1.4495e+00,  1.0640e+00,  1.1645e+00,
         -3.4560e-01,  2.0137e+00, -5.1223e-01,  2.0068e+00, -7.0276e-01,
         -5.1158e-01, -1.7230e+00,  2.3604e-01,  1.1305e+00, -6.5812e-01,
         -3.6016e-01, -1.2838e+00,  1.2412e+00,  1.2128e+00, -3.2235e-01,
          7.5720e-01, -3.0015e-01, -1.7045e+00, -4.4538e-01, -7.6546e-01,
         -1.9902e-01,  4.5017e-01,  9.9657e-01, -2.0427e+00, -1.1009e+00,
          1.0661