Skip to content

Commit

Permalink
minor improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
WaelKarkoub committed May 4, 2024
1 parent 4afab28 commit f5b6fd8
Showing 1 changed file with 18 additions and 9 deletions.
27 changes: 18 additions & 9 deletions autogen/agentchat/contrib/capabilities/generate_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
EXAMPLE: Blue background, 3D shapes, ...
"""

VALID_DALLE_MODELS = ["dall-e-2", "dall-e-3"]


class ImageGenerator(Protocol):
"""This class defines an interface for image generators.
Expand Down Expand Up @@ -81,14 +83,17 @@ def __init__(
num_images (int): The number of images to generate.
"""
config_list = llm_config["config_list"]
_validate_dalle_model(config_list[0]["model"])
dalle_configs = _find_valid_dalle_config(config_list)
assert len(dalle_configs) > 0, "Invalid DALL-E config. Must contain a valid DALL-E model."

_validate_dalle_model(dalle_configs[0]["model"])
_validate_resolution_format(resolution)

self._model = config_list[0]["model"]
self._model = dalle_configs[0]["model"]
self._resolution = resolution
self._quality = quality
self._num_images = num_images
self._dalle_client = self._dalle_client_factory(config_list)
self._dalle_client = self._dalle_client_factory(dalle_configs[0])

def generate_image(self, prompt: str) -> Image:
response = self._dalle_client.images.generate(
Expand All @@ -109,11 +114,11 @@ def cache_key(self, prompt: str) -> str:
keys = (prompt, self._model, self._resolution, self._quality, self._num_images)
return ",".join([str(k) for k in keys])

def _dalle_client_factory(self, config_list: List[Dict]) -> Union[OpenAI, AzureOpenAI]:
if config_list[0]["api_type"] == "azure":
return AzureOpenAI(api_key=config_list[0]["api_key"])
def _dalle_client_factory(self, dalle_config: Dict) -> Union[OpenAI, AzureOpenAI]:
if dalle_config.get("api_type") == "azure":
return AzureOpenAI(api_key=dalle_config["api_key"])
else:
return OpenAI(api_key=config_list[0]["api_key"])
return OpenAI(api_key=dalle_config["api_key"])


class ImageGeneration(AgentCapability):
Expand Down Expand Up @@ -294,5 +299,9 @@ def _validate_resolution_format(resolution: str):


def _validate_dalle_model(model: str):
if model not in ["dall-e-3", "dall-e-2"]:
raise ValueError(f"Invalid DALL-E model: {model}. Must be 'dall-e-3' or 'dall-e-2'")
if model not in VALID_DALLE_MODELS:
raise ValueError(f"Invalid DALL-E model: {model}. Must be in {VALID_DALLE_MODELS}")


def _find_valid_dalle_config(config_list: List[Dict]) -> List[Dict]:
return list(filter(lambda config: config["model"] in VALID_DALLE_MODELS, config_list))

0 comments on commit f5b6fd8

Please sign in to comment.