In [1]:
# Setup: paths and imports for FLUX.1-dev + LoRA
import os, sys
from pathlib import Path
import torch

# Make sd-scripts importable
root = Path.cwd().parent  # repository root: .../fluxgym
sys.path.append(str(root / "sd-scripts"))

from library.device_utils import get_preferred_device, init_ipex
init_ipex()
device = get_preferred_device()

# Paths (edit these to your actual checkpoints if different)
flux_ckpt = root / "models" / "unet" / "flux1-dev.sft"                 # FLUX.1-dev (UNet/DiT)
ae_ckpt   = root / "models" / "vae" / "ae.sft"                         # AutoEncoder
clip_ckpt = root / "models" / "clip" / "clip_l.safetensors"           # CLIP-L
t5_ckpt   = root / "models" / "clip" / "t5xxl_fp16.safetensors"       # T5 XXL
lora_ckpt = root / "outputs" / "tabtab-food-demo-style-01" / "tabtab-food-demo-style-01.safetensors"  # Your LoRA

# Dtypes (bf16 helps avoid OOM; keep CLIP and LoRA consistent later)
dtype = torch.bfloat16
clip_l_dtype = dtype
t5xxl_dtype = dtype
ae_dtype = dtype
flux_dtype = dtype

print("Using device:", device)
for p in [flux_ckpt, ae_ckpt, clip_ckpt, t5_ckpt, lora_ckpt]:
    print(p, "exists:", p.exists())

get_preferred_device() -> cuda
Using device: cuda
/home/huqianghui/fluxgym/models/unet/flux1-dev.sft exists: True
/home/huqianghui/fluxgym/models/vae/ae.sft exists: True
/home/huqianghui/fluxgym/models/clip/clip_l.safetensors exists: True
/home/huqianghui/fluxgym/models/clip/t5xxl_fp16.safetensors exists: True
/home/huqianghui/fluxgym/outputs/tabtab-food-demo-style-01/tabtab-food-demo-style-01.safetensors exists: True


In [2]:
# Load base FLUX components
from safetensors.torch import load_file
from library import flux_utils, strategy_flux

# Load text encoders
clip_l = flux_utils.load_clip_l(str(clip_ckpt), clip_l_dtype, device)
clip_l.eval()

t5xxl = flux_utils.load_t5xxl(str(t5_ckpt), t5xxl_dtype, device)
t5xxl.eval()

# Load the DiT (FLUX.1 dev) and tell if it's schnell or dev
is_schnell, flux_model = flux_utils.load_flow_model(str(flux_ckpt), flux_dtype, device, model_type="flux")
flux_model.eval()
print("is_schnell:", is_schnell)

# Load AE
ae = flux_utils.load_ae(str(ae_ckpt), ae_dtype, device)
ae.eval()

# Tokenization/encoding strategies
from library import strategy_flux as strat
max_len = 256 if is_schnell else 512
tokenize_strategy = strat.FluxTokenizeStrategy(max_len)
encoding_strategy = strat.FluxTextEncodingStrategy()

print("Loaded base components.")

  from .autonotebook import tqdm as notebook_tqdm


is_schnell: False


You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


Loaded base components.


In [3]:
# Apply your LoRA to the loaded base
import sys
from pathlib import Path
# Ensure sd-scripts is importable
root = Path.cwd().parent
if str(root / "sd-scripts") not in sys.path:
    sys.path.append(str(root / "sd-scripts"))

from networks import lora_flux, oft_flux
from safetensors.torch import load_file as load_safetensors

# Multiplier controls LoRA strength
def load_and_apply_lora(lora_path: str, multiplier: float = 1.0, merge: bool = False):
    sd = load_safetensors(lora_path)
    # Detect LoRA vs OFT
    is_lora = any(k.startswith("lora") for k in sd.keys())
    module = lora_flux if is_lora else oft_flux

    lora_model, _ = module.create_network_from_weights(
        multiplier, None, ae, [clip_l, t5xxl], flux_model, sd, True
    )

    if merge:
        lora_model.merge_to([clip_l, t5xxl], flux_model, sd)
        print("LoRA merged into base model.")
        return None
    else:
        lora_model.apply_to([clip_l, t5xxl], flux_model)
        info = lora_model.load_state_dict(sd, strict=True)
        lora_model.eval()
        # Keep CLIP and LoRA on GPU and same dtype to avoid matmul/device mismatch
        clip_l.to(device=device, dtype=clip_l_dtype)
        lora_model.to(device=device, dtype=clip_l_dtype)
        print(f"LoRA attached. Casted CLIP-L & LoRA to {clip_l_dtype} on {device}.", info)
        return lora_model

lora_model = load_and_apply_lora(str(lora_ckpt), multiplier=1.0, merge=False)

LoRA attached. Casted CLIP-L & LoRA to torch.bfloat16 on cuda. <All keys matched successfully>


In [None]:
# schedule + generate helpers
import torch, math, einops
from PIL import Image
import numpy as np
from typing import Optional


def time_shift(mu: float, sigma: float, t: torch.Tensor):
    return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)


def _lin(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15):
    m = (y2 - y1) / (x2 - x1)
    b = y1 - m * x1
    return lambda x: m * x + b


def get_schedule(num_steps: int, image_seq_len: int, shift: bool = True) -> list[float]:
    ts = torch.linspace(1.0, 0.0, num_steps + 1, device=device)
    if shift and not is_schnell:
        mu = _lin()(image_seq_len)
        ts = time_shift(mu, 1.0, ts)
    return ts.tolist()


def _denoise(
    model,
    img: torch.Tensor,
    img_ids: torch.Tensor,
    t5_out: Optional[torch.Tensor],
    txt_ids: Optional[torch.Tensor],
    l_pooled: Optional[torch.Tensor],
    timesteps: list[float],
    guidance: float,
    t5_attn_mask: Optional[torch.Tensor] = None,
    neg_t5_out: Optional[torch.Tensor] = None,
    neg_l_pooled: Optional[torch.Tensor] = None,
    neg_t5_attn_mask: Optional[torch.Tensor] = None,
    cfg_scale: Optional[float] = None,
):
    do_cfg = neg_t5_out is not None and (cfg_scale is not None and cfg_scale != 1.0)
    guidance_vec = torch.full((img.shape[0] * (2 if do_cfg else 1),), guidance, device=img.device, dtype=img.dtype)

    if do_cfg:
        b_img_ids = torch.cat([img_ids, img_ids], dim=0)
        b_txt_ids = torch.cat([txt_ids, txt_ids], dim=0) if txt_ids is not None else None
        b_txt = torch.cat([neg_t5_out, t5_out], dim=0)
        b_vec = torch.cat([neg_l_pooled, l_pooled], dim=0) if l_pooled is not None else None
        b_t5_attn_mask = (
            torch.cat([neg_t5_attn_mask, t5_attn_mask], dim=0) if (t5_attn_mask is not None and neg_t5_attn_mask is not None) else None
        )
    else:
        b_img_ids = img_ids
        b_txt_ids = txt_ids
        b_txt = t5_out
        b_vec = l_pooled
        b_t5_attn_mask = t5_attn_mask

    for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]):
        t_vec = torch.full((b_img_ids.shape[0],), t_curr, dtype=img.dtype, device=img.device)
        b_img = torch.cat([img, img], dim=0) if do_cfg else img

        mod_vectors = model.get_mod_vectors(timesteps=t_vec, guidance=guidance_vec, batch_size=b_img.shape[0])

        pred = model(
            img=b_img,
            img_ids=b_img_ids,
            txt=b_txt,
            txt_ids=b_txt_ids,
            y=b_vec,
            timesteps=t_vec,
            guidance=guidance_vec,
            txt_attention_mask=b_t5_attn_mask,
            mod_vectors=mod_vectors,
        )

        if do_cfg:
            pred_uncond, pred = torch.chunk(pred, 2, dim=0)
            pred = pred_uncond + cfg_scale * (pred - pred_uncond)

        img = img + (t_prev - t_curr) * pred

    return img


essential_channels, ph, pw = 16, 2, 2


def generate(
    prompt: str,
    negative_prompt: str = "",
    width: int = 768,
    height: int = 768,
    steps: int = 20,
    seed: Optional[int] = 1234,
    guidance: float = 3.5,
    cfg_scale: float = 1.0,
):
    torch.cuda.empty_cache()
    # Ensure multiples of 16 for Flux AE packing
    height = max(64, height - height % 16)
    width = max(64, width - width % 16)

    g = torch.Generator(device=device)
    if seed is not None:
        g.manual_seed(seed)

    # Encode text using strategies (encoders may be on CPU); move outputs to Flux device/dtype
    tokens = tokenize_strategy.tokenize(prompt)
    l_pooled, t5_out, txt_ids, t5_attn_mask = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, t5xxl], tokens)

    # Move encoded outputs to Flux device/dtype
    if l_pooled is not None:
        l_pooled = l_pooled.to(device=device, dtype=flux_dtype)
    t5_out = t5_out.to(device=device, dtype=flux_dtype)
    # Some impls return zeros for txt_ids; cast to flux dtype on device for safety
    if txt_ids is not None:
        try:
            txt_ids = txt_ids.to(device=device, dtype=flux_dtype)
        except Exception:
            txt_ids = txt_ids.to(device=device)

    use_mask = bool(getattr(encoding_strategy, "apply_t5_attn_mask", False))

    if cfg_scale != 1.0 and negative_prompt is not None:
        neg_tokens = tokenize_strategy.tokenize(negative_prompt)
        neg_l_pooled, neg_t5_out, _, neg_t5_attn_mask = encoding_strategy.encode_tokens(
            tokenize_strategy, [clip_l, t5xxl], neg_tokens
        )
        if neg_l_pooled is not None:
            neg_l_pooled = neg_l_pooled.to(device=device, dtype=flux_dtype)
        if neg_t5_out is not None:
            neg_t5_out = neg_t5_out.to(device=device, dtype=flux_dtype)
    else:
        neg_l_pooled = neg_t5_out = neg_t5_attn_mask = None

    # Move masks to device and proper dtype (bool works with SDPA)
    if use_mask and t5_attn_mask is not None:
        t5_attn_mask = t5_attn_mask.to(device, non_blocking=True).bool()
    else:
        t5_attn_mask = None
    if use_mask and neg_t5_attn_mask is not None:
        neg_t5_attn_mask = neg_t5_attn_mask.to(device, non_blocking=True).bool()
    else:
        neg_t5_attn_mask = None

    # Prepare latents
    H, W = height // 16, width // 16
    with torch.autocast("cuda", dtype=flux_dtype), torch.no_grad():
        x = torch.randn(1, H * W, essential_channels * ph * pw, device=device, dtype=flux_dtype, generator=g)
        img_ids = flux_utils.prepare_img_ids(1, H, W).to(device=device, dtype=flux_dtype)
        timesteps = get_schedule(steps, x.shape[1], shift=True)

        x = _denoise(
            flux_model,
            x,
            img_ids,
            t5_out,
            txt_ids,
            l_pooled,
            timesteps,
            guidance,
            t5_attn_mask,
            neg_t5_out,
            neg_l_pooled,
            neg_t5_attn_mask,
            cfg_scale,
        )

        # Rearrange to image space for AE
        x = einops.rearrange(x, "b (h w) (cc ph pw) -> b cc (h ph) (w pw)", h=H, w=W, ph=ph, pw=pw)

    # Decode with AE; fallback to CPU on OOM
    ae.to(device)
    ae_dtype = next(ae.parameters()).dtype
    x = x.to(device=device, dtype=ae_dtype)
    decode_dtype = ae_dtype if ae_dtype in (torch.float16, torch.bfloat16) else None
    try:
        if decode_dtype is not None:
            with torch.no_grad(), torch.autocast("cuda", dtype=decode_dtype):
                x = ae.decode(x)
        else:
            with torch.no_grad():
                x = ae.decode(x)
    except RuntimeError as e:
        if "CUDA out of memory" in str(e):
            # free some memory and try CPU decode
            torch.cuda.empty_cache()
            x_cpu = x.float().cpu()
            ae_cpu = ae.to("cpu", dtype=torch.float32)
            with torch.inference_mode():
                x = ae_cpu.decode(x_cpu)
            # move back to CUDA tensor for consistency (small tensor after clamp/permute handled on CPU anyway)
        else:
            raise

    x = x.clamp(-1, 1).permute(0, 2, 3, 1)
    img = Image.fromarray((127.5 * (x + 1.0)).float().detach().cpu().numpy().astype(np.uint8)[0])
    return img


In [5]:
# Offload text encoders to CPU to save VRAM before generation
import torch, gc

# Move encoders to CPU
if 'clip_l' in globals() and clip_l is not None:
    clip_l.to('cpu', dtype=torch.float32)
    clip_l.eval()
if 't5xxl' in globals() and t5xxl is not None:
    t5xxl.to('cpu', dtype=torch.float32)
    t5xxl.eval()

# Also move any text-encoder LoRA modules to CPU/float32 so devices match
try:
    moved = 0
    if 'lora_model' in globals() and lora_model is not None:
        for l in getattr(lora_model, 'text_encoder_loras', []):
            l.to('cpu', dtype=torch.float32).eval()
            for p in l.parameters():
                p.data = p.data.to(dtype=torch.float32)
            moved += 1
    if moved:
        print(f"Moved {moved} TE LoRA modules to CPU/float32 to match encoders.")
except Exception as e:
    print('warn: failed moving TE LoRA to CPU:', e)

# Optionally reduce T5 token length to speed up CPU encoding
try:
    from library.strategy_flux import FluxTokenizeStrategy
    tokenize_strategy = FluxTokenizeStrategy(t5xxl_max_length=256)
except Exception:
    pass

# Free up GPU memory held by encoders
gc.collect()
torch.cuda.empty_cache()

print('Encoders (and TE LoRA) moved to CPU. CUDA memory cleared.')

Moved 72 TE LoRA modules to CPU/float32 to match encoders.
Encoders (and TE LoRA) moved to CPU. CUDA memory cleared.
Encoders (and TE LoRA) moved to CPU. CUDA memory cleared.


In [6]:
# Quick test: call generate() and display the result
from IPython.display import display

prompt = "a colorful bowl of noodles, food photography, 50mm, depth of field"
img = generate(prompt, width=512, height=512, steps=12, guidance=3.5, cfg_scale=1.0, seed=1234)
display(img)

RuntimeError: Can't call numpy() on Tensor that requires grad. Use tensor.detach().numpy() instead.

In [None]:
# (Optional) Move everything to CUDA (bf16) for single-device execution
import torch

# Target device/dtype
_cuda = torch.device('cuda')
_dt = torch.bfloat16

# Core models
if 'flux_model' in globals() and flux_model is not None:
    flux_model.to(device=_cuda, dtype=_dt).eval()
if 'ae' in globals() and ae is not None:
    ae.to(device=_cuda, dtype=_dt).eval()
if 'clip_l' in globals() and clip_l is not None:
    clip_l.to(device=_cuda, dtype=_dt).eval()
if 't5xxl' in globals() and t5xxl is not None:
    t5xxl.to(device=_cuda, dtype=_dt).eval()

# LoRA modules
if 'lora_model' in globals() and lora_model is not None:
    lora_model.to(device=_cuda, dtype=_dt).eval()
    # Ensure nested LoRA submodules are on the same device/dtype
    for p in lora_model.parameters():
        p.data = p.data.to(device=_cuda, dtype=_dt)

# Update globals
device = _cuda
clip_l_dtype = _dt
t5xxl_dtype = _dt
ae_dtype = _dt
flux_dtype = _dt

import gc
gc.collect(); torch.cuda.empty_cache()
print('All models and LoRA moved to CUDA with bfloat16 (optional step).')

All models and LoRA moved to CUDA with bfloat16.
