Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Diffusers] Add text-guided image to image #223

Merged
merged 14 commits into from
Apr 19, 2023
2 changes: 1 addition & 1 deletion api_inference_community/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def emit(self, record):
}
)
return JSONResponse(items, headers=headers, status_code=status_code)
elif task == "text-to-image":
elif task in ["text-to-image", "image-to-image"]:
buf = io.BytesIO()
outputs.save(buf, format="JPEG")
buf.seek(0)
Expand Down
33 changes: 33 additions & 0 deletions docker_images/common/app/pipelines/image_to_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from typing import TYPE_CHECKING, Optional

from app.pipelines import Pipeline


if TYPE_CHECKING:
from PIL import Image


class ImageToImagePipeline(Pipeline):
def __init__(self, model_id: str):
# IMPLEMENT_THIS
# Preload all the elements you are going to need for inference.
# For instance your model, processors, tokenizer that might be needed.
# This function is only called once, so do all the heavy processing I/O here
raise NotImplementedError(
"Please implement ImageToImagePipeline.__init__ function"
)

def __call__(self, image: Image.Image, inputs: Optional[str] = "") -> "Image.Image":
"""
Args:
image (:obj:`PIL.Image.Image`):
a condition image
inputs (:obj:`str`, *optional*):
a string containing some text
Return:
A :obj:`PIL.Image` with the raw image representation as PIL.
"""
# IMPLEMENT_THIS
raise NotImplementedError(
"Please implement ImageToImagePipeline.__call__ function"
)
5 changes: 0 additions & 5 deletions docker_images/common/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,22 +26,17 @@
"speech-segmentation",
"tabular-classification",
"tabular-regression",
"text-classification",
"text-to-image",
"text-to-speech",
"token-classification",
"conversational",
"feature-extraction",
"question-answering",
"sentence-similarity",
"fill-mask",
"table-question-answering",
"summarization",
"text2text-generation",
"text-classification",
"text-to-image",
"text-to-speech",
"token-classification",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • The removed ones should be reverted
  • We should also add image-to-image here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They are duplicates "text-to-image" exists in there already. Cleaning up ;-)

"zero-shot-classification",
}

Expand Down
70 changes: 70 additions & 0 deletions docker_images/common/tests/test_api_image_to_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import os
from io import BytesIO
from unittest import TestCase, skipIf

import PIL
from app.main import ALLOWED_TASKS
from starlette.testclient import TestClient
from tests.test_api import TESTABLE_MODELS


@skipIf(
"image-to-image" not in ALLOWED_TASKS,
"image-to-image not implemented",
)
class ImageToImageTestCase(TestCase):
def setUp(self):
model_id = TESTABLE_MODELS["image-to-image"]
self.old_model_id = os.getenv("MODEL_ID")
self.old_task = os.getenv("TASK")
os.environ["MODEL_ID"] = model_id
os.environ["TASK"] = "image-to-image"
from app.main import app

self.app = app

@classmethod
def setUpClass(cls):
from app.main import get_pipeline

get_pipeline.cache_clear()

def tearDown(self):
if self.old_model_id is not None:
os.environ["MODEL_ID"] = self.old_model_id
else:
del os.environ["MODEL_ID"]
if self.old_task is not None:
os.environ["TASK"] = self.old_task
else:
del os.environ["TASK"]

def test_simple(self):
text = "soap bubble"
image = PIL.Image.new("RGB", (64, 64))

inputs = (image, text)

with TestClient(self.app) as client:
response = client.post("/", json={"inputs": inputs})

self.assertEqual(
response.status_code,
200,
)

image = PIL.Image.open(BytesIO(response.content))
self.assertTrue(isinstance(image, PIL.Image.Image))

def test_malformed_input(self):
with TestClient(self.app) as client:
response = client.post("/", data=b"\xc3\x28")

self.assertEqual(
response.status_code,
400,
)
self.assertEqual(
response.content,
b'{"error":"\'utf-8\' codec can\'t decode byte 0xc3 in position 0: invalid continuation byte"}',
)
7 changes: 5 additions & 2 deletions docker_images/diffusers/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Dict, Type

from api_inference_community.routes import pipeline_route, status_ok
from app.pipelines import Pipeline, TextToImagePipeline
from app.pipelines import ImageToImagePipeline, Pipeline, TextToImagePipeline
from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.middleware.gzip import GZipMiddleware
Expand Down Expand Up @@ -32,7 +32,10 @@
# ALLOWED_TASKS = {"automatic-speech-recognition": AutomaticSpeechRecognitionPipeline}
# You can check the requirements and expectations of each pipelines in their respective
# directories. Implement directly within the directories.
ALLOWED_TASKS: Dict[str, Type[Pipeline]] = {"text-to-image": TextToImagePipeline}
ALLOWED_TASKS: Dict[str, Type[Pipeline]] = {
"text-to-image": TextToImagePipeline,
"image-to-image": ImageToImagePipeline,
}


@functools.lru_cache()
Expand Down
1 change: 1 addition & 0 deletions docker_images/diffusers/app/pipelines/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from app.pipelines.base import Pipeline, PipelineException # isort:skip

from app.pipelines.image_to_image import ImageToImagePipeline
from app.pipelines.text_to_image import TextToImagePipeline
103 changes: 103 additions & 0 deletions docker_images/diffusers/app/pipelines/image_to_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import json
osanseviero marked this conversation as resolved.
Show resolved Hide resolved
import os

import torch
from app.pipelines import Pipeline
from diffusers import (
AltDiffusionImg2ImgPipeline,
ControlNetModel,
DiffusionPipeline,
DPMSolverMultistepScheduler,
StableDiffusionControlNetPipeline,
StableDiffusionImg2ImgPipeline,
)
from huggingface_hub import hf_hub_download, model_info
from PIL import Image


class ImageToImagePipeline(Pipeline):
def __init__(self, model_id: str):
model_data = model_info(model_id, token=os.getenv("HF_API_TOKEN"))

kwargs = (
{"safety_checker": None}
if model_id.startswith("hf-internal-testing/")
else {}
)
if torch.cuda.is_available():
kwargs["torch_dtype"] = torch.float16

has_config = any(
file.rfilename == "config.json" for file in model_data.siblings
)
if has_config:
config_file = hf_hub_download(
model_id, "config.json", token=os.getenv("HF_API_TOKEN")
)
with open(config_file, "r") as f:
config_dict = json.load(f)

is_controlnet = config_dict.get("_class_name", None) == "ControlNetModel"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_class_name could be added to the ModelInfo object so you don't need to load the whole config here. Internal PR for that https://github.com/huggingface/moon-landing/pull/6067

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok waiting for #6067


if is_controlnet:
model_to_load = model_data.cardData["base_model"]
controlnet = ControlNetModel.from_pretrained(
model_id, use_auth_token=os.getenv("HF_API_TOKEN"), **kwargs
)

self.ldm = StableDiffusionControlNetPipeline.from_pretrained(
model_to_load,
controlnet=controlnet,
use_auth_token=os.getenv("HF_API_TOKEN"),
**kwargs,
)
else:
self.ldm = DiffusionPipeline.from_pretrained(
Copy link
Contributor Author

@patrickvonplaten patrickvonplaten Apr 19, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can load both StableDiffusionInstructPix2PixPipeline, StableDiffusionImg2ImgPipeline and AltDiffusionImg2ImgPipeline depending on how the model_index.json is defined

model_id, use_auth_token=os.getenv("HF_API_TOKEN"), **kwargs
)

if torch.cuda.is_available():
self.ldm.to("cuda")
self.ldm.enable_xformers_memory_efficient_attention()

if isinstance(
self.ldm,
(
StableDiffusionImg2ImgPipeline,
AltDiffusionImg2ImgPipeline,
osanseviero marked this conversation as resolved.
Show resolved Hide resolved
StableDiffusionControlNetPipeline,
),
):
self.ldm.scheduler = DPMSolverMultistepScheduler.from_config(
self.ldm.scheduler.config
)

def __call__(self, image: Image.Image, inputs: str = "", **kwargs) -> "Image.Image":
"""
Args:
inputs (:obj:`str`):
a string containing some text
image (:obj:`PIL.Image.Image`):
a condition image
Return:
A :obj:`PIL.Image.Image` with the raw image representation as PIL.
"""

if isinstance(
self.ldm,
(
StableDiffusionImg2ImgPipeline,
AltDiffusionImg2ImgPipeline,
StableDiffusionControlNetPipeline,
),
):
if "num_inference_steps" not in kwargs:
kwargs["num_inference_steps"] = 25
images = self.ldm(
inputs,
image,
**kwargs,
)["images"]
else:
images = self.ldm(inputs, image, **kwargs)["images"]
return images[0]
2 changes: 1 addition & 1 deletion docker_images/diffusers/app/pipelines/text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __call__(self, inputs: str, **kwargs) -> "Image.Image":
inputs (:obj:`str`):
a string containing some text
Return:
A :obj:`PIL.Image` with the raw image representation as PIL.
A :obj:`PIL.Image.Image` with the raw image representation as PIL.
"""

if isinstance(self.ldm, (StableDiffusionPipeline, AltDiffusionPipeline)):
Expand Down
11 changes: 6 additions & 5 deletions docker_images/diffusers/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
starlette==0.25.0
api-inference-community==0.0.27
huggingface_hub==0.11.0
diffusers==0.12.0
-e git+https://github.com/huggingface/transformers@main#egg=transformers
accelerate==0.13.2
api-inference-community==0.0.29
huggingface_hub==0.13.3
safetensors==0.3.0
diffusers==0.14.0
transformers==4.27.4
accelerate==0.18.0
pydantic==1.8.2
ftfy==6.1.1
sentencepiece==0.1.97
Expand Down
10 changes: 8 additions & 2 deletions docker_images/diffusers/tests/test_api.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from typing import Dict
from typing import Dict, List
from unittest import TestCase, skipIf

from app.main import ALLOWED_TASKS, get_pipeline
Expand All @@ -8,7 +8,13 @@
# Must contain at least one example of each implemented pipeline
# Tests do not check the actual values of the model output, so small dummy
# models are recommended for faster tests.
TESTABLE_MODELS: Dict[str, str] = {"text-to-image": "CompVis/ldm-text2im-large-256"}
TESTABLE_MODELS: Dict[str, List[str]] = {
"text-to-image": ["hf-internal-testing/tiny-stable-diffusion-pipe-no-safety"],
"image-to-image": [
"hf-internal-testing/tiny-controlnet",
"hf-internal-testing/tiny-stable-diffusion-pix2pix",
],
}


ALL_TASKS = {
Expand Down
73 changes: 73 additions & 0 deletions docker_images/diffusers/tests/test_api_image_to_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import os
osanseviero marked this conversation as resolved.
Show resolved Hide resolved
from io import BytesIO
from unittest import TestCase, skipIf

import PIL
from app.main import ALLOWED_TASKS
from parameterized import parameterized_class
from starlette.testclient import TestClient
from tests.test_api import TESTABLE_MODELS


@skipIf(
"image-to-image" not in ALLOWED_TASKS,
"image-to-image not implemented",
)
@parameterized_class(
[{"model_id": model_id} for model_id in TESTABLE_MODELS["image-to-image"]]
)
class ImageToImageTestCase(TestCase):
def setUp(self):
self.old_model_id = os.getenv("MODEL_ID")
self.old_task = os.getenv("TASK")
os.environ["MODEL_ID"] = self.model_id
os.environ["TASK"] = "image-to-image"
from app.main import app

self.app = app

@classmethod
def setUpClass(cls):
from app.main import get_pipeline

get_pipeline.cache_clear()

def tearDown(self):
if self.old_model_id is not None:
os.environ["MODEL_ID"] = self.old_model_id
else:
del os.environ["MODEL_ID"]
if self.old_task is not None:
os.environ["TASK"] = self.old_task
else:
del os.environ["TASK"]

def test_simple(self):
text = "soap bubble"
image = PIL.Image.new("RGB", (64, 64))

inputs = (image, text)

with TestClient(self.app) as client:
response = client.post("/", json={"inputs": inputs})

self.assertEqual(
response.status_code,
200,
)

image = PIL.Image.open(BytesIO(response.content))
self.assertTrue(isinstance(image, PIL.Image.Image))

def test_malformed_input(self):
with TestClient(self.app) as client:
response = client.post("/", data=b"\xc3\x28")

self.assertEqual(
response.status_code,
400,
)
self.assertEqual(
response.content,
b'{"error":"\'utf-8\' codec can\'t decode byte 0xc3 in position 0: invalid continuation byte"}',
)
Loading