In [None]:
import gc 
import os
os.environ["PYDEVD_DISABLE_FILE_VALIDATION"] = "1"

In [None]:
from diffusers import FluxPipeline
from datetime import datetime
import torch

In [None]:
import json

# -------- LOAD HF TOKEN FROM DATASET ---------
HF_TOKEN_PATH = "/kaggle/input/imggenhub-hf-token/hf_token.json"
with open(HF_TOKEN_PATH, "r") as f:
    HF_TOKEN = json.load(f)["HF_TOKEN"]
os.environ["HF_TOKEN"] = HF_TOKEN
print("HF_TOKEN loaded from dataset")

In [None]:
MODEL_ID = "black-forest-labs/FLUX.1-schnell"
PROMPTS = ['Fresh flux bf16 test']
OUTPUT_DIR = "."
IMG_SIZE = (1024, 1024)
GUIDANCE = 0.0
STEPS = 4
SEED = 42
PRECISION = "bf16"


In [None]:
os.makedirs(OUTPUT_DIR, exist_ok=True)

dtype_map = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}
torch_dtype = dtype_map.get(PRECISION, torch.bfloat16)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")

def get_vram_gb():
    if torch.cuda.is_available():
        return torch.cuda.memory_allocated() / 1024**3
    return 0.0

print(f"Loading model...")
pipe = FluxPipeline.from_pretrained(MODEL_ID, torch_dtype=torch_dtype, token=HF_TOKEN)
pipe.enable_vae_tiling()
pipe.enable_attention_slicing()
pipe.set_progress_bar_config(disable=False)
pipe.enable_sequential_cpu_offload()

print(f"Model loaded (VRAM: {get_vram_gb():.2f} GB)")

for i, prompt in enumerate(PROMPTS):
    print(f"[{i+1}/{len(PROMPTS)}] {prompt}")
    generator = torch.Generator(device=device).manual_seed(SEED + i)
    image = pipe(
        prompt,
        height=IMG_SIZE[0],
        width=IMG_SIZE[1],
        guidance_scale=GUIDANCE,
        num_inference_steps=STEPS,
        generator=generator,
        max_sequence_length=256,
    ).images[0]
    filename = f"flux_{i+1}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.png"
    image.save(os.path.join(OUTPUT_DIR, filename))
    print(f"Saved: {filename}")
    if device == "cuda":
        torch.cuda.empty_cache()
        gc.collect()

print(f"Complete! {len(PROMPTS)} images in {OUTPUT_DIR}")

In [None]:
from IPython.display import display, Markdown
from PIL import Image
import glob
import natsort

# Collect all generated PNGs
image_paths = natsort.natsorted(glob.glob(os.path.join(OUTPUT_DIR, "generated_*.png")))

print(f"Displaying {len(image_paths)} generated images with prompts:")

# Make sure PROMPTS order matches generated images
for i, path in enumerate(image_paths):
    prompt = PROMPTS[i] if i < len(PROMPTS) else "Unknown prompt"
    display(Markdown(f"**Prompt {i+1}:** {prompt}"))
    img = Image.open(path)
    display(img)
    print("-"*50)