In [None]:
!pip -q install --upgrade diffusers transformers accelerate safetensors torchvision huggingface_hub

from huggingface_hub import notebook_login
notebook_login()

In [None]:
import os, requests, importlib.util, sys

raw_url="https://raw.githubusercontent.com/huggingface/diffusers/main/examples/community/pipeline_flux_with_cfg.py"
os.makedirs("/content/flux_cfg", exist_ok=True)
open("/content/flux_cfg/pipeline_flux_with_cfg.py", "w").write(requests.get(raw_url).text)

spec=importlib.util.spec_from_file_location("pipeline_flux_with_cfg", "/content/flux_cfg/pipeline_flux_with_cfg.py")
mod=importlib.util.module_from_spec(spec); sys.modules[spec.name]=mod; spec.loader.exec_module(mod)
FluxCFGPipeline=mod.FluxCFGPipeline


In [None]:
import math, shutil
from typing import Dict, Any
import torch
import torchvision.utils as vutils
from diffusers import FluxPipeline

@torch.no_grad()
def decode_and_save_grid(pipe: FluxPipeline, latents: torch.Tensor, path: str, nrow: int=10):
    vae=pipe.vae
    scaling=getattr(vae.config, "scaling_factor", 0.18215)
    imgs01=vae.decode(latents/scaling).sample.clamp_(0, 1)
    grid=vutils.make_grid(imgs01, nrow=nrow)
    vutils.save_image(grid, path)

model_id="black-forest-labs/FLUX.1-dev"  # requires HF acceptance/login
dtype=torch.bfloat16

In [None]:
device="cuda" if torch.cuda.is_available() else "cpu"
print("Using:", device, dtype)

pipe=FluxPipeline.from_pretrained(model_id, torch_dtype=dtype)

pipe.enable_model_cpu_offload()
# pipe.set_progress_bar_config(disable=True)

In [None]:
prompt="indian institute of science"
steps=20
panels=10
guidance=3.5
width=256
height=256
seed=42
num_images=100
nrow=10

out_dir=f"/content/viz_flux_{prompt.replace(' ', '_')}"
os.makedirs(out_dir, exist_ok=True)

torch.manual_seed(seed)

capture_idxs=set(torch.linspace(0, steps - 1, panels, dtype=torch.long).tolist())
saved={}

def on_step_end(pipe, step: int, timestep: int, kwargs: dict) -> dict:
    if step in capture_idxs and "latents" in kwargs:
        lat=kwargs["latents"]

        vae_latents=pipe._unpack_latents(lat, height, width, pipe.vae_scale_factor)

        path=os.path.join(out_dir, f"step_{step:03d}.png")
        decode_and_save_grid(pipe, vae_latents, path, nrow=nrow)
        saved[step]=path
    return kwargs

_=pipe(
    prompt=prompt,
    height=height,
    width=width,
    guidance_scale=guidance,
    num_inference_steps=steps,
    num_images_per_prompt=num_images,
    output_type="pil",
    callback_on_step_end=on_step_end,
    callback_on_step_end_tensor_inputs=["latents"],
)

if saved:
    last=max(saved)
    shutil.copyfile(saved[last], os.path.join(out_dir, "final.png"))
print(f"Saved {len(saved)} panels to {out_dir}")

!zip {out_dir}.zip -r {out_dir}

In [None]:
import os, math, torch, numpy as np, torchvision.utils as vutils
from typing import List
from PIL import Image

prompt="umbrella"
neg=""
height, width=256, 256
steps=16
imgs_per_grid=8
nrow=4
w_cfg_max=10
out_dir="/content/final_image_cfg_compare"
os.makedirs(out_dir, exist_ok=True)

tag=prompt.replace(" ", "_")

dtype=torch.bfloat16
pipe=FluxCFGPipeline.from_pretrained(model_id, torch_dtype=dtype)
pipe.enable_model_cpu_offload()
pipe.set_progress_bar_config(disable=True)

def pil_grid(images: List[Image.Image], path: str, nrow: int):
    tens=[torch.from_numpy(np.array(im)).permute(2,0,1).float().div_(255) for im in images]
    grid=vutils.make_grid(torch.stack(tens, 0), nrow=nrow)
    vutils.save_image(grid, path)

def gen_imgs(text: str, num: int, seed0: int, **kwargs) -> List[Image.Image]:
    out=[]
    for i in range(num):
        g=torch.Generator(device="cpu").manual_seed(seed0 + i)
        im=pipe(prompt=text, height=height, width=width,
                  num_inference_steps=steps, generator=g,
                  output_type="pil", **kwargs).images[0]
        out.append(im)
    return out

imgs_uncond=gen_imgs("", imgs_per_grid, seed0=0,
                        negative_prompt=None,
                        guidance_scale=1.0,
                        true_cfg=1e-6)

imgs_cond=gen_imgs(prompt, imgs_per_grid, seed0=0,
                      negative_prompt=neg,
                      guidance_scale=1.0,
                      true_cfg=1e-6)

pil_grid(imgs_uncond, os.path.join(out_dir, f"{tag}_uncond.png"), nrow)
pil_grid(imgs_cond, os.path.join(out_dir, f"{tag}_cond.png"), nrow)

del imgs_uncond, imgs_cond

for w_cfg in range(w_cfg_max+1):
    imgs_cfg=gen_imgs(prompt, imgs_per_grid, seed0=0,
                        negative_prompt=neg,
                        guidance_scale=1.0,
                        true_cfg=float(w_cfg))

    pil_grid(imgs_cfg, os.path.join(out_dir, f"{tag}_truecfg_w{str(w_cfg)}.png"), nrow)

    del imgs_cfg

!zip {out_dir}.zip {out_dir}/*