In [None]:
!pip install --upgrade accelerate sentencepiece 
# hf_xet kagglehub

In [None]:
import os
import json
import torch
import gc
from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline
from tqdm import tqdm
from datetime import datetime

# -------- SETTINGS ---------
MODEL_ID = "stabilityai/stable-diffusion-xl-base-1.0"
REFINER_MODEL_ID = "stabilityai/stable-diffusion-xl-refiner-1.0"

# -------- PARAMETERS (can be overridden when running via papermill) ---------
PROMPTS = ['bombed out high rise soviet apartment in kharkiv 2022 ukraine war']
NEGATIVE_PROMPT = "blurry, low quality, distorted, watermark, duplicate, multiple identical people, clones, repetition, cartoon, anime, painting, drawing, sketch, low resolution, pixelated, noisy, grainy, artifacts, overexposed, underexposed, bad anatomy, deformed, ugly, disfigured, poorly lit, bad composition"
OUTPUT_DIR = "kharkiv_war_apartment_20251120_160600"
IMG_SIZE = (1080, 1920)
GUIDANCE = 10.0
PRECISION = "fp16"
STEPS = 75
SEED = 42
USE_GPU = True
USE_REFINER = False
REFINER_STEPS = 15
REFINER_GUIDANCE = 7.0
REFINER_PRECISION = "fp16"
# ----------------------------

def get_vram_gb():
    if torch.cuda.is_available():
        return torch.cuda.memory_allocated() / 1024**3
    return 0.0

device = "cuda" if USE_GPU and torch.cuda.is_available() else "cpu"
print(f"Device: {device}")
if device == "cuda":
    print(f"GPU: {torch.cuda.get_device_name(0)}")

print(f"\nVRAM at start: {get_vram_gb():.2f} GB")
print(f"Prompts: {len(PROMPTS)}")
print(f"Base precision: {PRECISION}")
if USE_REFINER:
    print(f"Refiner precision: {REFINER_PRECISION}")
print(f"Steps: {STEPS} (base)")
if USE_REFINER:
    print(f"Refiner steps: {REFINER_STEPS} (refiner)")

os.makedirs(OUTPUT_DIR, exist_ok=True)

# Determine base model precision
if PRECISION == "fp16":
    base_dtype = torch.float16
    base_variant = "fp16"
elif PRECISION == "fp32":
    base_dtype = torch.float32
    base_variant = None
else:
    base_dtype = torch.float16
    base_variant = "fp16"

# Determine refiner precision
if REFINER_PRECISION == "fp16":
    refiner_dtype = torch.float16
    refiner_variant = "fp16"
elif REFINER_PRECISION == "fp32":
    refiner_dtype = torch.float32
    refiner_variant = None
else:
    refiner_dtype = torch.float16
    refiner_variant = "fp16"

generator = torch.manual_seed(SEED)

BATCH_SIZE = 1  # tune

batched_prompts = [PROMPTS[i:i+BATCH_SIZE] for i in range(0, len(PROMPTS), BATCH_SIZE)]
all_base_images = []

# ============================================================
# STAGE 1 — LOAD BASE ONCE → RUN ALL BATCHES → DELETE
# ============================================================
pipe = StableDiffusionXLPipeline.from_pretrained(
    MODEL_ID,
    torch_dtype=base_dtype,
    variant=base_variant,
    use_safetensors=True,
    low_cpu_mem_usage=True,
).to(device)

for batch in batched_prompts:
    out = pipe(
        batch,
        negative_prompt=[NEGATIVE_PROMPT]*len(batch),
        guidance_scale=GUIDANCE,
        num_inference_steps=STEPS,
        generator=generator,
        height=IMG_SIZE[0],
        width=IMG_SIZE[1],
        denoising_end=0.8 if USE_REFINER else None,
    ).images
    all_base_images.extend(out)

pipe.to("cpu")
del pipe
torch.cuda.synchronize()
gc.collect()
torch.cuda.empty_cache()

# ============================================================
# STAGE 2 — OPTIONAL REFINER: LOAD ONCE → RUN ALL BATCHES → DELETE
# ============================================================
if USE_REFINER:
    refiner = StableDiffusionXLImg2ImgPipeline.from_pretrained(
        REFINER_MODEL_ID,
        torch_dtype=refiner_dtype,
        variant=refiner_variant,
        use_safetensors=True,
        low_cpu_mem_usage=True,
    ).to(device)

    refined_images = []
    idx = 0
    for batch in batched_prompts:
        imgs = all_base_images[idx: idx+len(batch)]
        idx += len(batch)

        out = refiner(
            batch,
            negative_prompt=[NEGATIVE_PROMPT]*len(batch),
            image=imgs,
            guidance_scale=REFINER_GUIDANCE,
            num_inference_steps=REFINER_STEPS,
            generator=generator,
            denoising_start=0.8,
        ).images
        refined_images.extend(out)

    refiner.to("cpu")
    del refiner
    torch.cuda.synchronize()
    gc.collect()
    torch.cuda.empty_cache()

    final_images = refined_images
else:
    final_images = all_base_images


# ============================================================
# SAVE
# ============================================================
for i, img in enumerate(final_images):
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    fname = f"generated_{i+1}_{timestamp}.png"
    img.save(os.path.join(OUTPUT_DIR, fname))
    print(f"Saved: {os.path.join(OUTPUT_DIR, fname)}")

print(f"\n{'='*60}")
print(f"COMPLETE! {len(PROMPTS)} images saved to {OUTPUT_DIR}")
print(f"Final VRAM: {get_vram_gb():.2f} GB")
print(f"{'='*60}")