Skip to content

Commit

Permalink
Add support for StableDiffusionPipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
shubhamofbce committed Feb 3, 2024
1 parent 2e6672c commit 5528274
Showing 1 changed file with 28 additions and 4 deletions.
32 changes: 28 additions & 4 deletions gradio/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,17 @@ def load_from_pipeline(pipeline: pipelines.base.Pipeline) -> dict:
(dict): a dictionary of kwargs that can be used to construct an Interface object
"""
try:
import transformers
import transformers, diffusers
from transformers import pipelines
from diffusers import DiffusionPipeline
from diffusers import pipelines as diffuser_pipelines

except ImportError as ie:
raise ImportError(
"transformers not installed. Please try `pip install transformers`"
) from ie
if not isinstance(pipeline, pipelines.base.Pipeline):
raise ValueError("pipeline must be a transformers.Pipeline")
if not ((isinstance(pipeline, pipelines.base.Pipeline)) or (isinstance(pipeline, DiffusionPipeline))):
raise ValueError("pipeline must be a transformers.Pipeline or DiffusionPipeline")

# Handle the different pipelines. The has_attr() checks to make sure the pipeline exists in the
# version of the transformers library that the user has installed.
Expand Down Expand Up @@ -230,6 +233,27 @@ def load_from_pipeline(pipeline: pipelines.base.Pipeline) -> dict:
],
),
}
# Diffuser pipelines
elif hasattr(diffusers, "StableDiffusionPipeline") and isinstance(
pipeline, diffuser_pipelines.StableDiffusionPipeline
):
# TODO: complete this
pipeline_info = {
"inputs": [
components.Textbox(label="Prompt", render=False),
components.Textbox(label="Negative prompt", render=False),
components.Slider(label="Number of inference steps", minimum=1, maximum=500, value=50, step=1),
components.Slider(label="Guidance scale", minimum=1, maximum=20, value=7.5, step=0.5)
],
"outputs": components.Image(label="Generated Image", render=False, type="pil"),
"preprocess": lambda prompt, n_prompt, num_inf_steps, g_scale: {
"prompt": prompt,
"negative_prompt": n_prompt,
"num_inference_steps": num_inf_steps,
"guidance_scale": g_scale
},
"postprocess": lambda r: r["images"][0],
}
else:
raise ValueError(f"Unsupported pipeline type: {type(pipeline)}")

Expand Down Expand Up @@ -265,6 +289,6 @@ def fn(*params):
del interface_info["postprocess"]

# define the title/description of the Interface
interface_info["title"] = pipeline.model.__class__.__name__
interface_info["title"] = pipeline.model.__class__.__name__ if not isinstance(pipeline, DiffusionPipeline) else pipeline.__class__.__name__

return interface_info

0 comments on commit 5528274

Please sign in to comment.