### License:
Copyright 2023, Jozsef Szalma <br>
Creative Commons Attribution-NonCommercial 4.0 International Public License<br>
https://creativecommons.org/licenses/by-nc/4.0/legalcode

### Set up: 
1. Environment variables on the backend (e.g. in a .env file) 
- HF_KEY: Your Hugging Face API key 
- IMG_DIR_WIN and IMG_DIR_DOCKER: Location to store the generated images
- PROMPT_PREFIX and PROMPT_SUFFIX: Optional, if you want to prefix or suffix the prompt with anything (e.g. cartoonish, kid-friendly)
- NEGATIVE_PROMPT: Optional, but should be used for parental controls (e.g. add "scary" to prevent convergence on scary images, the same with NSFW concepts, etc.)
- MODEL_ID: Optional, Hugging Face model ID, using SD 2.1 if not defined

2. set a fixed LAN IP address on the machine running the backend and expose port 5000 to your **intra**net

3. set up the IP address of the backend on the mobile app under the kebab menu (look for ⋮ in the upper right corner)

4. As of now, to get the mobile app running, you need to set up a React Native development environment, compile the app from source and load the .apk onto an Android device using developer mode.<br>
Here is a handy guide: https://reactnative.dev/docs/environment-setup?guide=native


### Known issues and Disclaimers:
- This is a hobby prototype that takes quite a bit of tech skills to get to work and is not production ready. You shouldn't use it if you don't understand the technology involved. Read the license terms, especially Section 5 – Disclaimer of Warranties and Limitation of Liability.
- I couldn't test if Docker works at all, as my NVIDIA drivers do not want to play with Docker in my Windows Linux Subsystem
- The mobile app still has the default Android icon and is named "mobile_client"
- Minimal security (not making any attempts to sanitize inputs or authenticate clients), the backend is only intended to be used behind a NAT router for demo purposes, not ready to be exposed to the Internet. 
- I recommend setting up an extensive negative prompt as parental controls, in addition to using the Stability safety filter, and not letting kids play with diffusion models without adult supervision, as **most of these models will produce age-inappropriate content** with minimal effort and curiosity. 

In [None]:
import os
import uuid
from flask import Flask, request, send_from_directory
from flask_restful import Api, Resource

from huggingface_hub import login
from diffusers import StableDiffusionPipeline
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from transformers import CLIPFeatureExtractor

import torch

import openai
import json

#huggingface key
hf_key = os.getenv("HF_KEY")

#output directory
image_dir = os.getenv("IMG_DIR_WIN") 

#optional, if you want to prefix or suffix the prompt with anything (e.g. cartoonish, kids friendly)
prompt_prefix = os.getenv("PROMPT_PREFIX", "")
prompt_suffix = os.getenv("PROMPT_SUFFIX", "")

#optional, can be used for parental controls (e.g. add "scary" to prevent convergence on scary images, et cetera)
negative_prompt = os.getenv("NEGATIVE_PROMPT","")

login(token=hf_key,add_to_git_credential=True)

#Hugging Face model ID, using SD 2.1 if not defined in env
model_id = os.getenv("MODEL_ID","stabilityai/stable-diffusion-2-1") 


#parental control model's parameters (GPT3)
openai.api_key = os.getenv("OPENAI_KEY")
default_inhibitor = os.getenv("PROMPT_INHIBITOR")
inhibitor_model = "text-davinci-003"
inhibitor_temp = 0.0
inhibitor_token_limit = 200

blank_prompt = "clouds in the sky"

In [None]:
#API calls to OpenAI

import re

def get_gpt (prompt):
    
    inhibitor_response = openai.Completion.create(
            model = inhibitor_model,
            prompt = default_inhibitor + " " + prompt,
            temperature = inhibitor_temp,
            max_tokens = inhibitor_token_limit,
            n = 1
        )
    
    
    inhibitor_message = inhibitor_response.choices[0].text 
    json_text = re.search(r"\{.*\}", inhibitor_message).group(0)
    print ('user input: ', prompt)
    print ('inhibitor''s response',inhibitor_message)
    try:
        #inhibitor_decision = json.loads(inhibitor_message)['decision']
        #inhibitor_explanation = json.loads(inhibitor_message)['explanation']
        inhibitor_prompt = json.loads(json_text)['revised_prompt']

    except:
        print('inhibitor returned malformed response: ', inhibitor_message)
        return prompt    
    
    return inhibitor_prompt

In [None]:
if model_id == "stabilityai/stable-diffusion-2-1":
    #SD 2.1 does not have the safety checker by default
    pipe = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token=True, 
                                                        safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker",torch_dtype=torch.float16),
                                                        feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32",torch_dtype=torch.float16),
                                                        torch_dtype=torch.float16)


else:
    pipe = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token=True, 
                                                    torch_dtype=torch.float16)

pipe = pipe.to("cuda")

In [None]:
app = Flask(__name__)
api = Api(app)

#image generator api, generates the image and returns a GUID that acts as key for image retrieval 
class ImageGeneration(Resource):
    def post(self):
        data = request.get_json()
        prompt = data['prompt']
        if prompt == "" : prompt = blank_prompt
        prompt = get_gpt(prompt)
        image_id = str(uuid.uuid4())

        result = pipe(prompt = prompt_prefix + " " + prompt + " " + prompt_suffix,
                      negative_prompt = negative_prompt)

        nsfw_loop_count = 0
        while result.nsfw_content_detected[0]:
            result = pipe(prompt = prompt_prefix + " " + prompt + " " + prompt_suffix,
                      negative_prompt = negative_prompt)
            nsfw_loop_count += 1
            if nsfw_loop_count > 10 : 
                result = pipe(prompt = "",
                      negative_prompt = negative_prompt)
                break
            
            
        image = result.images[0]
        image.save(os.path.join(image_dir, f"{image_id}.png"))
        
        print("nsfw? ", result.nsfw_content_detected[0])
        print("prompt: ", prompt)
        print("image id: ", image_id)
       
        return {'guid': image_id}

#image retrieval api, serves the image that matches the GUID provided
class ImageRetrieval(Resource):
    def get(self):
        image_id = request.args.get('guid')
        print(image_id)
        return send_from_directory(image_dir, f"{image_id}.png")

api.add_resource(ImageGeneration, '/generate/')
api.add_resource(ImageRetrieval, '/image/')




In [None]:
app.run(debug=True, use_reloader=False, host='0.0.0.0')