In [None]:
!mkdir templates static 
!touch templates/index.html static/script.js static/style.css

In [None]:
!curl -o templates/index.html https://raw.githubusercontent.com/lucas-wa/flask-server/main/templates/index.html
!curl -o static/script.js https://raw.githubusercontent.com/lucas-wa/flask-server/main/static/script.js
!curl -o static/style.css https://raw.githubusercontent.com/lucas-wa/flask-server/main/static/style.css

In [None]:
!pip install flask-ngrok
!pip install pyngrok
!ngrok config add-authtoken YOUR_NGROK_KEY
!pip install -U flask-cors

In [None]:
!pip install diffusers==0.11.1
!pip install transformers scipy ftfy accelerate

In [None]:
import torch
from diffusers import StableDiffusionPipeline

pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16)
pipe = pipe.to("cuda")

In [None]:
import os
import io
import base64
import queue
import threading
from PIL import Image

from flask import Flask, render_template, request, jsonify
from werkzeug.utils import secure_filename
from pyngrok import ngrok
from flask_cors import CORS

from inference_realesrgan import upscale_image


request_queue = queue.Queue()

app = Flask(__name__)
cors = CORS(app, resources={r"/*": {"origins": "*"}})
port = 5000



public_url = ngrok.connect(port).public_url
print(" * ngrok tunnel \"{}\" -> \"http://127.0.0.1:{}\"".format(public_url, port))

# Update any base URLs to use the public ngrok URL
app.config["BASE_URL"] = public_url
app.config['UPLOAD_FOLDER'] = 'images/'

def generate_image(prompt):
  try:
    image = pipe(prompt).images[0]
    buffered = io.BytesIO()
    image.save(buffered, format="PNG")
    img_str = base64.b64encode(buffered.getvalue())
    b = "data:image/png;base64," + str(img_str)[2:-1]
    return b
  except Exception as e:
    raise Exception("Image could not be generated!")
    return str(e)


def process_requests():
    while True:
      try:
        req, service, response_queue, status_code = request_queue.get()

        if service == "generator":

          # Process the request here and generate the response
          prompt = req["prompt"]
          response = generate_image(prompt)
          status_code = 200
          response_queue.put((response, status_code))

        elif service == "upscale":

          # Process the request here and generate the response
          file = req['file']
          file = base64.b64decode(file)
          filename = "image.png"
          image = upscale_image(file, filename)
          image = Image.fromarray(image)
          buffered = io.BytesIO()
          image.save(buffered, format="PNG")
          img_str = base64.b64encode(buffered.getvalue())
          response = "data:image/png;base64," + str(img_str)[2:-1]
          response_queue.put((response, status_code))

      except Exception as e:
        print(e)
        response = "Internal server error. Image couldn't be generated"
        status_code = 500
        response_queue.put((response, status_code))


@app.route("/", methods = ['GET', 'POST'])
def index():
  
  return render_template("index.html")


@app.route("/generator", methods = ['POST'])
def image_generator():
  try:
    status_code = 200
    req = request.get_json()
    prompt = req["prompt"]
    service = req["service"]
    if ("prompt" not in req) or (prompt == '')  or 'service' not in req:
          return jsonify({"error": "Prompt or service is missing"}), 400

    response_queue = queue.Queue()
    request_queue.put((req, service, response_queue, status_code))
    response, status_code = response_queue.get()
    if(status_code == 500):
        return jsonify({"error": "Internal server error"}), 500
    return jsonify({"image_raw": response})
  except Exception as e:
    print(e)
    return jsonify({"error": e})
    

@app.route("/upscale", methods = ['POST'])
def upload_image():
  try:
    req = request.get_json()

    if 'file' not in req or 'service' not in req:
      return jsonify({"error": "No file or service"}), 400
      
    file = req['file']
    service = req["service"]


    status_code = 200
    response_queue = queue.Queue()
    request_queue.put((req, service, response_queue, status_code))
    reponse, status_code = response_queue.get()
    return jsonify({"image_raw": reponse}), 200
  
  except Exception as e:
    print(e)
    return jsonify({"error": "Internal server error"}), 500


tr_stable = threading.Thread(target=process_requests)
tr_stable.start()

app.run()