diff --git a/.full-env.example b/.full-env.example index 45f2d9e..50ae9fe 100644 --- a/.full-env.example +++ b/.full-env.example @@ -18,5 +18,7 @@ SYSTEM_PROMPT="You are ChatGPT, a large language model trained by OpenAI. Respo TEMPERATURE=0.8 LC_ADMIN="@admin:xxxxxx.xxx,@admin2:xxxxxx.xxx" IMAGE_GENERATION_ENDPOINT="http://127.0.0.1:7860/sdapi/v1/txt2img" -IMAGE_GENERATION_BACKEND="sdwui" # openai or sdwui +IMAGE_GENERATION_BACKEND="sdwui" # openai or sdwui or localai +IMAGE_GENERATION_SIZE="512x512" +IMAGE_FORMAT="webp" TIMEOUT=120.0 diff --git a/full-config.json.example b/full-config.json.example index 4d7d708..77e6213 100644 --- a/full-config.json.example +++ b/full-config.json.example @@ -19,6 +19,8 @@ "system_prompt": "You are ChatGPT, a large language model trained by OpenAI. Respond conversationally", "lc_admin": ["@admin:xxxxx.org"], "image_generation_endpoint": "http://localai:8080/v1/images/generations", - "image_generation_backend": "openai", + "image_generation_backend": "localai", + "image_generation_size": "512x512", + "image_format": "webp", "timeout": 120.0 } diff --git a/src/bot.py b/src/bot.py index 7ee7590..8218535 100644 --- a/src/bot.py +++ b/src/bot.py @@ -68,22 +68,31 @@ def __init__( lc_admin: Optional[list[str]] = None, image_generation_endpoint: Optional[str] = None, image_generation_backend: Optional[str] = None, + image_generation_size: Optional[str] = None, + image_format: Optional[str] = None, timeout: Union[float, None] = None, ): if homeserver is None or user_id is None or device_id is None: - logger.warning("homeserver && user_id && device_id is required") + logger.error("homeserver && user_id && device_id is required") sys.exit(1) if password is None and access_token is None: - logger.warning("password is required") + logger.error("password is required") sys.exit(1) if image_generation_endpoint and image_generation_backend not in [ "openai", "sdwui", + "localai", None, ]: - logger.warning("image_generation_backend must be openai or sdwui") + logger.error("image_generation_backend must be openai or sdwui or localai") + sys.exit(1) + + if image_format not in ["jpeg", "webp", "png", None]: + logger.error( + "image_format should be jpeg or webp or png, leave blank for jpeg" + ) sys.exit(1) self.homeserver: str = homeserver @@ -115,6 +124,20 @@ def __init__( self.image_generation_endpoint: str = image_generation_endpoint self.image_generation_backend: str = image_generation_backend + if image_format: + self.image_format: str = image_format + else: + self.image_format = "jpeg" + + if image_generation_size is None: + self.image_generation_size = "512x512" + self.image_generation_width = 512 + self.image_generation_height = 512 + else: + self.image_generation_size = image_generation_size + self.image_generation_width = self.image_generation_size.split("x")[0] + self.image_generation_height = self.image_generation_size.split("x")[1] + self.timeout: float = timeout or 120.0 self.base_path = Path(os.path.dirname(__file__)).parent @@ -1333,20 +1356,19 @@ async def pic(self, room_id, prompt, replay_to_event_id, sender_id, user_message if self.image_generation_endpoint is not None: await self.client.room_typing(room_id, timeout=int(self.timeout) * 1000) # generate image - b64_datas = await imagegen.get_images( + image_path_list = await imagegen.get_images( self.httpx_client, self.image_generation_endpoint, prompt, self.image_generation_backend, timeount=self.timeout, api_key=self.openai_api_key, + output_path=self.base_path / "images", n=1, - size="256x256", - ) - image_path_list = await asyncio.to_thread( - imagegen.save_images, - b64_datas, - self.base_path / "images", + size=self.image_generation_size, + width=self.image_generation_width, + height=self.image_generation_height, + image_format=self.image_format, ) # send image for image_path in image_path_list: diff --git a/src/imagegen.py b/src/imagegen.py index 2214eac..8f059d9 100644 --- a/src/imagegen.py +++ b/src/imagegen.py @@ -7,9 +7,14 @@ async def get_images( - aclient: httpx.AsyncClient, url: str, prompt: str, backend_type: str, **kwargs + aclient: httpx.AsyncClient, + url: str, + prompt: str, + backend_type: str, + output_path: str, + **kwargs, ) -> list[str]: - timeout = kwargs.get("timeout", 120.0) + timeout = kwargs.get("timeout", 180.0) if backend_type == "openai": resp = await aclient.post( url, @@ -20,7 +25,7 @@ async def get_images( json={ "prompt": prompt, "n": kwargs.get("n", 1), - "size": kwargs.get("size", "256x256"), + "size": kwargs.get("size", "512x512"), "response_format": "b64_json", }, timeout=timeout, @@ -29,7 +34,7 @@ async def get_images( b64_datas = [] for data in resp.json()["data"]: b64_datas.append(data["b64_json"]) - return b64_datas + return save_images_b64(b64_datas, output_path, **kwargs) else: raise Exception( f"{resp.status_code} {resp.reason_phrase} {resp.text}", @@ -45,25 +50,56 @@ async def get_images( "sampler_name": kwargs.get("sampler_name", "Euler a"), "batch_size": kwargs.get("n", 1), "steps": kwargs.get("steps", 20), - "width": 256 if "256" in kwargs.get("size") else 512, - "height": 256 if "256" in kwargs.get("size") else 512, + "width": kwargs.get("width", 512), + "height": kwargs.get("height", 512), }, timeout=timeout, ) if resp.status_code == 200: b64_datas = resp.json()["images"] - return b64_datas + return save_images_b64(b64_datas, output_path, **kwargs) else: raise Exception( f"{resp.status_code} {resp.reason_phrase} {resp.text}", ) + elif backend_type == "localai": + resp = await aclient.post( + url, + headers={ + "Content-Type": "application/json", + "Authorization": f"Bearer {kwargs.get('api_key')}", + }, + json={ + "prompt": prompt, + "size": kwargs.get("size", "512x512"), + }, + timeout=timeout, + ) + if resp.status_code == 200: + image_url = resp.json()["data"][0]["url"] + return await save_image_url(image_url, aclient, output_path, **kwargs) -def save_images(b64_datas: list[str], path: Path, **kwargs) -> list[str]: - images = [] +def save_images_b64(b64_datas: list[str], path: Path, **kwargs) -> list[str]: + images_path_list = [] for b64_data in b64_datas: - image_path = path / (str(uuid.uuid4()) + ".jpeg") + image_path = path / ( + str(uuid.uuid4()) + "." + kwargs.get("image_format", "jpeg") + ) img = Image.open(io.BytesIO(base64.decodebytes(bytes(b64_data, "utf-8")))) img.save(image_path) - images.append(image_path) - return images + images_path_list.append(image_path) + return images_path_list + + +async def save_image_url( + url: str, aclient: httpx.AsyncClient, path: Path, **kwargs +) -> list[str]: + images_path_list = [] + r = await aclient.get(url) + image_path = path / (str(uuid.uuid4()) + "." + kwargs.get("image_format", "jpeg")) + if r.status_code == 200: + img = Image.open(io.BytesIO(r.content)) + img.save(image_path) + images_path_list.append(image_path) + return images_path_list diff --git a/src/main.py b/src/main.py index 07e7e9c..48bbceb 100644 --- a/src/main.py +++ b/src/main.py @@ -44,6 +44,8 @@ async def main(): lc_admin=config.get("lc_admin"), image_generation_endpoint=config.get("image_generation_endpoint"), image_generation_backend=config.get("image_generation_backend"), + image_generation_size=config.get("image_generation_size"), + image_format=config.get("image_format"), timeout=config.get("timeout"), ) if ( @@ -75,6 +77,8 @@ async def main(): lc_admin=os.environ.get("LC_ADMIN"), image_generation_endpoint=os.environ.get("IMAGE_GENERATION_ENDPOINT"), image_generation_backend=os.environ.get("IMAGE_GENERATION_BACKEND"), + image_generation_size=os.environ.get("IMAGE_GENERATION_SIZE"), + image_format=os.environ.get("IMAGE_FORMAT"), timeout=float(os.environ.get("TIMEOUT", 120.0)), ) if (