<a href="https://colab.research.google.com/github/dil27/HuggingFace-Diffusers-with-API-using-Flask-on-Google-Colab/blob/main/Diffusion_with_API.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#@markdown #Installation.
#@markdown Run this cell. No need to configure, I'm serious.

#@markdown ---
#@markdown This notebook is already built-in **EasyNegative** and with default **DPM++ 2M Karras** sampling method.

#@markdown ---
#@markdown If you're accidentally stop the cells, just start the **second** runtume. You don't need to run this cell again.
!pip install diffusers["torch"] transformers
!pip install accelerate
!pip install git+https://github.com/huggingface/diffusers
!pip install Flask pyngrok
!pip install flask-cors
!pip install peft safetensors

In [None]:
from flask import Flask, request, jsonify
from pyngrok import ngrok
from io import BytesIO
import base64
from PIL import Image
from flask_cors import CORS

import torch
from diffusers import StableDiffusionPipeline
from diffusers import DPMSolverMultistepScheduler

#@markdown #Setup
#@markdown --- Get your _ngrok_ token [here](https://dashboard.ngrok.com/get-started/your-authtoken)

NGROK_TOKEN = "<your_ngrok_token_here>" #@param {type:"string"}
!ngrok authtoken $NGROK_TOKEN

#@markdown --- After running this cell, you will get your runtime url in console

#@markdown example: `https://8fc2-34-125-49-157.ngrok-free.app/`

app = Flask(__name__)
CORS(app)

pipe = None
checkpoint = None

@app.route('/connect', methods=['POST'])
def connect():
    return jsonify({
        "msg": "Connected"
    })

@app.route('/loadcheckpoint', methods=['POST'])
def loadCheckpoint():
    global pipe
    global checkpoint
    data = request.get_json()
    checkpoint = data.get('checkpoint')

    pipe = StableDiffusionPipeline.from_pretrained(checkpoint, torch_dtype=torch.float16)
    pipe = pipe.to('cuda')
    pipe.safety_checker = None
    pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
    pipe.load_textual_inversion(
        "embed/EasyNegative",
        weight_name="EasyNegative.safetensors",
        token="EasyNegative"
    )

    # pipe.load_lora_weights(
    #     "Shalie/GenshinImpactSigewinne",
    #     weight_name="spgiSigewinneXLPony.safetensors",
    #     adapter_name="sigewinnedef"
    # )
    # pipe.set_adapters(
    #     [
    #         "sigewinnedef"
    #     ],
    #     adapter_weights=[
    #         0.7
    #     ]
    # )

    return jsonify({
        "msg": "Checkpoint successfully loaded",
        "checkpoint": checkpoint
    })

@app.route('/checkpipe', methods=['POST'])
def checkPipe():
    global pipe
    global checkpoint
    if not pipe:
        return jsonify({"error":"Pipeline is not loaded. Please load checkpoint first"})

    return jsonify({"pipe": checkpoint})

@app.route('/txt2img', methods=['POST'])
def generate():
  try:
    global pipe
    global checkpoint
    if not pipe:
        return jsonify({"error":"Pipeline is not loaded. Please load checkpoint first"})

    data = request.get_json()

    prompt     = data.get('prompt')
    neg        = data.get('neg')
    seed       = data.get('seed')
    width      = data.get('width')
    height     = data.get('height')
    denoise    = data.get('sampling')
    guidance   = data.get('guidance')
    checkpoint = data.get('checkpoint')

    generator = torch.Generator(device="cuda").manual_seed(seed)

    image = pipe(
        prompt=prompt,
        negative_prompt=neg,
        width=width,
        height=height,
        num_inference_steps=denoise,
        guidance_scale=guidance,
        added_cond_kwargs={"text_time": None},
        generator=generator,
    ).images[0]

    buffered = BytesIO()
    image.save(buffered, format="PNG")
    img_str = base64.b64encode(buffered.getvalue()).decode()

    return jsonify({
        "data": {
            "checkpoint": checkpoint,
            "prompt": prompt,
            "negative_prompt": neg,
            "width": width,
            "height": height,
            "denoise": denoise,
            "guidance": guidance,
            "seed": seed
        },
        "image": img_str
    })

  except Exception as e:
      return jsonify({"error": str(e)}), 400

ngrok_tunnel = ngrok.connect(5000)
print(' * Ngrok URL:', ngrok_tunnel.public_url)

if __name__ == '__main__':
    app.run(port=5000)
