In [None]:
import json
import os

os.environ["PYDEVD_DISABLE_FILE_VALIDATION"] = "1"
os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
os.environ.setdefault("HF_HUB_DISABLE_TELEMETRY", "1")
os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "1")
os.environ.setdefault("HF_HUB_DOWNLOAD_TIMEOUT", "1800")

HF_TOKEN_PATH = "/kaggle/input/imggenhub-hf-token/hf_token.json"
with open(HF_TOKEN_PATH, "r", encoding="utf-8") as hf_file:
    HF_TOKEN = json.load(hf_file)["HF_TOKEN"]

os.environ["HF_TOKEN"] = HF_TOKEN

In [None]:
import json
import torch
import gc
import os
from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline
from tqdm import tqdm
from datetime import datetime

HF_TOKEN = os.environ.get("HF_TOKEN")
if not HF_TOKEN:
    raise RuntimeError("HF_TOKEN not found. Ensure the Kaggle dataset is attached and synced.")

# -------- SETTINGS ---------
MODEL_ID = "stabilityai/stable-diffusion-xl-base-1.0"
REFINER_MODEL_ID = "stabilityai/stable-diffusion-xl-refiner-1.0"

# -------- PARAMETERS (can be overridden when running via papermill) ---------
PROMPTS = ['Fresh test: stable diffusion']
NEGATIVE_PROMPT = "blurry, distorted"
OUTPUT_DIR = "."
IMG_SIZE = (64, 64)
GUIDANCE = 0.8
PRECISION = "fp32"
STEPS = 1
SEED = 42
USE_GPU = True
USE_REFINER = True
REFINER_STEPS = 25
REFINER_GUIDANCE = 7.0
REFINER_PRECISION = "fp16"
# ----------------------------

def get_vram_gb():
    if torch.cuda.is_available():
        return torch.cuda.memory_allocated() / 1024**3
    return 0.0

device = "cuda" if USE_GPU and torch.cuda.is_available() else "cpu"
print(f"Device: {device}")
if device == "cuda":
    print(f"GPU: {torch.cuda.get_device_name(0)}")

print(f"\nVRAM at start: {get_vram_gb():.2f} GB")
print(f"Prompts: {len(PROMPTS)}")
print(f"Base precision: {PRECISION}, Refiner precision: {REFINER_PRECISION}")
print(f"Steps: {STEPS} (base), {REFINER_STEPS} (refiner)")

os.makedirs(OUTPUT_DIR, exist_ok=True)

# Determine base model precision
if PRECISION == "fp16":
    base_dtype = torch.float16
    base_variant = "fp16"
elif PRECISION == "fp32":
    base_dtype = torch.float32
    base_variant = None
else:
    base_dtype = torch.float16
    base_variant = "fp16"

# Determine refiner precision
if REFINER_PRECISION == "fp16":
    refiner_dtype = torch.float16
    refiner_variant = "fp16"
elif REFINER_PRECISION == "fp32":
    refiner_dtype = torch.float32
    refiner_variant = None
else:
    refiner_dtype = torch.float16
    refiner_variant = "fp16"

generator = torch.manual_seed(SEED)

# ============================================================================
# SEQUENTIAL TWO-STAGE GENERATION
# ============================================================================
for i, prompt in enumerate(tqdm(PROMPTS, desc="Generating")):
    print(f"\n{'='*60}")
    print(f"Prompt {i+1}/{len(PROMPTS)}: {prompt[:60]}...")
    print(f"{'='*60}")
    
    # ========================================================================
    # STAGE 1: LOAD BASE MODEL → GENERATE → DELETE
    # ========================================================================
    print(f"\n[STAGE 1: BASE MODEL]")
    print(f"  Loading base model...")
    print(f"  VRAM before load: {get_vram_gb():.2f} GB")
    
    pipe = StableDiffusionXLPipeline.from_pretrained(
        MODEL_ID,
        torch_dtype=base_dtype,
        variant=base_variant,
        use_safetensors=True,
        token=HF_TOKEN,
    ).to(device)
    
    print(f"  VRAM after load: {get_vram_gb():.2f} GB")
    print(f"  Running base inference ({STEPS} steps)...")
    
    base_image = pipe(
        prompt,
        negative_prompt=NEGATIVE_PROMPT,
        guidance_scale=GUIDANCE,
        num_inference_steps=STEPS,
        generator=generator,
        height=IMG_SIZE[0],
        width=IMG_SIZE[1],
        denoising_end=0.8 if USE_REFINER else None,
    ).images[0]
    
    print(f"  VRAM after inference: {get_vram_gb():.2f} GB")
    print(f"  Deleting base model from memory...")
    
    del pipe
    gc.collect()
    torch.cuda.empty_cache()
    
    print(f"  VRAM after cleanup: {get_vram_gb():.2f} GB")
    
    # ========================================================================
    # STAGE 2: LOAD REFINER → REFINE → DELETE
    # ========================================================================
    if USE_REFINER:
        print(f"\n[STAGE 2: REFINER MODEL]")
        print(f"  Loading refiner...")
        print(f"  VRAM before load: {get_vram_gb():.2f} GB")
        
        refiner = StableDiffusionXLImg2ImgPipeline.from_pretrained(
            REFINER_MODEL_ID,
            torch_dtype=refiner_dtype,
            variant=refiner_variant,
            use_safetensors=True,
            token=HF_TOKEN,
        ).to(device)
        
        print(f"  VRAM after load: {get_vram_gb():.2f} GB")
        print(f"  Running refiner inference ({REFINER_STEPS} steps)...")
        
        final_image = refiner(
            prompt,
            negative_prompt=NEGATIVE_PROMPT,
            image=base_image,
            guidance_scale=REFINER_GUIDANCE,
            num_inference_steps=REFINER_STEPS,
            generator=generator,
            denoising_start=0.8,
        ).images[0]
        
        print(f"  VRAM after inference: {get_vram_gb():.2f} GB")
        print(f"  Deleting refiner from memory...")
        
        del refiner
        gc.collect()
        torch.cuda.empty_cache()
        
        print(f"  VRAM after cleanup: {get_vram_gb():.2f} GB")
    else:
        final_image = base_image
    
    # Save
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    filename = f"generated_{i+1}_{timestamp}.png"
    filepath = os.path.join(OUTPUT_DIR, filename)
    final_image.save(filepath)
    print(f"\nSaved: {filepath}")

print(f"\n{'='*60}")
print(f"COMPLETE! {len(PROMPTS)} images saved to {OUTPUT_DIR}")
print(f"Final VRAM: {get_vram_gb():.2f} GB")
print(f"{'='*60}")