# Take the following code to and run it on Google Colab use GPU runtime.

In [None]:
!pip install -qU pyngrok nest_asyncio fastapi uvicorn

import nest_asyncio
from pyngrok import ngrok
import uvicorn

nest_asyncio.apply()

port = 7860

ngrok.kill()
ngrok.set_auth_token("Your Private Key from NGROK")

public_url = ngrok.connect(port, bind_tls=True)
print(f"🔗 Public FastAPI URL: {public_url}/docs")

In [None]:
from fastapi import FastAPI, Form, UploadFile, File
from pydantic import BaseModel
from fastapi.responses import JSONResponse
from fastapi.staticfiles import StaticFiles
import uuid, os, torch, traceback, asyncio
from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline
from PIL import Image
import nest_asyncio
import uvicorn
from typing import Optional
import base64
from io import BytesIO
# ========== setting FastAPI ==========
app = FastAPI()

OUTPUT_DIR = "stable_diffusion_output"
EDITED_DIR = "edited_images"
os.makedirs(OUTPUT_DIR, exist_ok=True)
os.makedirs(EDITED_DIR, exist_ok=True)

app.mount("/generated", StaticFiles(directory=OUTPUT_DIR), name="generated")
app.mount("/edited", StaticFiles(directory=EDITED_DIR), name="edited")

device = "cuda" if torch.cuda.is_available() else "cpu"

# ========== Load Models ==========
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5").to(device)
pipe2 = StableDiffusionImg2ImgPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5",
    torch_dtype=torch.float16 if device == "cuda" else torch.float32
).to(device)

# ========== Concurrency Control ==========
pipe_semaphore = asyncio.Semaphore(2)  # allow 2 concurrent generations

# تحسين إضافي: استخدام xformers إن توفر
try:
    pipe.enable_xformers_memory_efficient_attention()
    pipe2.enable_xformers_memory_efficient_attention()
except Exception:
    pass

# ========== تحديد التزامن ==========
pipe_semaphore = asyncio.Semaphore(2)  # يمكن تعديل العدد حسب موارد الجهاز

# ========== الهياكل ==========
class Prompt(BaseModel):
    text: str

class EditImageRequest(BaseModel):
    prompt: str
    image_base64: str

# ========== توليد صورة واحدة (للاختبار) ==========
@app.post("/generate_image")
async def generate_image_endpoint(prompt: Prompt):
    try:
        async with pipe_semaphore:
            pipe.scheduler.set_timesteps(30)
            result = pipe(prompt.text, num_inference_steps=30)
            image = result.images[0]

        buffered = BytesIO()
        image.save(buffered, format="PNG")
        img_bytes = buffered.getvalue()
        img_base64 = base64.b64encode(img_bytes).decode("utf-8")

        return {"image_base64": img_base64}
    except Exception as e:
        traceback.print_exc()
        return JSONResponse(content={"error": str(e)}, status_code=500)

# ========== توليد عدة صور دفعة واحدة ==========
async def generate_image_batch_async(prompt_text: str, count: int = 4) -> list:
    async with pipe_semaphore:
        try:
            pipe.scheduler.set_timesteps(30)
            result = pipe(prompt_text, num_inference_steps=30, num_images_per_prompt=count)

            urls = []
            for image in result.images:
                filename = f"{uuid.uuid4().hex}.png"
                file_path = os.path.join(OUTPUT_DIR, filename)
                image.save(file_path)
                urls.append(f"generated/{filename}")

            return urls
        except Exception as e:
            traceback.print_exc()
            return ["error: exception"]

@app.post("/multi_image_generation")
async def multi_image_generation(prompt: Prompt, count: int = 4):
    try:
        if not prompt.text.strip():
            return JSONResponse(content={"error": "Prompt is required."}, status_code=400)

        urls = await generate_image_batch_async(prompt.text, count)
        return {"urls": urls}
    except Exception as e:
        traceback.print_exc()
        return JSONResponse(content={"error": str(e)}, status_code=500)

# ========== تعديل صورة باستخدام img2img ==========
@app.post("/edit_image")
async def edit_image(request: EditImageRequest):
    try:
        image_data = base64.b64decode(request.image_base64)
        img = Image.open(BytesIO(image_data)).convert("RGB").resize((512, 512))

        async with pipe_semaphore:
            pipe2.scheduler.set_timesteps(50)
            result = pipe2(
                prompt=request.prompt,
                image=img,
                strength=0.75,
                guidance_scale=7.5,
                num_inference_steps=50,
            )

        edited_image = result.images[0]

        buffered = BytesIO()
        edited_image.save(buffered, format="PNG")
        img_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")

        return {"image_base64": img_base64}
    except Exception as e:
        traceback.print_exc()
        return JSONResponse(content={"error": str(e)}, status_code=500)

@app.post("/is_available")
async def check():
    return {"ready": True}

# ========== تشغيل الخادم ==========
nest_asyncio.apply()
uvicorn.run(app, host="0.0.0.0", port=7860)