Skip to content

Commit

Permalink
Merge pull request #6 from maltob/slack_bot_helper
Browse files Browse the repository at this point in the history
Slack bot helper
  • Loading branch information
maltob committed Oct 22, 2022
2 parents 41d5927 + e083889 commit 9738d11
Show file tree
Hide file tree
Showing 4 changed files with 171 additions and 103 deletions.
14 changes: 14 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,20 @@
This bot allows for you to have access to [stable diffusion](https://huggingface.co/blog/stable_diffusion#:~:text=Stable%20Diffusion%20%F0%9F%8E%A8...using%20%F0%9F%A7%A8%20Diffusers%20Stable%20Diffusion%20is,images%20from%20a%20subset%20of%20the%20LAION-5B%20database.) at your fingertips!


## Use

1. @ the bot with the text you would like to generate
2. It will send a photo to you shortly after

### Advanced use
1. You can send negative prompts to filter out items you would like to exclude by using -= or -- at the end of your prompt like below

##### -= Additive negative prompt
**soup -= chicken, turkey, tomato** would inform stable diffusion to avoid generating soups with chicken, turkey or tomato. This will still include any default negative prompts from the .env file.

##### -- Exclusive negative prompt
**soup -- chicken, turkey, tomato** would inform stable diffusion to avoid generating soups with chicken, turkey or tomato. This will override any default negative prompts from the .env file.

## Installation
1. Clone this repository to a folder with at least 5 GB free
```
Expand Down
128 changes: 25 additions & 103 deletions bot.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from cmath import log
import os
import logging
import sched
Expand All @@ -10,6 +11,8 @@
from slack_bolt import App
from slack_bolt.adapter.socket_mode import SocketModeHandler
from slack_sdk.errors import SlackApiError
from bot_slack_helper import delete_bot_file, delete_old_files, get_prompts
from bot_config_helper import get_generation_time, get_pipe, get_sd_dimensions,get_num_interations,get_negative_prompt,get_guidance_scale,get_scheduler,get_pipe

from dotenv import load_dotenv

Expand All @@ -23,15 +26,7 @@
#Todo change to a locking counter
sd_running_jobs = tsCounter()

#Load in config
approved_delete_users = os.environ.get("SLACK_ALLOWED_DELETE").split(",")
img_height= 512
img_width= 512
model_path="CompVis/stable-diffusion-v1-4"
generation_time = 45
guidance_scale = 7.5
num_inference_steps = 50
negative_prompt =""


#setup logger
logger = logging.getLogger(__name__)
Expand All @@ -45,73 +40,29 @@
logger.addHandler(logging_file_handler)
logger.addHandler(console_log_handler)


#Load in config
approved_delete_users = os.environ.get("SLACK_ALLOWED_DELETE").split(",")
img_height= 512
img_width= 512
model_path="CompVis/stable-diffusion-v1-4"
generation_time = 45
guidance_scale = get_guidance_scale(logger,7.5)
num_inference_steps = get_num_interations(logger,50)
negative_prompt = get_negative_prompt(logger,"")

#Environment config
try:
if os.environ.get("SD_IMG_HEIGHT") and len(os.environ.get("SD_IMG_HEIGHT")) >0 :
t_env_height = int(os.environ.get("SD_IMG_HEIGHT"))
t_env_width = int(os.environ.get("SD_IMG_WIDTH"))
img_width = t_env_width
img_height = t_env_height
logger.debug(f"Loaded height {img_height} and width {img_width} from environment")
except:
logger.warning(f"Failed to load height and width from environment. Falling back to defaults.")
print("Error loading the height and width from variables")

img_height,img_width = get_sd_dimensions(logger,img_height,img_width)

if os.environ.get("SD_MODEL_PATH") and len(os.environ.get("SD_MODEL_PATH")) > 0:
model_path = os.environ.get("SD_MODEL_PATH")
logger.debug(f"Set model path to {model_path}")

if os.environ.get("SD_NEGATIVE_PROMPT") and len(os.environ.get("SD_NEGATIVE_PROMPT")) > 0:
negative_prompt = os.environ.get("SD_NEGATIVE_PROMPT")
logger.debug(f"Set negative prompt to {negative_prompt}")


if os.environ.get("SD_ITERATIONS") and len(os.environ.get("SD_ITERATIONS")) > 0:
try:
num_inference_steps = int(os.environ.get("SD_ITERATIONS"))
logger.debug(f"Set number of inference steps to {num_inference_steps}")
except:
logger.debug(f"Failed to parse number of inference steps")

if os.environ.get("SD_GUIDANCE_SCALE") and len(os.environ.get("SD_GUIDANCE_SCALE")) > 0:
try:
guidance_scale = float(os.environ.get("SD_GUIDANCE_SCALE"))
logger.debug(f"Set guidance scale to {guidance_scale}")
except:
logger.debug(f"Failed to parse guidance scale")

#Build the StableDiffusionPipeline
pipe = None
scheduler = DDIMScheduler()

if os.environ.get("SD_SCHEDULER") and len(os.environ.get("SD_SCHEDULER")) > 2:
if os.environ.get("SD_SCHEDULER").upper() == "LMS":
scheduler = LMSDiscreteScheduler()
logger.debug(f"Using LMS Scheduler")
if os.environ.get("SD_SCHEDULER").upper() == "PNDM":
scheduler = PNDMScheduler()
logger.debug(f"Using PNDM Scheduler")
if os.environ.get("SD_SCHEDULER").upper() == "KERRASVE":
scheduler = KarrasVeScheduler()
logger.debug(f"Using KerrasVe Scheduler")

if os.environ.get("SD_PRECISION") and len(os.environ.get("SD_PRECISION"))>0 and os.environ.get("SD_PRECISION").lower() == "fp16":
logger.debug(f"Using fp16 precision")
if model_path.startswith(".") :
pipe = StableDiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float16, revision="fp16")
else:
pipe = StableDiffusionPipeline.from_pretrained(model_path, use_auth_token=os.environ.get("SD_MODEL_AUTH_TOKEN"), torch_dtype=torch.float16, revision="fp16")
else:
if model_path.startswith(".") :
pipe = StableDiffusionPipeline.from_pretrained(model_path)
else:
pipe = StableDiffusionPipeline.from_pretrained(model_path, use_auth_token=os.environ.get("SD_MODEL_AUTH_TOKEN"))

if torch.cuda.is_available() :
logger.debug(f"Using cuda")
pipe = pipe.to("cuda")

scheduler = get_scheduler(logger,DDIMScheduler())
pipe = get_pipe(logger,model_path=model_path)
pipe.enable_attention_slicing()

app = App(
Expand All @@ -122,10 +73,10 @@


@app.event("app_mention")
def event_test(event, say,client):
def app_mention(event, say,client):
txt = event["text"]
user = event["user"]
str_txt = txt[(txt.find(" "))+1:]
str_txt,neg_txt = get_prompts(logger,txt,negative_prompt)
channel = event["channel"]
# Get estimated time to generate the image
sd_running_jobs.increment()
Expand All @@ -136,7 +87,7 @@ def event_test(event, say,client):

#Generate an image and upload it to slack, then delete the info message
generation_lock.acquire()
image = pipe(txt, height=img_height,width=img_width,guidance_scale=guidance_scale,negative_prompt=negative_prompt,num_inference_steps=num_inference_steps).images[0]
image = pipe(str_txt, height=img_height,width=img_width,guidance_scale=guidance_scale,negative_prompt=neg_txt,num_inference_steps=num_inference_steps).images[0]
generation_lock.release()
sd_running_jobs.decrement()
fp = str_txt.replace(",","_").replace("/","_").replace("\\","_").replace(":","_").replace(".","_")
Expand Down Expand Up @@ -168,46 +119,17 @@ def handle_reaction_added_events(body):
item = body["event"]["item"]
ts = item["ts"]
user = body["event"]["user"]
#print(item)
logger.info(f"Searching for images to delete at time {ts} due to reaction by {user}")
delete_bot_file(channel=item["channel"],ts=item["ts"])

def ack_shortcut(ack):
ack()

#Used to cleanup a file uploaded in a channel at timestamp when requested
def delete_bot_file(channel,ts):
files = app.client.files_list(token=os.environ.get("SLACK_BOT_TOKEN"),channel=channel,ts_from=(int(float(ts))-1),ts_to=(int(float(ts)))+1)
myprof =app.client.users_profile_get()
for file in files["files"]:
user_prof = app.client.users_profile_get(token=os.environ.get("SLACK_BOT_TOKEN"),user=file["user"])
if file["name"].find("uf_") == 0 and "bot_id" in user_prof["profile"] and user_prof["profile"]["bot_id"] == myprof["profile"]["bot_id"]:
logger.info("Deleting "+file["name"]+" at user request")
app.client.files_delete(token=os.environ.get("SLACK_BOT_TOKEN"),file=file["id"])

#Used to clean up all files
def delete_old_files():
myprof =app.client.users_profile_get()
files = app.client.files_list(token=os.environ.get("SLACK_BOT_TOKEN"))
for file in files["files"]:
#print(file)
user_prof = app.client.users_profile_get(token=os.environ.get("SLACK_BOT_TOKEN"),user=file["user"])
#print(file["name"])
if file["name"].find("uf_") == 0 and "bot_id" in user_prof["profile"] and user_prof["profile"]["bot_id"] == myprof["profile"]["bot_id"]:
logger.info("File Cleanup - deleting "+file["name"])
#app.client.files_delete(token=os.environ.get("SLACK_BOT_TOKEN"),file=file["id"])
delete_bot_file(app,channel=item["channel"],ts=item["ts"],logger=logger)



if os.environ.get("SD_BENCHMARK") and os.environ.get("SD_BENCHMARK").lower()=="true":
logger.info("Running benchmark")
start_ns = monotonic_ns()
pipe("squid", height=img_height,width=img_width,guidance_scale=guidance_scale,negative_prompt=negative_prompt,num_inference_steps=num_inference_steps,seed=42)
end_ns = monotonic_ns()
generation_time = int((end_ns-start_ns)/1_000_000_000) + 5
generation_time = get_generation_time(logger=logger,pipe=pipe,img_height=img_height,img_width=img_width,guidance_scale=guidance_scale,negative_prompt=negative_prompt,num_inference_steps=num_inference_steps)
logger.info(f"Completed will report {generation_time} seconds of gen time")

# Start your app
if __name__ == "__main__":
#delete_old_files()
#delete_old_files(app,logger)
logger.info("Starting app")
SocketModeHandler(app, os.environ["SLACK_APP_TOKEN"]).start()
87 changes: 87 additions & 0 deletions bot_config_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import os
from time import monotonic_ns
from diffusers import StableDiffusionPipeline, DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler, KarrasVeScheduler
from torch import torch

def get_sd_dimensions(logger,default_height,default_width):
t_env_height = default_height
t_env_width = default_width
try:
if os.environ.get("SD_IMG_HEIGHT") and len(os.environ.get("SD_IMG_HEIGHT")) >0 :
t_env_height = int(os.environ.get("SD_IMG_HEIGHT"))
t_env_width = int(os.environ.get("SD_IMG_WIDTH"))
logger.debug(f"Loaded height {t_env_height} and width {t_env_width} from environment")
except:
logger.warning(f"Failed to load height and width from environment. Falling back to defaults of {default_height} {default_width}.")
return (t_env_height,t_env_width)


def get_num_interations(logger,default_iterations):
t_iterations = default_iterations
if os.environ.get("SD_ITERATIONS") and len(os.environ.get("SD_ITERATIONS")) > 0:
try:
t_iterations = int(os.environ.get("SD_ITERATIONS"))
logger.debug(f"Set number of inference steps to {t_iterations}")
except:
logger.debug(f"Failed to parse number of inference steps")
return t_iterations

def get_negative_prompt(logger,default_negative_prompt):
t_neg_prompt = default_negative_prompt
if os.environ.get("SD_NEGATIVE_PROMPT") and len(os.environ.get("SD_NEGATIVE_PROMPT")) > 0:
t_neg_prompt = os.environ.get("SD_NEGATIVE_PROMPT")
logger.debug(f"Set negative prompt to {t_neg_prompt}")
return t_neg_prompt

def get_guidance_scale(logger,default_guidance):
t_guidance = default_guidance
if os.environ.get("SD_GUIDANCE_SCALE") and len(os.environ.get("SD_GUIDANCE_SCALE")) > 0:
try:
t_guidance = float(os.environ.get("SD_GUIDANCE_SCALE"))
logger.debug(f"Set guidance scale to {t_guidance}")
except:
logger.debug(f"Failed to parse guidance scale")
return t_guidance

def get_scheduler(logger,default_scheduler):
t_scheduler = default_scheduler
if os.environ.get("SD_SCHEDULER") and len(os.environ.get("SD_SCHEDULER")) > 2:
if os.environ.get("SD_SCHEDULER").upper() == "LMS":
t_scheduler = LMSDiscreteScheduler()
logger.debug(f"Using LMS Scheduler")
if os.environ.get("SD_SCHEDULER").upper() == "PNDM":
t_scheduler = PNDMScheduler()
logger.debug(f"Using PNDM Scheduler")
if os.environ.get("SD_SCHEDULER").upper() == "KERRASVE":
t_scheduler = KarrasVeScheduler()
logger.debug(f"Using KerrasVe Scheduler")
if os.environ.get("SD_SCHEDULER").upper() == "DDIM":
t_scheduler = DDIMScheduler()
logger.debug(f"Using DDIM Scheduler")
return t_scheduler

def get_pipe(logger,model_path):
t_pipe = None
if os.environ.get("SD_PRECISION") and len(os.environ.get("SD_PRECISION"))>0 and os.environ.get("SD_PRECISION").lower() == "fp16":
logger.debug(f"Using fp16 precision")
if model_path.startswith(".") :
t_pipe = StableDiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float16, revision="fp16")
else:
t_pipe = StableDiffusionPipeline.from_pretrained(model_path, use_auth_token=os.environ.get("SD_MODEL_AUTH_TOKEN"), torch_dtype=torch.float16, revision="fp16")
else:
if model_path.startswith(".") :
t_pipe = StableDiffusionPipeline.from_pretrained(model_path)
else:
t_pipe = StableDiffusionPipeline.from_pretrained(model_path, use_auth_token=os.environ.get("SD_MODEL_AUTH_TOKEN"))
#Always add CUDA if available
if torch.cuda.is_available() :
logger.debug(f"Using cuda")
t_pipe = t_pipe.to("cuda")
return t_pipe

def get_generation_time(logger,pipe,img_height,img_width,guidance_scale,negative_prompt,num_inference_steps,):
logger.info("Running benchmark")
start_ns = monotonic_ns()
pipe("squid", height=img_height,width=img_width,guidance_scale=guidance_scale,negative_prompt=negative_prompt,num_inference_steps=num_inference_steps,seed=42)
end_ns = monotonic_ns()
return (int((end_ns-start_ns)/1_000_000_000) + 5)
45 changes: 45 additions & 0 deletions bot_slack_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@

import os
import logging

#Used to clean up all files
def delete_old_files(app,logger):
myprof =app.client.users_profile_get()
files = app.client.files_list(token=os.environ.get("SLACK_BOT_TOKEN"))
for file in files["files"]:
#print(file)
user_prof = app.client.users_profile_get(token=os.environ.get("SLACK_BOT_TOKEN"),user=file["user"])
#print(file["name"])
if file["name"].find("uf_") == 0 and "bot_id" in user_prof["profile"] and user_prof["profile"]["bot_id"] == myprof["profile"]["bot_id"]:
logger.info("File Cleanup - deleting "+file["name"])
#app.client.files_delete(token=os.environ.get("SLACK_BOT_TOKEN"),file=file["id"])

#Used to cleanup a file uploaded in a channel at timestamp when requested
def delete_bot_file(app,channel,ts,logger):
files = app.client.files_list(token=os.environ.get("SLACK_BOT_TOKEN"),channel=channel,ts_from=(int(float(ts))-1),ts_to=(int(float(ts)))+1)
myprof =app.client.users_profile_get()
for file in files["files"]:
user_prof = app.client.users_profile_get(token=os.environ.get("SLACK_BOT_TOKEN"),user=file["user"])
if file["name"].find("uf_") == 0 and "bot_id" in user_prof["profile"] and user_prof["profile"]["bot_id"] == myprof["profile"]["bot_id"]:
logger.info("Deleting "+file["name"]+" at user request")
app.client.files_delete(token=os.environ.get("SLACK_BOT_TOKEN"),file=file["id"])

def get_prompts(logger, txt, default_negative_prompt):
#Assume @ for the bot is first and chop it off
t_str_txt = txt[(txt.find(" "))+1:]
t_neg_txt = default_negative_prompt
#Check for either of the negative prompt styles and set the prompt
try :
indx_neg_prompt = t_str_txt.index("--")
t_neg_txt = t_str_txt[indx_neg_prompt+2:]
t_str_txt = t_str_txt[:indx_neg_prompt]
logger.debug(f"Detected negative prompt of {t_neg_txt} ")
except:
try :
indx_neg_prompt = t_str_txt.index("-=")
t_neg_txt = f"{default_negative_prompt},{t_str_txt[indx_neg_prompt+2:]}"
t_str_txt = t_str_txt[:indx_neg_prompt]
logger.debug(f"Detected negative prompt of {t_neg_txt} ")
except:
logger.debug(f"{t_str_txt} has no negative prompt, using default of {t_neg_txt}")
return t_str_txt,t_neg_txt

0 comments on commit 9738d11

Please sign in to comment.