# Flux Text-to-Image Manual Test

This notebook mirrors the logic in `nodetool-huggingface/src/nodetool/nodes/huggingface/text_to_image.py` so you can interactively validate FLUX pipelines (standard or GGUF-quantized). Run all cells after activating the `nodetool` conda environment.

## Setup Checklist

1. `conda activate nodetool`
2. `pip install -U torch diffusers transformers accelerate huggingface_hub` if the deps are missing.
3. `huggingface-cli login` (or set `HF_TOKEN`) before downloading gated models.
4. If you plan to test GGUF quantized weights, download them first via `huggingface-cli download <repo> <file>`.
5. Restart the kernel after changing low-level libraries.

In [None]:
import platform
import torch
import diffusers
import huggingface_hub

def _default_device() -> str:
    if torch.cuda.is_available():
        return "cuda"
    if getattr(torch.backends, 'mps', None) and torch.backends.mps.is_available():
        return "mps"
    return "cpu"

device = _default_device()
print(f'Python: {platform.python_version()}')
print(f'Torch: {torch.__version__}')
print(f'Diffusers: {diffusers.__version__}')
print(f'Hugging Face Hub: {huggingface_hub.__version__}')
print(f'Active device: {device}')

## Configure A Test Run

Adjust the dataclass below to match the FLUX variant you want to exercise. Set `use_gguf=True` plus `gguf_repo_id` and `gguf_filename` to load quantized weights; otherwise the notebook pulls the full diffusion weights from Hugging Face.

In [None]:
from dataclasses import dataclass, field
from enum import Enum
from typing import Optional

class FluxVariant(Enum):
    SCHNELL = 'schnell'
    DEV = 'dev'
    FILL_DEV = 'fill-dev'
    CANNY_DEV = 'canny-dev'
    DEPTH_DEV = 'depth-dev'

@dataclass
class FluxConfig:
    prompt: str = "A cozy living room lit by warm neon strips, rendered in a cinematic style"
    guidance_scale: float = 3.5
    num_inference_steps: int = 20
    width: int = 1024
    height: int = 1024
    max_sequence_length: int = 512
    seed: int = 0
    model_repo_id: Optional[str] = 'black-forest-labs/FLUX.1-dev'
    model_path: Optional[str] = None
    use_gguf: bool = False
    gguf_repo_id: Optional[str] = None
    gguf_filename: Optional[str] = None
    custom_gguf_path: Optional[str] = None
    enable_cpu_offload: bool = False
    enable_vae_tiling: bool = False
    enable_vae_slicing: bool = False
    output_path: Optional[str] = 'flux_test.png'
    device: str = field(default_factory=lambda: 'cuda' if torch.cuda.is_available() else ('mps' if getattr(torch.backends, 'mps', None) and torch.backends.mps.is_available() else 'cpu'))

cfg = FluxConfig()
cfg

## Load The Pipeline

The helpers below reproduce the variant detection, GGUF handling, and memory optimizations from the production node.

In [None]:
from huggingface_hub import hf_hub_download
from huggingface_hub.file_download import try_to_load_from_cache
from diffusers import FluxPipeline
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
from diffusers.quantizers.quantization_config import GGUFQuantizationConfig

def detect_variant(model_repo_id: Optional[str], model_hint: Optional[str]) -> FluxVariant:
    target = f"{model_repo_id or ''} {model_hint or ''}".lower()
    if 'schnell' in target:
        return FluxVariant.SCHNELL
    if 'fill' in target:
        return FluxVariant.FILL_DEV
    if 'canny' in target:
        return FluxVariant.CANNY_DEV
    if 'depth' in target:
        return FluxVariant.DEPTH_DEV
    if 'dev' in target:
        return FluxVariant.DEV
    return FluxVariant.DEV

def _is_gguf_path(path: Optional[str]) -> bool:
    return bool(path and path.lower().endswith('.gguf'))

def _resolve_dtype(variant: FluxVariant):
    return torch.bfloat16 if variant in {FluxVariant.SCHNELL, FluxVariant.DEV} else torch.float16

def load_flux_pipeline(cfg: FluxConfig):
    variant = detect_variant(cfg.model_repo_id, cfg.model_path or cfg.gguf_filename or cfg.custom_gguf_path)
    torch_dtype = _resolve_dtype(variant)
    if cfg.use_gguf or _is_gguf_path(cfg.model_path):
        gguf_path = cfg.custom_gguf_path
        if gguf_path is None:
            if not cfg.gguf_repo_id or not cfg.gguf_filename:
                raise ValueError('Set `gguf_repo_id` and `gguf_filename` (or `custom_gguf_path`) to test GGUF weights.')
            gguf_path = try_to_load_from_cache(cfg.gguf_repo_id, cfg.gguf_filename)
            if gguf_path is None:
                gguf_path = hf_hub_download(repo_id=cfg.gguf_repo_id, filename=cfg.gguf_filename)
        transformer = FluxTransformer2DModel.from_single_file(
            gguf_path,
            quantization_config=GGUFQuantizationConfig(compute_dtype=torch_dtype),
            torch_dtype=torch_dtype,
        )
        base_model_id = cfg.model_repo_id or ('black-forest-labs/FLUX.1-schnell' if variant == FluxVariant.SCHNELL else 'black-forest-labs/FLUX.1-dev')
        pipeline = FluxPipeline.from_pretrained(base_model_id, transformer=transformer, torch_dtype=torch_dtype)
    else:
        if not cfg.model_repo_id:
            raise ValueError('Configure `model_repo_id` when `use_gguf` is False.')
        pipeline = FluxPipeline.from_pretrained(cfg.model_repo_id, torch_dtype=torch_dtype)
    if cfg.enable_cpu_offload:
        pipeline.enable_sequential_cpu_offload()
    else:
        pipeline.to(cfg.device)
        if cfg.enable_vae_slicing:
            pipeline.enable_vae_slicing()
        if cfg.enable_vae_tiling:
            pipeline.enable_vae_tiling()
    return pipeline, variant

flux_pipeline, selected_variant = load_flux_pipeline(cfg)
selected_variant

## Generate An Image

This cell mirrors the production sampling logic, including variant-specific overrides and the same callback signature used for streaming progress updates.

In [None]:
import math
from time import perf_counter

def _derive_sampling_params(cfg: FluxConfig, variant: FluxVariant):
    guidance = cfg.guidance_scale
    steps = cfg.num_inference_steps
    max_seq = cfg.max_sequence_length
    if variant == FluxVariant.SCHNELL:
        guidance = 0.0
        steps = 4
        max_seq = min(256, cfg.max_sequence_length)
    return guidance, steps, max_seq

guidance_scale, num_steps, max_seq_len = _derive_sampling_params(cfg, selected_variant)
generator = None if cfg.seed < 0 else torch.Generator(device='cpu').manual_seed(cfg.seed)

def progress_callback(_, step: int, timestep: int, callback_kwargs: dict):
    if step == 0 or step == num_steps - 1 or step % max(1, num_steps // 5) == 0:
        print(f'[Flux] step {step + 1}/{num_steps} | timestep {timestep}')
    return callback_kwargs

start = perf_counter()
try:
    flux_output = flux_pipeline(
        prompt=cfg.prompt,
        guidance_scale=guidance_scale,
        height=cfg.height,
        width=cfg.width,
        num_inference_steps=num_steps,
        max_sequence_length=max_seq_len,
        generator=generator,
        callback_on_step_end=progress_callback,
        callback_on_step_end_tensor_inputs=['latents'],
    )
except torch.OutOfMemoryError as exc:
    raise RuntimeError('VRAM out of memory while running Flux. Reduce the resolution/steps or enable cfg.enable_cpu_offload.') from exc
finally:
    print(f'Elapsed: {perf_counter() - start:.2f}s')

flux_image = flux_output.images[0]
flux_image

## Preview & Persist

Display the generated image inline and optionally save it for quick comparisons.

In [None]:
from PIL import Image
from pathlib import Path

display(flux_image)
if cfg.output_path:
    output_path = Path(cfg.output_path)
    output_path.parent.mkdir(parents=True, exist_ok=True)
    flux_image.save(output_path)
    print(f'Saved image to {output_path.resolve()}')