Skip to content

Commit

Permalink
add stable diffusion integration (#1111)
Browse files Browse the repository at this point in the history
  • Loading branch information
sudoboi committed Sep 29, 2023
1 parent b12e7ac commit 551f066
Show file tree
Hide file tree
Showing 8 changed files with 642 additions and 0 deletions.
79 changes: 79 additions & 0 deletions evadb/functions/dalle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# coding=utf-8
# Copyright 2018-2023 EvaDB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

import pandas as pd

from evadb.catalog.catalog_type import NdArrayType
from evadb.configuration.configuration_manager import ConfigurationManager
from evadb.functions.abstract.abstract_function import AbstractFunction
from evadb.functions.decorators.decorators import forward
from evadb.functions.decorators.io_descriptors.data_types import PandasDataframe
from evadb.utils.generic_utils import try_to_import_openai


class DallEFunction(AbstractFunction):
@property
def name(self) -> str:
return "DallE"

def setup(self) -> None:
pass

@forward(
input_signatures=[
PandasDataframe(
columns=["prompt"],
column_types=[
NdArrayType.STR,
],
column_shapes=[(None,)],
)
],
output_signatures=[
PandasDataframe(
columns=["response"],
column_types=[
NdArrayType.STR,
],
column_shapes=[(1,)],
)
],
)
def forward(self, text_df):
try_to_import_openai()
import openai

# Register API key, try configuration manager first
openai.api_key = ConfigurationManager().get_value("third_party", "OPENAI_KEY")
# If not found, try OS Environment Variable
if len(openai.api_key) == 0:
openai.api_key = os.environ.get("OPENAI_KEY", "")
assert (
len(openai.api_key) != 0
), "Please set your OpenAI API key in evadb.yml file (third_party, open_api_key) or environment variable (OPENAI_KEY)"

def generate_image(text_df: PandasDataframe):
results = []
queries = text_df[text_df.columns[0]]
for query in queries:
response = openai.Image.create(prompt=query, n=1, size="1024x1024")
results.append(response["data"][0]["url"])
return results

df = pd.DataFrame({"response": generate_image(text_df=text_df)})

return df
12 changes: 12 additions & 0 deletions evadb/functions/function_bootstrap_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,18 @@
MODEL 'yolov8n.pt';
"""

stablediffusion_function_query = """CREATE FUNCTION IF NOT EXISTS StableDiffusion
IMPL '{}/functions/stable_diffusion.py';
""".format(
EvaDB_INSTALLATION_DIR
)

dalle_function_query = """CREATE FUNCTION IF NOT EXISTS StableDiffusion
IMPL '{}/functions/dalle.py';
""".format(
EvaDB_INSTALLATION_DIR
)


def init_builtin_functions(db: EvaDBDatabase, mode: str = "debug") -> None:
"""Load the built-in functions into the system during system bootstrapping.
Expand Down
88 changes: 88 additions & 0 deletions evadb/functions/stable_diffusion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# coding=utf-8
# Copyright 2018-2023 EvaDB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

import pandas as pd

from evadb.catalog.catalog_type import NdArrayType
from evadb.functions.abstract.abstract_function import AbstractFunction
from evadb.functions.decorators.decorators import forward
from evadb.functions.decorators.io_descriptors.data_types import PandasDataframe
from evadb.utils.generic_utils import try_to_import_replicate

_VALID_STABLE_DIFFUSION_MODEL = [
"sdxl:af1a68a271597604546c09c64aabcd7782c114a63539a4a8d14d1eeda5630c33",
]


class StableDiffusion(AbstractFunction):
@property
def name(self) -> str:
return "StableDiffusion"

def setup(
self,
model="sdxl:af1a68a271597604546c09c64aabcd7782c114a63539a4a8d14d1eeda5630c33",
) -> None:
assert (
model in _VALID_STABLE_DIFFUSION_MODEL
), f"Unsupported Stable Diffusion {model}"
self.model = model

@forward(
input_signatures=[
PandasDataframe(
columns=["prompt"],
column_types=[
NdArrayType.STR,
],
column_shapes=[(None,)],
)
],
output_signatures=[
PandasDataframe(
columns=["response"],
column_types=[
NdArrayType.STR,
],
column_shapes=[(1,)],
)
],
)
def forward(self, text_df):
try_to_import_replicate()
import replicate

if os.environ.get("REPLICATE_API_TOKEN") is None:
replicate_api_key = (
"r8_Q75IAgbaHFvYVfLSMGmjQPcW5uZZoXz0jGalu" # token for testing
)
os.environ["REPLICATE_API_TOKEN"] = replicate_api_key

# @retry(tries=5, delay=20)
def generate_image(text_df: PandasDataframe):
results = []
queries = text_df[text_df.columns[0]]
for query in queries:
output = replicate.run(
"stability-ai/" + self.model, input={"prompt": query}
)
results.append(output[0])
return results

df = pd.DataFrame({"response": generate_image(text_df=text_df)})

return df
18 changes: 18 additions & 0 deletions evadb/utils/generic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,3 +596,21 @@ def string_comparison_case_insensitive(string_1, string_2) -> bool:
return False

return string_1.lower() == string_2.lower()


def try_to_import_replicate():
try:
import replicate # noqa: F401
except ImportError:
raise ValueError(
"""Could not import replicate python package.
Please install it with `pip install replicate`."""
)


def is_replicate_available():
try:
try_to_import_replicate()
return True
except ValueError:
return False
60 changes: 60 additions & 0 deletions test/integration_tests/long/functions/test_dalle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# coding=utf-8
# Copyright 2018-2023 EvaDB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import unittest
from test.util import get_evadb_for_testing
from unittest.mock import patch

from evadb.server.command_handler import execute_query_fetch_all


class DallEFunctionTest(unittest.TestCase):
def setUp(self) -> None:
self.evadb = get_evadb_for_testing()
self.evadb.catalog().reset()
create_table_query = """CREATE TABLE IF NOT EXISTS ImageGen (
prompt TEXT(100));
"""
execute_query_fetch_all(self.evadb, create_table_query)

test_prompts = ["a surreal painting of a cat"]

for prompt in test_prompts:
insert_query = f"""INSERT INTO ImageGen (prompt) VALUES ('{prompt}')"""
execute_query_fetch_all(self.evadb, insert_query)

def tearDown(self) -> None:
execute_query_fetch_all(self.evadb, "DROP TABLE IF EXISTS ImageGen;")

@patch.dict("os.environ", {"OPENAI_KEY": "mocked_openai_key"})
@patch("openai.Image.create", return_value={"data": [{"url": "mocked_url"}]})
def test_dalle_image_generation(self, mock_openai_create):
function_name = "DallE"

execute_query_fetch_all(self.evadb, f"DROP FUNCTION IF EXISTS {function_name};")

create_function_query = f"""CREATE FUNCTION IF NOT EXISTS{function_name}
IMPL 'evadb/functions/dalle.py';
"""
execute_query_fetch_all(self.evadb, create_function_query)

gpt_query = f"SELECT {function_name}(prompt) FROM ImageGen;"
output_batch = execute_query_fetch_all(self.evadb, gpt_query)

self.assertEqual(output_batch.columns, ["dalle.response"])
mock_openai_create.assert_called_once_with(
prompt="a surreal painting of a cat", n=1, size="1024x1024"
)
61 changes: 61 additions & 0 deletions test/integration_tests/long/functions/test_selfdiffusion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# coding=utf-8
# Copyright 2018-2023 EvaDB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
from test.markers import stable_diffusion_skip_marker
from test.util import get_evadb_for_testing
from unittest.mock import patch

from evadb.server.command_handler import execute_query_fetch_all


class StableDiffusionTest(unittest.TestCase):
def setUp(self) -> None:
self.evadb = get_evadb_for_testing()
self.evadb.catalog().reset()
create_table_query = """CREATE TABLE IF NOT EXISTS ImageGen (
prompt TEXT(100));
"""
execute_query_fetch_all(self.evadb, create_table_query)

test_prompts = ["pink cat riding a rocket to the moon"]

for prompt in test_prompts:
insert_query = f"""INSERT INTO ImageGen (prompt) VALUES ('{prompt}')"""
execute_query_fetch_all(self.evadb, insert_query)

def tearDown(self) -> None:
execute_query_fetch_all(self.evadb, "DROP TABLE IF EXISTS ImageGen;")

@stable_diffusion_skip_marker
@patch("replicate.run", return_value=[{"response": "mocked response"}])
def test_stable_diffusion_image_generation(self, mock_replicate_run):
function_name = "StableDiffusion"

execute_query_fetch_all(self.evadb, f"DROP FUNCTION IF EXISTS {function_name};")

create_function_query = f"""CREATE FUNCTION IF NOT EXISTS{function_name}
IMPL 'evadb/functions/stable_diffusion.py';
"""
execute_query_fetch_all(self.evadb, create_function_query)

gpt_query = f"SELECT {function_name}(prompt) FROM ImageGen;"
output_batch = execute_query_fetch_all(self.evadb, gpt_query)

self.assertEqual(output_batch.columns, ["stablediffusion.response"])
mock_replicate_run.assert_called_once_with(
"stability-ai/sdxl:af1a68a271597604546c09c64aabcd7782c114a63539a4a8d14d1eeda5630c33",
input={"prompt": "pink cat riding a rocket to the moon"},
)
5 changes: 5 additions & 0 deletions test/markers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
is_ludwig_available,
is_pinecone_available,
is_qdrant_available,
is_replicate_available,
is_sklearn_available,
)

Expand Down Expand Up @@ -96,3 +97,7 @@
is_forecast_available() is False,
reason="Run only if forecasting packages available",
)

stable_diffusion_skip_marker = pytest.mark.skipif(
is_replicate_available() is False, reason="requires replicate"
)
Loading

0 comments on commit 551f066

Please sign in to comment.