In [None]:
import os
import json
import torch
from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline, StableDiffusionXLInpaintPipeline
from tqdm import tqdm
from datetime import datetime
import papermill as pm

# -------- SETTINGS ---------
MODEL_ID = "stabilityai/stable-diffusion-xl-base-1.0"
REFINER_MODEL_ID = "stabilityai/stable-diffusion-xl-refiner-1.0"
# -------- PARAMETERS (can be overridden when running via papermill) ---------
PROMPTS = ['A highly detailed photorealistic portrait of a young Russian woman visible from chest upwards, with long flowing hair, professional studio lighting, sharp focus on eyes, cinematic composition, 8K resolution, masterpiece quality, with Russian flag in the background on one side and some famous Russian buildings in afar on the left side. Make sure the flag and the Russian buildings are noticeable.']
NEGATIVE_PROMPT = "blurry, low quality, distorted, watermark, duplicate, multiple identical people, clones, repetition, cartoon, anime, painting, drawing, sketch, low resolution, pixelated, noisy, grainy, artifacts, overexposed, underexposed, bad anatomy, deformed, ugly, disfigured, poorly lit, bad composition"
OUTPUT_DIR = "output/20251114_124830"
IMG_SIZE = (1080, 1920)  # HD ~16:9 aspect ratio for wide images
GUIDANCE = 12.0
STEPS = 75
REFINER_STEPS = 20  # Additional refinement steps
SEED = 42
USE_GPU = True
USE_REFINER = False
PRECISION = "int8"
# ----------------------------

print("Parameters:")
print(f"PROMPTS: {PROMPTS}")
print(f"NEGATIVE_PROMPT: {NEGATIVE_PROMPT}")
print(f"OUTPUT_DIR: {OUTPUT_DIR}")
print(f"IMG_SIZE: {IMG_SIZE}")
print(f"GUIDANCE: {GUIDANCE}")
print(f"STEPS: {STEPS}")
print(f"REFINER_STEPS: {REFINER_STEPS}")
print(f"SEED: {SEED}")
print(f"USE_GPU: {USE_GPU}")
print(f"USE_REFINER: {USE_REFINER}")
print(f"PRECISION: {PRECISION}")
print()

os.makedirs(OUTPUT_DIR, exist_ok=True)

device = "cuda" if USE_GPU and torch.cuda.is_available() else "cpu"
print("GPU available:", torch.cuda.get_device_name(0) if device=="cuda" else "Running on CPU.")

# Determine torch dtype and model variant based on precision
if PRECISION == "fp32":
    torch_dtype = torch.float32
    variant = None
elif PRECISION == "fp16":
    torch_dtype = torch.float16 if device == "cuda" else torch.float32
    variant = "fp16" if device == "cuda" else None
elif PRECISION in ["int8", "int4"]:
    torch_dtype = torch.float16 if device == "cuda" else torch.float32
    variant = None  # Quantization will be handled separately
else:
    raise ValueError(f"Unsupported precision: {PRECISION}")

# Load base pipeline
pipe = StableDiffusionXLPipeline.from_pretrained(
    MODEL_ID,
    torch_dtype=torch_dtype,
    safety_checker=None,
    variant=variant,
).to(device)

# Load refiner pipeline if enabled
if USE_REFINER:
    refiner = StableDiffusionXLImg2ImgPipeline.from_pretrained(
        REFINER_MODEL_ID,
        torch_dtype=torch_dtype,
        safety_checker=None,
        variant=variant,
    ).to(device)
    print("Refiner model loaded for enhanced photorealism")

generator = torch.manual_seed(SEED)

# Generate images (1 per prompt)
for i, prompt in enumerate(tqdm(PROMPTS, desc="Generating")):
    # Base generation
    base_image = pipe(
        prompt,
        negative_prompt=NEGATIVE_PROMPT,
        guidance_scale=GUIDANCE,
        num_inference_steps=STEPS,
        generator=generator,
        height=IMG_SIZE[0],
        width=IMG_SIZE[1],
        denoising_end=0.8 if USE_REFINER else None,  # Stop at 80% for refiner
    ).images[0]
    
    # Refinement step for photorealism
    if USE_REFINER:
        final_image = refiner(
            prompt=prompt,
            negative_prompt=NEGATIVE_PROMPT,
            image=base_image,
            guidance_scale=GUIDANCE,
            num_inference_steps=REFINER_STEPS,
            generator=generator,
            denoising_start=0.8,  # Start from 80%
        ).images[0]
    else:
        final_image = base_image
    
    # Save the image
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    filename = f"image_{i+1:03d}_{timestamp}.png"
    filepath = os.path.join(OUTPUT_DIR, filename)
    final_image.save(filepath)
    print(f"Saved: {filepath}")

print(f"\nGeneration complete! Images saved to: {OUTPUT_DIR}")
print(f"Total images generated: {len(PROMPTS)}")