Skip to content

Commit

Permalink
dalle3 production grade ready
Browse files Browse the repository at this point in the history
  • Loading branch information
kyegomez committed Nov 10, 2023
1 parent d26531f commit 41e5f17
Show file tree
Hide file tree
Showing 18 changed files with 204 additions and 9 deletions.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
24 changes: 24 additions & 0 deletions playground/models/dalle3_concurrent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""
User task ->> GPT4 for prompt enrichment ->> Dalle3V for image generation
->> GPT4Vision for image captioning ->> Dalle3 better image
"""
from swarms.models.dalle3 import Dalle3
import os

api_key = os.environ["OPENAI_API_KEY"]

dalle3 = Dalle3(openai_api_key=api_key, n=1)

# task = "Swarm of robots working super industrial ambience concept art"

# image_url = dalle3(task)

tasks = ["A painting of a dog", "A painting of a cat"]
results = dalle3.process_batch_concurrently(tasks)

# print(results)



3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"

[tool.poetry]
name = "swarms"
version = "2.1.6"
version = "2.1.7"
description = "Swarms - Pytorch"
license = "MIT"
authors = ["Kye Gomez <kye@apac.ai>"]
Expand Down Expand Up @@ -35,6 +35,7 @@ langchain-experimental = "*"
playwright = "*"
duckduckgo-search = "*"
faiss-cpu = "*"
backoff = "*"
datasets = "*"
diffusers = "*"
accelerate = "*"
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ tabulate
colored
griptape
addict
backoff
ratelimit
albumentations
basicsr
Expand Down
2 changes: 2 additions & 0 deletions swarms/models/autotemp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from concurrent.futures import ThreadPoolExecutor, as_completed
from swarms.models.auto_temp import OpenAIChat


class AutoTempAgent:
"""
AutoTemp is a tool for automatically selecting the best temperature setting for a given task.
Expand Down Expand Up @@ -31,6 +32,7 @@ class AutoTempAgent:
Generate a 10,000 word blog on mental clarity and the benefits of meditation.
"""

def __init__(
self,
temperature: float = 0.5,
Expand Down
178 changes: 171 additions & 7 deletions swarms/models/dalle3.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
import concurrent.futures
import logging
import os
import uuid
from dataclasses import dataclass
from io import BytesIO
from typing import List, Optional

import backoff
import openai
import requests
from cachetools import TTLCache
from dotenv import load_dotenv
from openai import OpenAI
from PIL import Image
Expand All @@ -19,6 +25,17 @@
logger = logging.getLogger(__name__)



def handle_errors(self, function):
def wrapper(*args, **kwargs):
try:
return function(*args, **kwargs)
except Exception as error:
logger.error(error)
raise
return wrapper


@dataclass
class Dalle3:
"""
Expand Down Expand Up @@ -49,12 +66,26 @@ class Dalle3:
size: str = "1024x1024"
max_retries: int = 3
quality: str = "standard"
api_key: str = None
n: int = 4
openai_api_key: str = None
n: int = 1
save_path: str = "images"
max_time_seconds: int = 60
save_folder: str = "images"
image_format: str = "png"
client = OpenAI(
api_key=api_key,
max_retries=max_retries,
api_key=openai_api_key,
)
cache = TTLCache(maxsize=100, ttl=3600)
dashboard: bool = False

def __post_init__(self):
"""Post init method"""
if self.openai_api_key is None:
raise ValueError("Please provide an openai api key")
if self.img is not None:
self.img = self.convert_to_bytesio(self.img)

os.makedirs(self.save_path, exist_ok=True)

class Config:
"""Config class for the Dalle3 model"""
Expand Down Expand Up @@ -84,8 +115,8 @@ def convert_to_bytesio(self, img: str, format: str = "PNG"):
img.save(byte_stream, format=format)
byte_array = byte_stream.getvalue()
return byte_array

# @lru_cache(maxsize=32)
@backoff.on_exception(backoff.expo, Exception, max_time=max_time_seconds)
def __call__(self, task: str):
"""
Text to image conversion using the Dalle3 API
Expand All @@ -108,6 +139,10 @@ def __call__(self, task: str):
>>> print(image_url)
https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png
"""
if self.dashboard:
self.print_dashboard()
if task in self.cache:
return self.cache[task]
try:
# Making a call to the the Dalle3 API
response = self.client.images.generate(
Expand All @@ -119,7 +154,16 @@ def __call__(self, task: str):
)
# Extracting the image url from the response
img = response.data[0].url
return img

filename = f"{self._generate_uuid()}.{self.image_format}"

# Download and save the image
self._download_image(img, filename)

img_path = os.path.join(self.save_path, filename)
self.cache[task] = img_path

return img_path
except openai.OpenAIError as error:
# Handling exceptions and printing the errors details
print(
Expand All @@ -133,6 +177,29 @@ def __call__(self, task: str):
)
raise error

def _generate_image_name(self, task: str):
"""Generate a sanitized file name based on the task"""
sanitized_task = "".join(
char for char in task if char.isalnum() or char in " _ -"
).rstrip()
return f"{sanitized_task}.{self.image_format}"

def _download_image(self, img_url: str, filename: str):
"""
Download the image from the given URL and save it to a specified filename within self.save_path.
Args:
img_url (str): URL of the image to download.
filename (str): Filename to save the image.
"""
full_path = os.path.join(self.save_path, filename)
response = requests.get(img_url)
if response.status_code == 200:
with open(full_path, 'wb') as file:
file.write(response.content)
else:
raise ValueError(f"Failed to download image from {img_url}")

def create_variations(self, img: str):
"""
Create variations of an image using the Dalle3 API
Expand Down Expand Up @@ -176,3 +243,100 @@ def create_variations(self, img: str):
print(colored(f"Error running Dalle3: {error.http_status}", "red"))
print(colored(f"Error running Dalle3: {error.error}", "red"))
raise error

def print_dashboard(
self
):
"""Print the Dalle3 dashboard"""
print(
colored(
(
f"""Dalle3 Dashboard:
--------------------
Model: {self.model}
Image: {self.img}
Size: {self.size}
Max Retries: {self.max_retries}
Quality: {self.quality}
N: {self.n}
Save Path: {self.save_path}
Time Seconds: {self.time_seconds}
Save Folder: {self.save_folder}
Image Format: {self.image_format}
--------------------
"""
),
"green",
)
)

def process_batch_concurrently(
self,
tasks: List[str],
max_workers: int = 5
):
"""
Process a batch of tasks concurrently
Args:
tasks (List[str]): A list of tasks to be processed
max_workers (int): The maximum number of workers to use for the concurrent processing
Returns:
--------
results (List[str]): A list of image urls generated by the Dalle3 API
Example:
--------
>>> dalle3 = Dalle3()
>>> tasks = ["A painting of a dog", "A painting of a cat"]
>>> results = dalle3.process_batch_concurrently(tasks)
>>> print(results)
['https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png',
"""
with concurrent.futures.ThreadPoolExecutor(
max_workers=max_workers
) as executor:
future_to_task = {executor.submit(self, task): task for task in tasks}
results = []
for future in concurrent.futures.as_completed(future_to_task):
task = future_to_task[future]
try:
img = future.result()
results.append(img)

print(f"Task {task} completed: {img}")
except Exception as error:
print(
colored(
(
f"Error running Dalle3: {error} try optimizing your api key and"
" or try again"
),
"red",
)
)
print(colored(f"Error running Dalle3: {error.http_status}", "red"))
print(colored(f"Error running Dalle3: {error.error}", "red"))
raise error
def _generate_uuid(self):
"""Generate a uuid"""
return str(uuid.uuid4())

def __repr__(self):
"""Repr method for the Dalle3 class"""
return f"Dalle3(image_url={self.image_url})"

def __str__(self):
"""Str method for the Dalle3 class"""
return f"Dalle3(image_url={self.image_url})"

@backoff.on_exception(backoff.expo, Exception, max_tries=max_retries)
def rate_limited_call(self, task: str):
"""Rate limited call to the Dalle3 API"""
return self.__call__(task)
5 changes: 4 additions & 1 deletion tests/models/auto_temp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

load_dotenv()


@pytest.fixture
def auto_temp_agent():
return AutoTempAgent(api_key=api_key)
Expand Down Expand Up @@ -47,7 +48,9 @@ def test_run_no_scores(auto_temp_agent):
task = "Invalid task."
temperature_string = "0.4,0.6,0.8,1.0,1.2,1.4"
with ThreadPoolExecutor(max_workers=auto_temp_agent.max_workers) as executor:
with patch.object(executor, "submit", side_effect=[None, None, None, None, None, None]):
with patch.object(
executor, "submit", side_effect=[None, None, None, None, None, None]
):
result = auto_temp_agent.run(task, temperature_string)
assert result == "No valid outputs generated."

Expand Down

0 comments on commit 41e5f17

Please sign in to comment.