-
Notifications
You must be signed in to change notification settings - Fork 61
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Diffusers] Add text-guided image to image #223
Changes from all commits
009ee02
c1e8bca
13656ec
2523d3a
b921893
8228bab
14404b8
07c9495
1692bc1
b996948
d3ad2b2
a7c64ee
6a764c2
de62de4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
from typing import TYPE_CHECKING, Optional | ||
|
||
from app.pipelines import Pipeline | ||
|
||
|
||
if TYPE_CHECKING: | ||
from PIL import Image | ||
|
||
|
||
class ImageToImagePipeline(Pipeline): | ||
def __init__(self, model_id: str): | ||
# IMPLEMENT_THIS | ||
# Preload all the elements you are going to need for inference. | ||
# For instance your model, processors, tokenizer that might be needed. | ||
# This function is only called once, so do all the heavy processing I/O here | ||
raise NotImplementedError( | ||
"Please implement ImageToImagePipeline.__init__ function" | ||
) | ||
|
||
def __call__(self, image: Image.Image, inputs: Optional[str] = "") -> "Image.Image": | ||
""" | ||
Args: | ||
image (:obj:`PIL.Image.Image`): | ||
a condition image | ||
inputs (:obj:`str`, *optional*): | ||
a string containing some text | ||
Return: | ||
A :obj:`PIL.Image` with the raw image representation as PIL. | ||
""" | ||
# IMPLEMENT_THIS | ||
raise NotImplementedError( | ||
"Please implement ImageToImagePipeline.__call__ function" | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
import os | ||
from io import BytesIO | ||
from unittest import TestCase, skipIf | ||
|
||
import PIL | ||
from app.main import ALLOWED_TASKS | ||
from starlette.testclient import TestClient | ||
from tests.test_api import TESTABLE_MODELS | ||
|
||
|
||
@skipIf( | ||
"image-to-image" not in ALLOWED_TASKS, | ||
"image-to-image not implemented", | ||
) | ||
class ImageToImageTestCase(TestCase): | ||
def setUp(self): | ||
model_id = TESTABLE_MODELS["image-to-image"] | ||
self.old_model_id = os.getenv("MODEL_ID") | ||
self.old_task = os.getenv("TASK") | ||
os.environ["MODEL_ID"] = model_id | ||
os.environ["TASK"] = "image-to-image" | ||
from app.main import app | ||
|
||
self.app = app | ||
|
||
@classmethod | ||
def setUpClass(cls): | ||
from app.main import get_pipeline | ||
|
||
get_pipeline.cache_clear() | ||
|
||
def tearDown(self): | ||
if self.old_model_id is not None: | ||
os.environ["MODEL_ID"] = self.old_model_id | ||
else: | ||
del os.environ["MODEL_ID"] | ||
if self.old_task is not None: | ||
os.environ["TASK"] = self.old_task | ||
else: | ||
del os.environ["TASK"] | ||
|
||
def test_simple(self): | ||
text = "soap bubble" | ||
image = PIL.Image.new("RGB", (64, 64)) | ||
|
||
inputs = (image, text) | ||
|
||
with TestClient(self.app) as client: | ||
response = client.post("/", json={"inputs": inputs}) | ||
|
||
self.assertEqual( | ||
response.status_code, | ||
200, | ||
) | ||
|
||
image = PIL.Image.open(BytesIO(response.content)) | ||
self.assertTrue(isinstance(image, PIL.Image.Image)) | ||
|
||
def test_malformed_input(self): | ||
with TestClient(self.app) as client: | ||
response = client.post("/", data=b"\xc3\x28") | ||
|
||
self.assertEqual( | ||
response.status_code, | ||
400, | ||
) | ||
self.assertEqual( | ||
response.content, | ||
b'{"error":"\'utf-8\' codec can\'t decode byte 0xc3 in position 0: invalid continuation byte"}', | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
from app.pipelines.base import Pipeline, PipelineException # isort:skip | ||
|
||
from app.pipelines.image_to_image import ImageToImagePipeline | ||
from app.pipelines.text_to_image import TextToImagePipeline |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
import json | ||
osanseviero marked this conversation as resolved.
Show resolved
Hide resolved
|
||
import os | ||
|
||
import torch | ||
from app.pipelines import Pipeline | ||
from diffusers import ( | ||
AltDiffusionImg2ImgPipeline, | ||
ControlNetModel, | ||
DiffusionPipeline, | ||
DPMSolverMultistepScheduler, | ||
StableDiffusionControlNetPipeline, | ||
StableDiffusionImg2ImgPipeline, | ||
) | ||
from huggingface_hub import hf_hub_download, model_info | ||
from PIL import Image | ||
|
||
|
||
class ImageToImagePipeline(Pipeline): | ||
def __init__(self, model_id: str): | ||
model_data = model_info(model_id, token=os.getenv("HF_API_TOKEN")) | ||
|
||
kwargs = ( | ||
{"safety_checker": None} | ||
if model_id.startswith("hf-internal-testing/") | ||
else {} | ||
) | ||
if torch.cuda.is_available(): | ||
kwargs["torch_dtype"] = torch.float16 | ||
|
||
has_config = any( | ||
file.rfilename == "config.json" for file in model_data.siblings | ||
) | ||
if has_config: | ||
config_file = hf_hub_download( | ||
model_id, "config.json", token=os.getenv("HF_API_TOKEN") | ||
) | ||
with open(config_file, "r") as f: | ||
config_dict = json.load(f) | ||
|
||
is_controlnet = config_dict.get("_class_name", None) == "ControlNetModel" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok waiting for #6067 |
||
|
||
if is_controlnet: | ||
model_to_load = model_data.cardData["base_model"] | ||
controlnet = ControlNetModel.from_pretrained( | ||
model_id, use_auth_token=os.getenv("HF_API_TOKEN"), **kwargs | ||
) | ||
|
||
self.ldm = StableDiffusionControlNetPipeline.from_pretrained( | ||
model_to_load, | ||
controlnet=controlnet, | ||
use_auth_token=os.getenv("HF_API_TOKEN"), | ||
**kwargs, | ||
) | ||
else: | ||
self.ldm = DiffusionPipeline.from_pretrained( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This can load both |
||
model_id, use_auth_token=os.getenv("HF_API_TOKEN"), **kwargs | ||
) | ||
|
||
if torch.cuda.is_available(): | ||
self.ldm.to("cuda") | ||
self.ldm.enable_xformers_memory_efficient_attention() | ||
|
||
if isinstance( | ||
self.ldm, | ||
( | ||
StableDiffusionImg2ImgPipeline, | ||
AltDiffusionImg2ImgPipeline, | ||
osanseviero marked this conversation as resolved.
Show resolved
Hide resolved
|
||
StableDiffusionControlNetPipeline, | ||
), | ||
): | ||
self.ldm.scheduler = DPMSolverMultistepScheduler.from_config( | ||
self.ldm.scheduler.config | ||
) | ||
|
||
def __call__(self, image: Image.Image, inputs: str = "", **kwargs) -> "Image.Image": | ||
""" | ||
Args: | ||
inputs (:obj:`str`): | ||
a string containing some text | ||
image (:obj:`PIL.Image.Image`): | ||
a condition image | ||
Return: | ||
A :obj:`PIL.Image.Image` with the raw image representation as PIL. | ||
""" | ||
|
||
if isinstance( | ||
self.ldm, | ||
( | ||
StableDiffusionImg2ImgPipeline, | ||
AltDiffusionImg2ImgPipeline, | ||
StableDiffusionControlNetPipeline, | ||
), | ||
): | ||
if "num_inference_steps" not in kwargs: | ||
kwargs["num_inference_steps"] = 25 | ||
images = self.ldm( | ||
inputs, | ||
image, | ||
**kwargs, | ||
)["images"] | ||
else: | ||
images = self.ldm(inputs, image, **kwargs)["images"] | ||
return images[0] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
import os | ||
osanseviero marked this conversation as resolved.
Show resolved
Hide resolved
|
||
from io import BytesIO | ||
from unittest import TestCase, skipIf | ||
|
||
import PIL | ||
from app.main import ALLOWED_TASKS | ||
from parameterized import parameterized_class | ||
from starlette.testclient import TestClient | ||
from tests.test_api import TESTABLE_MODELS | ||
|
||
|
||
@skipIf( | ||
"image-to-image" not in ALLOWED_TASKS, | ||
"image-to-image not implemented", | ||
) | ||
@parameterized_class( | ||
[{"model_id": model_id} for model_id in TESTABLE_MODELS["image-to-image"]] | ||
) | ||
class ImageToImageTestCase(TestCase): | ||
def setUp(self): | ||
self.old_model_id = os.getenv("MODEL_ID") | ||
self.old_task = os.getenv("TASK") | ||
os.environ["MODEL_ID"] = self.model_id | ||
os.environ["TASK"] = "image-to-image" | ||
from app.main import app | ||
|
||
self.app = app | ||
|
||
@classmethod | ||
def setUpClass(cls): | ||
from app.main import get_pipeline | ||
|
||
get_pipeline.cache_clear() | ||
|
||
def tearDown(self): | ||
if self.old_model_id is not None: | ||
os.environ["MODEL_ID"] = self.old_model_id | ||
else: | ||
del os.environ["MODEL_ID"] | ||
if self.old_task is not None: | ||
os.environ["TASK"] = self.old_task | ||
else: | ||
del os.environ["TASK"] | ||
|
||
def test_simple(self): | ||
text = "soap bubble" | ||
image = PIL.Image.new("RGB", (64, 64)) | ||
|
||
inputs = (image, text) | ||
|
||
with TestClient(self.app) as client: | ||
response = client.post("/", json={"inputs": inputs}) | ||
|
||
self.assertEqual( | ||
response.status_code, | ||
200, | ||
) | ||
|
||
image = PIL.Image.open(BytesIO(response.content)) | ||
self.assertTrue(isinstance(image, PIL.Image.Image)) | ||
|
||
def test_malformed_input(self): | ||
with TestClient(self.app) as client: | ||
response = client.post("/", data=b"\xc3\x28") | ||
|
||
self.assertEqual( | ||
response.status_code, | ||
400, | ||
) | ||
self.assertEqual( | ||
response.content, | ||
b'{"error":"\'utf-8\' codec can\'t decode byte 0xc3 in position 0: invalid continuation byte"}', | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
image-to-image
hereThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They are duplicates "text-to-image" exists in there already. Cleaning up ;-)