From 34209af40452095838c69484ad42ad7ae70a5486 Mon Sep 17 00:00:00 2001 From: shubhamofbce Date: Sat, 3 Feb 2024 13:35:24 +0530 Subject: [PATCH] Add support for Stable Diffusion pipelines in interface.from_pipeline(), import transformers and diffusers conditionally --- gradio/pipelines.py | 808 ++++++++++++++++++++++++++++++-------------- 1 file changed, 555 insertions(+), 253 deletions(-) diff --git a/gradio/pipelines.py b/gradio/pipelines.py index 13f7bba8ee4e5..b2fc6055dbac0 100644 --- a/gradio/pipelines.py +++ b/gradio/pipelines.py @@ -18,270 +18,568 @@ def load_from_pipeline(pipeline: pipelines.base.Pipeline) -> dict: Returns: (dict): a dictionary of kwargs that can be used to construct an Interface object """ - try: - import transformers, diffusers - from transformers import pipelines - from diffusers import DiffusionPipeline - from diffusers import pipelines as diffuser_pipelines - except ImportError as ie: - raise ImportError( - "transformers not installed. Please try `pip install transformers`" - ) from ie - if not ((isinstance(pipeline, pipelines.base.Pipeline)) or (isinstance(pipeline, DiffusionPipeline))): - raise ValueError("pipeline must be a transformers.Pipeline or DiffusionPipeline") + if "transformers.pipelines" in str(type(pipeline)): + try: + import transformers + from transformers import pipelines - # Handle the different pipelines. The has_attr() checks to make sure the pipeline exists in the - # version of the transformers library that the user has installed. - if hasattr(transformers, "AudioClassificationPipeline") and isinstance( - pipeline, pipelines.audio_classification.AudioClassificationPipeline - ): - pipeline_info = { - "inputs": components.Audio( - sources=["microphone"], - type="filepath", - label="Input", - render=False, - ), - "outputs": components.Label(label="Class", render=False), - "preprocess": lambda i: {"inputs": i}, - "postprocess": lambda r: {i["label"].split(", ")[0]: i["score"] for i in r}, - } - elif hasattr(transformers, "AutomaticSpeechRecognitionPipeline") and isinstance( - pipeline, - pipelines.automatic_speech_recognition.AutomaticSpeechRecognitionPipeline, - ): - pipeline_info = { - "inputs": components.Audio( - sources=["microphone"], type="filepath", label="Input", render=False - ), - "outputs": components.Textbox(label="Output", render=False), - "preprocess": lambda i: {"inputs": i}, - "postprocess": lambda r: r["text"], - } - elif hasattr(transformers, "FeatureExtractionPipeline") and isinstance( - pipeline, pipelines.feature_extraction.FeatureExtractionPipeline - ): - pipeline_info = { - "inputs": components.Textbox(label="Input", render=False), - "outputs": components.Dataframe(label="Output", render=False), - "preprocess": lambda x: {"inputs": x}, - "postprocess": lambda r: r[0], - } - elif hasattr(transformers, "FillMaskPipeline") and isinstance( - pipeline, pipelines.fill_mask.FillMaskPipeline - ): - pipeline_info = { - "inputs": components.Textbox(label="Input", render=False), - "outputs": components.Label(label="Classification", render=False), - "preprocess": lambda x: {"inputs": x}, - "postprocess": lambda r: {i["token_str"]: i["score"] for i in r}, - } - elif hasattr(transformers, "ImageClassificationPipeline") and isinstance( - pipeline, pipelines.image_classification.ImageClassificationPipeline - ): - pipeline_info = { - "inputs": components.Image( - type="filepath", label="Input Image", render=False - ), - "outputs": components.Label(label="Classification", render=False), - "preprocess": lambda i: {"images": i}, - "postprocess": lambda r: {i["label"].split(", ")[0]: i["score"] for i in r}, - } - elif hasattr(transformers, "QuestionAnsweringPipeline") and isinstance( - pipeline, pipelines.question_answering.QuestionAnsweringPipeline - ): - pipeline_info = { - "inputs": [ - components.Textbox(lines=7, label="Context", render=False), - components.Textbox(label="Question", render=False), - ], - "outputs": [ - components.Textbox(label="Answer", render=False), - components.Label(label="Score", render=False), - ], - "preprocess": lambda c, q: {"context": c, "question": q}, - "postprocess": lambda r: (r["answer"], r["score"]), - } - elif hasattr(transformers, "SummarizationPipeline") and isinstance( - pipeline, pipelines.text2text_generation.SummarizationPipeline - ): - pipeline_info = { - "inputs": components.Textbox(lines=7, label="Input", render=False), - "outputs": components.Textbox(label="Summary", render=False), - "preprocess": lambda x: {"inputs": x}, - "postprocess": lambda r: r[0]["summary_text"], - } - elif hasattr(transformers, "TextClassificationPipeline") and isinstance( - pipeline, pipelines.text_classification.TextClassificationPipeline - ): - pipeline_info = { - "inputs": components.Textbox(label="Input", render=False), - "outputs": components.Label(label="Classification", render=False), - "preprocess": lambda x: [x], - "postprocess": lambda r: {i["label"].split(", ")[0]: i["score"] for i in r}, - } - elif hasattr(transformers, "TextGenerationPipeline") and isinstance( - pipeline, pipelines.text_generation.TextGenerationPipeline - ): - pipeline_info = { - "inputs": components.Textbox(label="Input", render=False), - "outputs": components.Textbox(label="Output", render=False), - "preprocess": lambda x: {"text_inputs": x}, - "postprocess": lambda r: r[0]["generated_text"], - } - elif hasattr(transformers, "TranslationPipeline") and isinstance( - pipeline, pipelines.text2text_generation.TranslationPipeline - ): - pipeline_info = { - "inputs": components.Textbox(label="Input", render=False), - "outputs": components.Textbox(label="Translation", render=False), - "preprocess": lambda x: [x], - "postprocess": lambda r: r[0]["translation_text"], - } - elif hasattr(transformers, "Text2TextGenerationPipeline") and isinstance( - pipeline, pipelines.text2text_generation.Text2TextGenerationPipeline - ): - pipeline_info = { - "inputs": components.Textbox(label="Input", render=False), - "outputs": components.Textbox(label="Generated Text", render=False), - "preprocess": lambda x: [x], - "postprocess": lambda r: r[0]["generated_text"], - } - elif hasattr(transformers, "ZeroShotClassificationPipeline") and isinstance( - pipeline, pipelines.zero_shot_classification.ZeroShotClassificationPipeline - ): - pipeline_info = { - "inputs": [ - components.Textbox(label="Input", render=False), - components.Textbox( - label="Possible class names (" "comma-separated)", render=False + except ImportError as ie: + raise ImportError( + "transformers not installed. Please try `pip install transformers`" + ) from ie + + # Handle the different pipelines. The has_attr() checks to make sure the pipeline exists in the + # version of the transformers library that the user has installed. + if hasattr(transformers, "AudioClassificationPipeline") and isinstance( + pipeline, pipelines.audio_classification.AudioClassificationPipeline + ): + pipeline_info = { + "inputs": components.Audio( + sources=["microphone"], + type="filepath", + label="Input", + render=False, + ), + "outputs": components.Label(label="Class", render=False), + "preprocess": lambda i: {"inputs": i}, + "postprocess": lambda r: {i["label"].split(", ")[0]: i["score"] for i in r}, + } + elif hasattr(transformers, "AutomaticSpeechRecognitionPipeline") and isinstance( + pipeline, + pipelines.automatic_speech_recognition.AutomaticSpeechRecognitionPipeline, + ): + pipeline_info = { + "inputs": components.Audio( + sources=["microphone"], type="filepath", label="Input", render=False + ), + "outputs": components.Textbox(label="Output", render=False), + "preprocess": lambda i: {"inputs": i}, + "postprocess": lambda r: r["text"], + } + elif hasattr(transformers, "FeatureExtractionPipeline") and isinstance( + pipeline, pipelines.feature_extraction.FeatureExtractionPipeline + ): + pipeline_info = { + "inputs": components.Textbox(label="Input", render=False), + "outputs": components.Dataframe(label="Output", render=False), + "preprocess": lambda x: {"inputs": x}, + "postprocess": lambda r: r[0], + } + elif hasattr(transformers, "FillMaskPipeline") and isinstance( + pipeline, pipelines.fill_mask.FillMaskPipeline + ): + pipeline_info = { + "inputs": components.Textbox(label="Input", render=False), + "outputs": components.Label(label="Classification", render=False), + "preprocess": lambda x: {"inputs": x}, + "postprocess": lambda r: {i["token_str"]: i["score"] for i in r}, + } + elif hasattr(transformers, "ImageClassificationPipeline") and isinstance( + pipeline, pipelines.image_classification.ImageClassificationPipeline + ): + pipeline_info = { + "inputs": components.Image( + type="filepath", label="Input Image", render=False ), - components.Checkbox(label="Allow multiple true classes", render=False), - ], - "outputs": components.Label(label="Classification", render=False), - "preprocess": lambda i, c, m: { - "sequences": i, - "candidate_labels": c, - "multi_label": m, - }, - "postprocess": lambda r: { - r["labels"][i]: r["scores"][i] for i in range(len(r["labels"])) - }, - } - elif hasattr(transformers, "DocumentQuestionAnsweringPipeline") and isinstance( - pipeline, - pipelines.document_question_answering.DocumentQuestionAnsweringPipeline, # type: ignore - ): - pipeline_info = { - "inputs": [ - components.Image(type="filepath", label="Input Document", render=False), - components.Textbox(label="Question", render=False), - ], - "outputs": components.Label(label="Label", render=False), - "preprocess": lambda img, q: {"image": img, "question": q}, - "postprocess": lambda r: {i["answer"]: i["score"] for i in r}, - } - elif hasattr(transformers, "VisualQuestionAnsweringPipeline") and isinstance( - pipeline, pipelines.visual_question_answering.VisualQuestionAnsweringPipeline - ): - pipeline_info = { - "inputs": [ - components.Image(type="filepath", label="Input Image", render=False), - components.Textbox(label="Question", render=False), - ], - "outputs": components.Label(label="Score", render=False), - "preprocess": lambda img, q: {"image": img, "question": q}, - "postprocess": lambda r: {i["answer"]: i["score"] for i in r}, - } - elif hasattr(transformers, "ImageToTextPipeline") and isinstance( - pipeline, - pipelines.image_to_text.ImageToTextPipeline, # type: ignore - ): - pipeline_info = { - "inputs": components.Image( - type="filepath", label="Input Image", render=False - ), - "outputs": components.Textbox(label="Text", render=False), - "preprocess": lambda i: {"images": i}, - "postprocess": lambda r: r[0]["generated_text"], - } - elif hasattr(transformers, "ObjectDetectionPipeline") and isinstance( - pipeline, pipelines.object_detection.ObjectDetectionPipeline - ): - pipeline_info = { - "inputs": components.Image( - type="filepath", label="Input Image", render=False - ), - "outputs": components.AnnotatedImage( - label="Objects Detected", render=False - ), - "preprocess": lambda i: {"inputs": i}, - "postprocess": lambda r, img: ( - img, - [ - ( + "outputs": components.Label(label="Classification", render=False), + "preprocess": lambda i: {"images": i}, + "postprocess": lambda r: {i["label"].split(", ")[0]: i["score"] for i in r}, + } + elif hasattr(transformers, "QuestionAnsweringPipeline") and isinstance( + pipeline, pipelines.question_answering.QuestionAnsweringPipeline + ): + pipeline_info = { + "inputs": [ + components.Textbox(lines=7, label="Context", render=False), + components.Textbox(label="Question", render=False), + ], + "outputs": [ + components.Textbox(label="Answer", render=False), + components.Label(label="Score", render=False), + ], + "preprocess": lambda c, q: {"context": c, "question": q}, + "postprocess": lambda r: (r["answer"], r["score"]), + } + elif hasattr(transformers, "SummarizationPipeline") and isinstance( + pipeline, pipelines.text2text_generation.SummarizationPipeline + ): + pipeline_info = { + "inputs": components.Textbox(lines=7, label="Input", render=False), + "outputs": components.Textbox(label="Summary", render=False), + "preprocess": lambda x: {"inputs": x}, + "postprocess": lambda r: r[0]["summary_text"], + } + elif hasattr(transformers, "TextClassificationPipeline") and isinstance( + pipeline, pipelines.text_classification.TextClassificationPipeline + ): + pipeline_info = { + "inputs": components.Textbox(label="Input", render=False), + "outputs": components.Label(label="Classification", render=False), + "preprocess": lambda x: [x], + "postprocess": lambda r: {i["label"].split(", ")[0]: i["score"] for i in r}, + } + elif hasattr(transformers, "TextGenerationPipeline") and isinstance( + pipeline, pipelines.text_generation.TextGenerationPipeline + ): + pipeline_info = { + "inputs": components.Textbox(label="Input", render=False), + "outputs": components.Textbox(label="Output", render=False), + "preprocess": lambda x: {"text_inputs": x}, + "postprocess": lambda r: r[0]["generated_text"], + } + elif hasattr(transformers, "TranslationPipeline") and isinstance( + pipeline, pipelines.text2text_generation.TranslationPipeline + ): + pipeline_info = { + "inputs": components.Textbox(label="Input", render=False), + "outputs": components.Textbox(label="Translation", render=False), + "preprocess": lambda x: [x], + "postprocess": lambda r: r[0]["translation_text"], + } + elif hasattr(transformers, "Text2TextGenerationPipeline") and isinstance( + pipeline, pipelines.text2text_generation.Text2TextGenerationPipeline + ): + pipeline_info = { + "inputs": components.Textbox(label="Input", render=False), + "outputs": components.Textbox(label="Generated Text", render=False), + "preprocess": lambda x: [x], + "postprocess": lambda r: r[0]["generated_text"], + } + elif hasattr(transformers, "ZeroShotClassificationPipeline") and isinstance( + pipeline, pipelines.zero_shot_classification.ZeroShotClassificationPipeline + ): + pipeline_info = { + "inputs": [ + components.Textbox(label="Input", render=False), + components.Textbox( + label="Possible class names (" "comma-separated)", render=False + ), + components.Checkbox(label="Allow multiple true classes", render=False), + ], + "outputs": components.Label(label="Classification", render=False), + "preprocess": lambda i, c, m: { + "sequences": i, + "candidate_labels": c, + "multi_label": m, + }, + "postprocess": lambda r: { + r["labels"][i]: r["scores"][i] for i in range(len(r["labels"])) + }, + } + elif hasattr(transformers, "DocumentQuestionAnsweringPipeline") and isinstance( + pipeline, + pipelines.document_question_answering.DocumentQuestionAnsweringPipeline, # type: ignore + ): + pipeline_info = { + "inputs": [ + components.Image(type="filepath", label="Input Document", render=False), + components.Textbox(label="Question", render=False), + ], + "outputs": components.Label(label="Label", render=False), + "preprocess": lambda img, q: {"image": img, "question": q}, + "postprocess": lambda r: {i["answer"]: i["score"] for i in r}, + } + elif hasattr(transformers, "VisualQuestionAnsweringPipeline") and isinstance( + pipeline, pipelines.visual_question_answering.VisualQuestionAnsweringPipeline + ): + pipeline_info = { + "inputs": [ + components.Image(type="filepath", label="Input Image", render=False), + components.Textbox(label="Question", render=False), + ], + "outputs": components.Label(label="Score", render=False), + "preprocess": lambda img, q: {"image": img, "question": q}, + "postprocess": lambda r: {i["answer"]: i["score"] for i in r}, + } + elif hasattr(transformers, "ImageToTextPipeline") and isinstance( + pipeline, + pipelines.image_to_text.ImageToTextPipeline, # type: ignore + ): + pipeline_info = { + "inputs": components.Image( + type="filepath", label="Input Image", render=False + ), + "outputs": components.Textbox(label="Text", render=False), + "preprocess": lambda i: {"images": i}, + "postprocess": lambda r: r[0]["generated_text"], + } + elif hasattr(transformers, "ObjectDetectionPipeline") and isinstance( + pipeline, pipelines.object_detection.ObjectDetectionPipeline + ): + pipeline_info = { + "inputs": components.Image( + type="filepath", label="Input Image", render=False + ), + "outputs": components.AnnotatedImage( + label="Objects Detected", render=False + ), + "preprocess": lambda i: {"inputs": i}, + "postprocess": lambda r, img: ( + img, + [ ( - i["box"]["xmin"], - i["box"]["ymin"], - i["box"]["xmax"], - i["box"]["ymax"], - ), - i["label"], - ) - for i in r + ( + i["box"]["xmin"], + i["box"]["ymin"], + i["box"]["xmax"], + i["box"]["ymax"], + ), + i["label"], + ) + for i in r + ], + ), + } + else: + raise ValueError(f"Unsupported pipeline type: {type(pipeline)}") + + elif "diffusers.pipelines" in str(type(pipeline)): + try: + import diffusers + from diffusers import pipelines as diffuser_pipelines + from PIL import Image + + except ImportError as ie: + raise ImportError( + "diffusers not installed. Please try `pip install diffusers`" + ) from ie + + # Handle diffuser pipelines + if hasattr(diffusers, "StableDiffusionPipeline") and isinstance( + pipeline, diffuser_pipelines.StableDiffusionPipeline + ): + pipeline_info = { + "inputs": [ + components.Textbox(label="Prompt", render=False), + components.Textbox(label="Negative prompt", render=False), + components.Slider( + label="Number of inference steps", + minimum=1, + maximum=500, + value=50, + step=1, + ), + components.Slider( + label="Guidance scale", + minimum=1, + maximum=20, + value=7.5, + step=0.5, + ), + ], + "outputs": components.Image( + label="Generated Image", render=False, type="pil" + ), + "preprocess": lambda prompt, n_prompt, num_inf_steps, g_scale: { + "prompt": prompt, + "negative_prompt": n_prompt, + "num_inference_steps": num_inf_steps, + "guidance_scale": g_scale, + }, + "postprocess": lambda r: r["images"][0], + } + elif hasattr(diffusers, "StableDiffusionImg2ImgPipeline") and isinstance( + pipeline, diffuser_pipelines.StableDiffusionImg2ImgPipeline + ): + pipeline_info = { + "inputs": [ + components.Textbox(label="Prompt", render=False), + components.Textbox(label="Negative prompt", render=False), + components.Image(type="filepath", label="Image", render=False), + components.Slider( + label="Strength", minimum=0, maximum=1, value=0.8, step=0.1 + ), + components.Slider( + label="Number of inference steps", + minimum=1, + maximum=500, + value=50, + step=1, + ), + components.Slider( + label="Guidance scale", + minimum=1, + maximum=20, + value=7.5, + step=0.5, + ), + ], + "outputs": components.Image( + label="Generated Image", render=False, type="pil" + ), + "preprocess": lambda prompt, + n_prompt, + image, + strength, + num_inf_steps, + g_scale: { + "prompt": prompt, + "image": Image.open(image).resize((768, 768)), + "negative_prompt": n_prompt, + "num_inference_steps": num_inf_steps, + "guidance_scale": g_scale, + "strength": strength, + }, + "postprocess": lambda r: r["images"][0], + } + elif hasattr(diffusers, "StableDiffusionInpaintPipeline") and isinstance( + pipeline, diffuser_pipelines.StableDiffusionInpaintPipeline + ): + pipeline_info = { + "inputs": [ + components.Textbox(label="Prompt", render=False), + components.Textbox(label="Negative prompt", render=False), + components.Image(type="filepath", label="Image", render=False), + components.Image(type="filepath", label="Mask Image", render=False), + components.Slider( + label="Strength", minimum=0, maximum=1, value=0.8, step=0.1 + ), + components.Slider( + label="Number of inference steps", + minimum=1, + maximum=500, + value=50, + step=1, + ), + components.Slider( + label="Guidance scale", + minimum=1, + maximum=20, + value=7.5, + step=0.5, + ), ], - ), - } - # Diffuser pipelines - elif hasattr(diffusers, "StableDiffusionPipeline") and isinstance( - pipeline, diffuser_pipelines.StableDiffusionPipeline - ): - # TODO: complete this - pipeline_info = { - "inputs": [ - components.Textbox(label="Prompt", render=False), - components.Textbox(label="Negative prompt", render=False), - components.Slider(label="Number of inference steps", minimum=1, maximum=500, value=50, step=1), - components.Slider(label="Guidance scale", minimum=1, maximum=20, value=7.5, step=0.5) - ], - "outputs": components.Image(label="Generated Image", render=False, type="pil"), - "preprocess": lambda prompt, n_prompt, num_inf_steps, g_scale: { - "prompt": prompt, - "negative_prompt": n_prompt, - "num_inference_steps": num_inf_steps, - "guidance_scale": g_scale - }, - "postprocess": lambda r: r["images"][0], - } + "outputs": components.Image( + label="Generated Image", render=False, type="pil" + ), + "preprocess": lambda prompt, + n_prompt, + image, + mask_image, + strength, + num_inf_steps, + g_scale: { + "prompt": prompt, + "image": Image.open(image).resize((768, 768)), + "mask_image": Image.open(mask_image).resize((768, 768)), + "negative_prompt": n_prompt, + "num_inference_steps": num_inf_steps, + "guidance_scale": g_scale, + "strength": strength, + }, + "postprocess": lambda r: r["images"][0], + } + elif hasattr(diffusers, "StableDiffusionDepth2ImgPipeline") and isinstance( + pipeline, diffuser_pipelines.StableDiffusionDepth2ImgPipeline + ): + pipeline_info = { + "inputs": [ + components.Textbox(label="Prompt", render=False), + components.Textbox(label="Negative prompt", render=False), + components.Image(type="filepath", label="Image", render=False), + components.Slider( + label="Strength", minimum=0, maximum=1, value=0.8, step=0.1 + ), + components.Slider( + label="Number of inference steps", + minimum=1, + maximum=500, + value=50, + step=1, + ), + components.Slider( + label="Guidance scale", + minimum=1, + maximum=20, + value=7.5, + step=0.5, + ), + ], + "outputs": components.Image( + label="Generated Image", render=False, type="pil" + ), + "preprocess": lambda prompt, + n_prompt, + image, + strength, + num_inf_steps, + g_scale: { + "prompt": prompt, + "image": Image.open(image).resize((768, 768)), + "negative_prompt": n_prompt, + "num_inference_steps": num_inf_steps, + "guidance_scale": g_scale, + "strength": strength, + }, + "postprocess": lambda r: r["images"][0], + } + elif hasattr(diffusers, "StableDiffusionImageVariationPipeline") and isinstance( + pipeline, diffuser_pipelines.StableDiffusionImageVariationPipeline + ): + pipeline_info = { + "inputs": [ + components.Image(type="filepath", label="Image", render=False), + components.Slider( + label="Number of inference steps", + minimum=1, + maximum=500, + value=50, + step=1, + ), + components.Slider( + label="Guidance scale", + minimum=1, + maximum=20, + value=7.5, + step=0.5, + ), + ], + "outputs": components.Image( + label="Generated Image", render=False, type="pil" + ), + "preprocess": lambda image, num_inf_steps, g_scale: { + "image": Image.open(image).resize((768, 768)), + "num_inference_steps": num_inf_steps, + "guidance_scale": g_scale, + }, + "postprocess": lambda r: r["images"][0], + } + elif hasattr( + diffusers, "StableDiffusionInstructPix2PixPipeline" + ) and isinstance( + pipeline, diffuser_pipelines.StableDiffusionInstructPix2PixPipeline + ): + pipeline_info = { + "inputs": [ + components.Textbox(label="Prompt", render=False), + components.Textbox(label="Negative prompt", render=False), + components.Image(type="filepath", label="Image", render=False), + components.Slider( + label="Number of inference steps", + minimum=1, + maximum=500, + value=50, + step=1, + ), + components.Slider( + label="Guidance scale", + minimum=1, + maximum=20, + value=7.5, + step=0.5, + ), + components.Slider( + label="Image Guidance scale", + minimum=1, + maximum=5, + value=1.5, + step=0.5, + ), + ], + "outputs": components.Image( + label="Generated Image", render=False, type="pil" + ), + "preprocess": lambda prompt, + n_prompt, + image, + num_inf_steps, + g_scale, + img_g_scale: { + "prompt": prompt, + "image": Image.open(image).resize((768, 768)), + "negative_prompt": n_prompt, + "num_inference_steps": num_inf_steps, + "guidance_scale": g_scale, + "image_guidance_scale": img_g_scale, + }, + "postprocess": lambda r: r["images"][0], + } + elif hasattr(diffusers, "StableDiffusionUpscalePipeline") and isinstance( + pipeline, diffuser_pipelines.StableDiffusionUpscalePipeline + ): + pipeline_info = { + "inputs": [ + components.Textbox(label="Prompt", render=False), + components.Textbox(label="Negative prompt", render=False), + components.Image(type="filepath", label="Image", render=False), + components.Slider( + label="Number of inference steps", + minimum=1, + maximum=500, + value=50, + step=1, + ), + components.Slider( + label="Guidance scale", + minimum=1, + maximum=20, + value=7.5, + step=0.5, + ), + components.Slider( + label="Noise level", minimum=1, maximum=100, value=20, step=1 + ), + ], + "outputs": components.Image( + label="Generated Image", render=False, type="pil" + ), + "preprocess": lambda prompt, + n_prompt, + image, + num_inf_steps, + g_scale, + noise_level: { + "prompt": prompt, + "image": Image.open(image).resize((768, 768)), + "negative_prompt": n_prompt, + "num_inference_steps": num_inf_steps, + "guidance_scale": g_scale, + "noise_level": noise_level, + }, + "postprocess": lambda r: r["images"][0], + } + else: + raise ValueError(f"Unsupported pipeline type: {type(pipeline)}") + else: - raise ValueError(f"Unsupported pipeline type: {type(pipeline)}") + raise ValueError( + "pipeline must be a transformers.pipeline of diffusers.pipeline" + ) # define the function that will be called by the Interface def fn(*params): data = pipeline_info["preprocess"](*params) - # special cases that needs to be handled differently - if isinstance( - pipeline, - ( - pipelines.text_classification.TextClassificationPipeline, - pipelines.text2text_generation.Text2TextGenerationPipeline, - pipelines.text2text_generation.TranslationPipeline, - ), - ): - data = pipeline(*data) - else: + if "tranformers.pipelines" in str(type(pipeline)): + from transformers import pipelines + + # special cases that needs to be handled differently + if isinstance( + pipeline, + ( + pipelines.text_classification.TextClassificationPipeline, + pipelines.text2text_generation.Text2TextGenerationPipeline, + pipelines.text2text_generation.TranslationPipeline, + ), + ): + data = pipeline(*data) + else: + data = pipeline(**data) + # special case for object-detection + # original input image sent to postprocess function + if isinstance( + pipeline, + pipelines.object_detection.ObjectDetectionPipeline, + ): + output = pipeline_info["postprocess"](data, params[0]) + else: + output = pipeline_info["postprocess"](data) + return output + + elif "diffusers.pipelines" in str(type(pipeline)): data = pipeline(**data) - # special case for object-detection - # original input image sent to postprocess function - if isinstance( - pipeline, - pipelines.object_detection.ObjectDetectionPipeline, - ): - output = pipeline_info["postprocess"](data, params[0]) - else: output = pipeline_info["postprocess"](data) - return output + return output interface_info = pipeline_info.copy() interface_info["fn"] = fn @@ -289,6 +587,10 @@ def fn(*params): del interface_info["postprocess"] # define the title/description of the Interface - interface_info["title"] = pipeline.model.__class__.__name__ if not isinstance(pipeline, DiffusionPipeline) else pipeline.__class__.__name__ + interface_info["title"] = ( + pipeline.model.__class__.__name__ + if "diffusers.pipelines" not in str(type(pipeline)) + else pipeline.__class__.__name__ + ) return interface_info