In [1]:
import os
import gc
import random
import argparse
import torch
import torch.nn.functional as F
import torchvision.transforms as T
from PIL import Image
from tqdm import tqdm
import numpy as np
from datasets import load_dataset
from diffusers import LCMScheduler, DDPMScheduler, StableDiffusionPipeline
# import ImageReward as RM
# from torchmetrics.functional.multimodal import clip_score
# from torchmetrics.image.fid import FrechetInceptionDistance
from torchvision.transforms.functional import to_tensor, resize
from diffusers.models.attention_processor import AttnProcessor2_0

from utils.loading import load_models
from utils import p2p, generation, inversion

torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True

NUM_REVERSE_CONS_STEPS = 4
REVERSE_TIMESTEPS = [259, 519, 779, 999]
NUM_FORWARD_CONS_STEPS = 4
FORWARD_TIMESTEPS = [19, 259, 519, 779]
NUM_DDIM_STEPS = 50
START_TIMESTEP = 19

def generate_images_batch(solver, reverse_cons_model, prompts, latent):
    images = []
    generator = torch.Generator(device="cuda:0").manual_seed(42)
    controller = p2p.AttentionStore()
    images, gen_latent, latents = generation.runner(
        guidance_scale=0.0,
        tau1=1.0,
        tau2=1.0,
        is_cons_forward=True,
        model=reverse_cons_model,
        dynamic_guidance=False,
        w_embed_dim=512,
        start_time=50,
        solver=solver,
        prompt=prompts,
        controller=controller,
        generator=generator,
        latent=latent,
        return_type="image",
        # num_inference_steps=50,
    )
    return images, gen_latent, latents


def invert_images_batch(solver, prompts, images, guidance_scale, use_reverse_model=False):
    (image_gt, image_rec), latents, uncond_embeddings, latent_orig = inversion.invert(
        is_cons_inversion=True,
        # do_npi=False,
        # do_nti=True,
        w_embed_dim=512,
        stop_step=50,  # from [0, NUM_DDIM_STEPS]
        inv_guidance_scale=guidance_scale,
        dynamic_guidance=False,
        tau1=0.0,
        tau2=0.0,
        solver=solver,
        images=images,
        prompt=prompts,
        # num_inner_steps=10,
        # early_stop_epsilon=1e-5,
        seed=42,
        use_reverse_model=use_reverse_model
    )

    return image_gt, image_rec, latents[-1], latents, latent_orig


2025-04-20 12:31:25.385301: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1745141485.415792  226143 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1745141485.425195  226143 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [48]:
ldm_stable, reverse_cons_model, forward_cons_model = load_models(
    model_id="sd-legacy/stable-diffusion-v1-5",
    device="cuda:0",
    forward_checkpoint="checkpoints/iCD-SD1.5_19_259_519_779.safetensors",
    reverse_checkpoint="checkpoints/iCD-SD1.5_259_519_779_999.safetensors",
    r=64,
    w_embed_dim=512,
    teacher_checkpoint="checkpoints/sd15_cfg_distill.pt",
    dtype="fp16",
)
# ldm_stable.unet.set_attn_processor(AttnProcessor2_0())
# reverse_cons_model.unet.set_attn_processor(AttnProcessor2_0())
# forward_cons_model.unet.set_attn_processor(AttnProcessor2_0())

ldm_stable.set_progress_bar_config(disable=True)
reverse_cons_model.set_progress_bar_config(disable=True)
forward_cons_model.set_progress_bar_config(disable=True)

ldm_stable.safety_checker = None
reverse_cons_model.safety_checker = None
forward_cons_model.safety_checker = None

noise_scheduler = DDPMScheduler.from_pretrained(
    "runwayml/stable-diffusion-v1-5",
    subfolder="scheduler",
)

solver = generation.Generator(
    model=ldm_stable,
    noise_scheduler=noise_scheduler,
    n_steps=NUM_DDIM_STEPS,
    forward_cons_model=forward_cons_model,
    forward_timesteps=FORWARD_TIMESTEPS,
    reverse_cons_model=reverse_cons_model,
    reverse_timesteps=REVERSE_TIMESTEPS,
    num_endpoints=NUM_REVERSE_CONS_STEPS,
    num_forward_endpoints=NUM_FORWARD_CONS_STEPS,
    max_forward_timestep_index=49,
    start_timestep=START_TIMESTEP,
)

# Configure P2P components
p2p.NUM_DDIM_STEPS = NUM_DDIM_STEPS
p2p.tokenizer = ldm_stable.tokenizer
p2p.device = "cuda:0"


Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

Forward CD is initialized with guidance embedding, dim 512


Some weights of UNet2DConditionModel were not initialized from the model checkpoint at sd-legacy/stable-diffusion-v1-5 and are newly initialized: ['time_embedding.cond_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Embedded model is loading from checkpoints/sd15_cfg_distill.pt
Reverse CD is loading from checkpoints/iCD-SD1.5_259_519_779_999.safetensors
Forward CD is loading from checkpoints/iCD-SD1.5_19_259_519_779.safetensors
Endpoints reverse CTM: tensor([259, 519, 779, 999]), tensor([519, 779, 999,   0])
Endpoints forward CTM: tensor([ 19, 259, 519, 779]), tensor([259, 519, 779, 999])


In [47]:
del ldm_stable, reverse_cons_model, forward_cons_model
torch.cuda.empty_cache()
gc.collect()

0

In [3]:
torch.manual_seed(42)

# optimizer = torch.optim.Adam([solver.w_embedding], lr=1e-4)

data_files = {
    "test": "data/test-*-of-*.parquet",
}
dataset = load_dataset(
    "bitmind/MS-COCO",
    data_files=data_files,
    split="test",
    verification_mode="no_checks",
)
dataset_sample = dataset.select(
    random.sample(range(len(dataset)), 1000)
)

mse_latent, mse_real = [], []
diff_latents = []

In [4]:
del dataset
torch.cuda.empty_cache()
gc.collect()

38

In [49]:
for p in solver.model.unet.parameters():
    p.requires_grad = False
for p in reverse_cons_model.unet.parameters():
    p.requires_grad = False
for p in forward_cons_model.unet.parameters():
    p.requires_grad = False

In [50]:
for name, param in reverse_cons_model.unet.time_embedding.cond_proj.named_parameters():
    param.requires_grad = True

In [44]:
proj = reverse_cons_model.unet.time_embedding.cond_proj
torch.nn.init.normal_(proj.weight, mean=0.0, std=1e-4)

In [41]:
import itertools

def pretty_count(n):
    """utility — 1 234 567 → '1.23 M' """
    if n < 1e3:
        return str(n)
    elif n < 1e6:
        return f"{n/1e3:,.2f} K"
    elif n < 1e9:
        return f"{n/1e6:,.2f} M"
    else:
        return f"{n/1e9:,.2f} B"

def trainable_parameter_report(*modules):
    """
    Print a small table with #trainable parameters for each module.
    Pass any mix of nn.Module objects (or objects that expose .parameters()).
    """
    rows = []
    grand_total = 0
    for mod in modules:
        n = sum(p.numel() for p in mod.parameters() if p.requires_grad)
        rows.append((mod.__class__.__name__, pretty_count(n)))
        grand_total += n

    width = max(len(name) for name, _ in rows) + 3
    print("─" * (width + 15))
    for name, cnt in rows:
        print(f"{name:<{width}} {cnt:>10}")
    print("─" * (width + 15))
    print(f"{'TOTAL':<{width}} {pretty_count(grand_total):>10}")
    print("─" * (width + 15))
    return grand_total

# --- call it ------------------------------------------------
trainable_parameter_report(solver.model.unet,               # contains w_embedding
                           reverse_cons_model.unet,   # any LoRA or unfrozen layers
                           forward_cons_model.unet)   # idem


──────────────────────────────────────
UNet2DConditionModel             0
UNet2DConditionModel      163.84 K
UNet2DConditionModel             0
──────────────────────────────────────
TOTAL                     163.84 K
──────────────────────────────────────


163840

In [30]:
reverse_cons_model.unet.time_embedding.cond_proj

Linear(in_features=512, out_features=320, bias=False)

In [31]:
512 * 320 / 1000

163.84

In [51]:
optimizer = torch.optim.Adam(
    reverse_cons_model.unet.time_embedding.cond_proj.parameters(),
    lr=1e-7
)

In [33]:
from utils.p2p import register_attention_control

In [34]:
register_attention_control(solver.model, None)

In [35]:
def get_grad_norm(model, norm_type=2):
    """
    Calculates the gradient norm for logging.

    Args:
        norm_type (float | str | None): the order of the norm.
    Returns:
        total_norm (float): the calculated norm.
    """
    parameters = model.parameters()
    if isinstance(parameters, torch.Tensor):
        parameters = [parameters]
    parameters = [p for p in parameters if p.grad is not None]
    total_norm = torch.norm(
        torch.stack([torch.norm(p.grad.detach(), norm_type) for p in parameters]),
        norm_type,
    )
    return total_norm.item()

In [52]:
torch.manual_seed(42)

clip_scores, ir_scores = [], []
# diff_latents = []
mse_latent_log = []
mse_real_log = []
BATCH_SIZE = 1
step_counter = 0

for start_idx in tqdm(
    range(0, len(dataset_sample), BATCH_SIZE), desc="Processing batches"
):
    batch = dataset_sample[start_idx : start_idx + BATCH_SIZE]
    batch_images = [
        img.convert("RGB").resize((512, 512), Image.Resampling.LANCZOS)
        for img in batch["image"]
    ]
    batch_prompts = [s["raw"] for s in batch["sentences"]]
    solver.init_prompt(batch_prompts)

    # 1) Inversion to get latents from the forward (teacher) direction
    # with torch.no_grad():
    image_rec1, latents1, latent1 = solver.cons_inversion(
        batch_images,
        w_embed_dim=512,
        guidance_scale=0.0,
        seed=0,
        use_reverse_model=False,
    )
    # with torch.amp.autocast("cuda"):
    image_rec2, latents2, latent2 = solver.cons_inversion(
        batch_images,
        w_embed_dim=512,
        guidance_scale=0.0,
        seed=0,
        use_reverse_model=True,
    )
    
    a1 = latents1[1] - latents2[1]
    a2 = latents1[2] - latents2[2]
    a3 = latents1[3] - latents2[3]
    a4 = latents1[4] - latents2[4]

    # diff_latents.append((a1.detach().cpu(), a2.detach().cpu(), a3.detach().cpu(), a4.detach().cpu()))

    # Use latents from forward pass vs. reverse pass (as an example)
    latent_forward1 = latents1[1].detach()  # no grad
    latent_reverse1 = latents2[1]  # reverse latent
    latent_forward2 = latents1[2].detach()  # no grad
    latent_reverse2 = latents2[2]  # reverse latent
    latent_forward3 = latents1[3].detach()  # no grad
    latent_reverse3 = latents2[3]  # reverse latent
    latent_forward4 = latents1[4].detach()  # no grad
    latent_reverse4 = latents2[4]  # reverse latent

    # 5) Compute MSE loss for training the embedding
    loss1 = F.mse_loss(latent_forward1, latent_reverse1)
    loss2 = F.mse_loss(latent_forward2, latent_reverse2)
    loss3 = F.mse_loss(latent_forward3, latent_reverse3)
    loss4 = F.mse_loss(latent_forward4, latent_reverse4)

    loss = loss2 + loss3 + loss4 + loss1
    # 6) Optimization step
    optimizer.zero_grad()
    loss.backward()
    # grad_norm = get_grad_norm(reverse_cons_model.unet.time_embedding.cond_proj)

    # torch.nn.utils.clip_grad_norm_(reverse_cons_model.unet.time_embedding.cond_proj.parameters(), 1.0)
    optimizer.step()

    # 7) Update step counter and log metrics to wandb
    step_counter += 1

    # 8) Periodically log sample images to wandb
    if step_counter % 1 == 0:
        batch_metrics = {
            "step": step_counter,
            "batch_start": start_idx,
            "batch_end": start_idx + BATCH_SIZE,
            "guidance_scale": 0.0,
            "loss": loss.item(),
            "loss_latent_r_1": loss1.item(),
            "loss_latent_r_2": loss2.item(),
            "loss_latent_r_3": loss3.item(),
            "loss_latent_r_4": loss4.item(),

            # "pixel_mse": pixel_mse,
            # "latent_mse": latent_mse,
            "diff_latents_a1": (a1.abs().mean().item(), (a1**2).mean().item()),
            "diff_latents_a2": (a2.abs().mean().item(), (a2**2).mean().item()),
            "diff_latents_a3": (a3.abs().mean().item(), (a3**2).mean().item()),
            "diff_latents_a4": (a4.abs().mean().item(), (a4**2).mean().item()), 
            "grad_norm": grad_norm,
        }
        # wandb.log(batch_metrics, step=step_counter)
        print("\n--- Batch Metrics ---")
        for key, value in batch_metrics.items():
            print(f"{key:20}: {value}")
        print("-" * 40)
#         try:
#             # with torch.no_grad():
#             gen_images_batch, gen_latents_batch, _ = generate_images_batch(
#                     solver, reverse_cons_model, batch_prompts, latents2[-1],
#             )
#             rec_pil = T.ToPILImage()(image_rec2[0].transpose(1, 2, 0))
#             gen_pil = T.ToPILImage()(gen_images_batch[0].transpose(1, 2, 0))

#             rec_pil.save(os.path.join("images", f"step_{step_counter}_sample_rec.jpg"))
#             gen_pil.save(os.path.join("images", f"step_{step_counter}_sample_gen.jpg"))
#             del (
#                 gen_images_batch, gen_latents_batch, _
#             )
#             gc.collect()
#             torch.cuda.empty_cache()
#         except Exception as e:
#             print(e)
#             break

    # Cleanup
    del (
        a1, 
        a2, 
        a3, 
        a4, 
        image_rec1, 
        latents1, 
        latent1, 
        image_rec2, 
        latents2, 
        latent2, 
        latent_forward1, 
        latent_reverse1, 
        latent_forward2, 
        latent_reverse2, 
        latent_forward3, 
        latent_reverse3, 
        latent_forward4, 
        latent_reverse4
    )
    gc.collect()
    torch.cuda.empty_cache()
    if step_counter == 2:
        break

Processing batches:   0%|          | 0/1000 [00:00<?, ?it/s]


--- Batch Metrics ---
step                : 1
batch_start         : 0
batch_end           : 1
guidance_scale      : 0.0
loss                : 0.1994294971227646
loss_latent_r_1     : 0.02583257295191288
loss_latent_r_2     : 0.06283518671989441
loss_latent_r_3     : 0.05859541893005371
loss_latent_r_4     : 0.05216630548238754
diff_latents_a1     : (0.1251124143600464, 0.02583257295191288)
diff_latents_a2     : (0.19543014466762543, 0.06283518671989441)
diff_latents_a3     : (0.18874910473823547, 0.05859541893005371)
diff_latents_a4     : (0.1778959035873413, 0.05216630548238754)
grad_norm           : nan
----------------------------------------


Processing batches:   0%|          | 1/1000 [00:01<27:53,  1.68s/it]

[NaN‑probe] after get_noise_pred: shape=torch.Size([1, 4, 64, 64]) min=nan max=nan
[NaN‑probe] after predicted_origin: shape=torch.Size([1, 4, 64, 64]) min=nan max=nan
[NaN‑probe] after get_noise_pred: shape=torch.Size([1, 4, 64, 64]) min=nan max=nan
[NaN‑probe] after predicted_origin: shape=torch.Size([1, 4, 64, 64]) min=nan max=nan
[NaN‑probe] after get_noise_pred: shape=torch.Size([1, 4, 64, 64]) min=nan max=nan
[NaN‑probe] after predicted_origin: shape=torch.Size([1, 4, 64, 64]) min=nan max=nan
[NaN‑probe] after get_noise_pred: shape=torch.Size([1, 4, 64, 64]) min=nan max=nan
[NaN‑probe] after predicted_origin: shape=torch.Size([1, 4, 64, 64]) min=nan max=nan

--- Batch Metrics ---
step                : 2
batch_start         : 1
batch_end           : 2
guidance_scale      : 0.0
loss                : nan
loss_latent_r_1     : nan
loss_latent_r_2     : nan
loss_latent_r_3     : nan
loss_latent_r_4     : nan
diff_latents_a1     : (nan, nan)
diff_latents_a2     : (nan, nan)
diff_latent

Processing batches:   0%|          | 1/1000 [00:03<55:27,  3.33s/it]


In [16]:
# image_rec1, latents1, latent1 = solver.cons_inversion(
#     batch_images,
#     w_embed_dim=512,
#     guidance_scale=0.0,
#     seed=0,
#     use_reverse_model=False,
# )
# with torch.amp.autocast("cuda"):
image_rec2, latents2, latent2 = solver.cons_inversion(
    batch_images,
    w_embed_dim=512,
    guidance_scale=-1.0,
    seed=0,
    use_reverse_model=True,
    use_w_embed=False
)

[NaN‑probe] after get_noise_pred: shape=torch.Size([1, 4, 64, 64]) min=nan max=nan
[NaN‑probe] after predicted_origin: shape=torch.Size([1, 4, 64, 64]) min=nan max=nan
[NaN‑probe] after get_noise_pred: shape=torch.Size([1, 4, 64, 64]) min=nan max=nan
[NaN‑probe] after predicted_origin: shape=torch.Size([1, 4, 64, 64]) min=nan max=nan
[NaN‑probe] after get_noise_pred: shape=torch.Size([1, 4, 64, 64]) min=nan max=nan
[NaN‑probe] after predicted_origin: shape=torch.Size([1, 4, 64, 64]) min=nan max=nan
[NaN‑probe] after get_noise_pred: shape=torch.Size([1, 4, 64, 64]) min=nan max=nan
[NaN‑probe] after predicted_origin: shape=torch.Size([1, 4, 64, 64]) min=nan max=nan


In [17]:
for i, (l1, l2) in enumerate(zip(latents1, latents2)):
    if torch.isnan(l2).any():
        print(f"NaNs in reverse latent {i}: max={l2.abs().max().item():.3e}")
        break

NameError: name 'latents1' is not defined

## Infer

In [None]:
# with torch.amp.autocast("cuda"):
image_rec2, latents2, latent2 = solver.cons_inversion(
    batch_images,
    w_embed_dim=512,
    guidance_scale=-1.0,
    seed=0,
    use_reverse_model=True,
)

In [None]:
gen_images_batch, gen_latents_batch, _ = generate_images_batch(
        solver, reverse_cons_model, batch_prompts, latents2[-1],
)

In [None]:
rec_pil = T.ToPILImage()(image_rec2[0].transpose(1, 2, 0))
gen_pil = T.ToPILImage()(gen_images_batch[0].transpose(1, 2, 0))

rec_pil.save(os.path.join("images", f"step_{step_counter}_sample_rec.jpg"))
gen_pil.save(os.path.join("images", f"step_{step_counter}_sample_gen.jpg"))

In [None]:
# del (
#     image_rec2, latents2, latent2, gen_images_batch, gen_latents_batch, _
# )
gc.collect()
torch.cuda.empty_cache()