Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions src/diffusers/modular_pipelines/modular_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,15 @@
from ..utils.hub_utils import load_or_create_model_card, populate_model_card
from .components_manager import ComponentsManager
from .modular_pipeline_utils import (
MODULAR_MODEL_CARD_TEMPLATE,
ComponentSpec,
ConfigSpec,
InputParam,
InsertableDict,
OutputParam,
format_components,
format_configs,
generate_modular_model_card_content,
make_doc_string,
)

Expand Down Expand Up @@ -1753,9 +1755,19 @@ def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub:
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id

# Generate modular pipeline card content
card_content = generate_modular_model_card_content(self.blocks)

# Create a new empty model card and eventually tag it
model_card = load_or_create_model_card(repo_id, token=token, is_pipeline=True)
model_card = populate_model_card(model_card)
model_card = load_or_create_model_card(
repo_id,
token=token,
is_pipeline=True,
model_description=MODULAR_MODEL_CARD_TEMPLATE.format(**card_content),
is_modular=True,
)
model_card = populate_model_card(model_card, tags=card_content["tags"])

model_card.save(os.path.join(save_directory, "README.md"))

# YiYi TODO: maybe order the json file to make it more readable: configs first, then components
Expand Down
199 changes: 199 additions & 0 deletions src/diffusers/modular_pipelines/modular_pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,30 @@

logger = logging.get_logger(__name__) # pylint: disable=invalid-name

# Template for modular pipeline model card description with placeholders
MODULAR_MODEL_CARD_TEMPLATE = """{model_description}

## Example Usage

[TODO]

## Pipeline Architecture

This modular pipeline is composed of the following blocks:

{blocks_description} {trigger_inputs_section}

## Model Components

{components_description} {configs_section}

## Input/Output Specification

### Inputs {inputs_description}

### Outputs {outputs_description}
"""


class InsertableDict(OrderedDict):
def insert(self, key, value, index):
Expand Down Expand Up @@ -916,3 +940,178 @@ def make_doc_string(
output += format_output_params(outputs, indent_level=2)

return output


def generate_modular_model_card_content(blocks) -> Dict[str, Any]:
"""
Generate model card content for a modular pipeline.

This function creates a comprehensive model card with descriptions of the pipeline's architecture, components,
configurations, inputs, and outputs.

Args:
blocks: The pipeline's blocks object containing all pipeline specifications

Returns:
Dict[str, Any]: A dictionary containing formatted content sections:
- pipeline_name: Name of the pipeline
- model_description: Overall description with pipeline type
- blocks_description: Detailed architecture of blocks
- components_description: List of required components
- configs_section: Configuration parameters section
- inputs_description: Input parameters specification
- outputs_description: Output parameters specification
- trigger_inputs_section: Conditional execution information
- tags: List of relevant tags for the model card
"""
blocks_class_name = blocks.__class__.__name__
pipeline_name = blocks_class_name.replace("Blocks", " Pipeline")
description = getattr(blocks, "description", "A modular diffusion pipeline.")

# generate blocks architecture description
blocks_desc_parts = []
sub_blocks = getattr(blocks, "sub_blocks", None) or {}
if sub_blocks:
for i, (name, block) in enumerate(sub_blocks.items()):
block_class = block.__class__.__name__
block_desc = block.description.split("\n")[0] if getattr(block, "description", "") else ""
blocks_desc_parts.append(f"{i + 1}. **{name}** (`{block_class}`)")
if block_desc:
blocks_desc_parts.append(f" - {block_desc}")

# add sub-blocks if any
if hasattr(block, "sub_blocks") and block.sub_blocks:
for sub_name, sub_block in block.sub_blocks.items():
sub_class = sub_block.__class__.__name__
sub_desc = sub_block.description.split("\n")[0] if getattr(sub_block, "description", "") else ""
blocks_desc_parts.append(f" - *{sub_name}*: `{sub_class}`")
if sub_desc:
blocks_desc_parts.append(f" - {sub_desc}")

blocks_description = "\n".join(blocks_desc_parts) if blocks_desc_parts else "No blocks defined."

components = getattr(blocks, "expected_components", [])
if components:
components_str = format_components(components, indent_level=0, add_empty_lines=False)
# remove the "Components:" header since template has its own
components_description = components_str.replace("Components:\n", "").strip()
if components_description:
# Convert to enumerated list
lines = [line.strip() for line in components_description.split("\n") if line.strip()]
enumerated_lines = [f"{i + 1}. {line}" for i, line in enumerate(lines)]
components_description = "\n".join(enumerated_lines)
else:
components_description = "No specific components required."
else:
components_description = "No specific components required. Components can be loaded dynamically."

configs = getattr(blocks, "expected_configs", [])
configs_section = ""
if configs:
configs_str = format_configs(configs, indent_level=0, add_empty_lines=False)
configs_description = configs_str.replace("Configs:\n", "").strip()
if configs_description:
configs_section = f"\n\n## Configuration Parameters\n\n{configs_description}"

inputs = blocks.inputs
outputs = blocks.outputs

# format inputs as markdown list
inputs_parts = []
required_inputs = [inp for inp in inputs if inp.required]
optional_inputs = [inp for inp in inputs if not inp.required]

if required_inputs:
inputs_parts.append("**Required:**\n")
for inp in required_inputs:
if hasattr(inp.type_hint, "__name__"):
type_str = inp.type_hint.__name__
elif inp.type_hint is not None:
type_str = str(inp.type_hint).replace("typing.", "")
else:
type_str = "Any"
desc = inp.description or "No description provided"
inputs_parts.append(f"- `{inp.name}` (`{type_str}`): {desc}")

if optional_inputs:
if required_inputs:
inputs_parts.append("")
inputs_parts.append("**Optional:**\n")
for inp in optional_inputs:
if hasattr(inp.type_hint, "__name__"):
type_str = inp.type_hint.__name__
elif inp.type_hint is not None:
type_str = str(inp.type_hint).replace("typing.", "")
else:
type_str = "Any"
desc = inp.description or "No description provided"
default_str = f", default: `{inp.default}`" if inp.default is not None else ""
inputs_parts.append(f"- `{inp.name}` (`{type_str}`){default_str}: {desc}")

inputs_description = "\n".join(inputs_parts) if inputs_parts else "No specific inputs defined."

# format outputs as markdown list
outputs_parts = []
for out in outputs:
if hasattr(out.type_hint, "__name__"):
type_str = out.type_hint.__name__
elif out.type_hint is not None:
type_str = str(out.type_hint).replace("typing.", "")
else:
type_str = "Any"
desc = out.description or "No description provided"
outputs_parts.append(f"- `{out.name}` (`{type_str}`): {desc}")

outputs_description = "\n".join(outputs_parts) if outputs_parts else "Standard pipeline outputs."

trigger_inputs_section = ""
if hasattr(blocks, "trigger_inputs") and blocks.trigger_inputs:
trigger_inputs_list = sorted([t for t in blocks.trigger_inputs if t is not None])
if trigger_inputs_list:
trigger_inputs_str = ", ".join(f"`{t}`" for t in trigger_inputs_list)
trigger_inputs_section = f"""
### Conditional Execution

This pipeline contains blocks that are selected at runtime based on inputs:
- **Trigger Inputs**: {trigger_inputs_str}
"""

# generate tags based on pipeline characteristics
tags = ["modular-diffusers", "diffusers"]

if hasattr(blocks, "model_name") and blocks.model_name:
tags.append(blocks.model_name)

if hasattr(blocks, "trigger_inputs") and blocks.trigger_inputs:
triggers = blocks.trigger_inputs
if any(t in triggers for t in ["mask", "mask_image"]):
tags.append("inpainting")
if any(t in triggers for t in ["image", "image_latents"]):
tags.append("image-to-image")
if any(t in triggers for t in ["control_image", "controlnet_cond"]):
tags.append("controlnet")
if not any(t in triggers for t in ["image", "mask", "image_latents", "mask_image"]):
tags.append("text-to-image")
else:
tags.append("text-to-image")

block_count = len(blocks.sub_blocks)
model_description = f"""This is a modular diffusion pipeline built with 🧨 Diffusers' modular pipeline framework.

**Pipeline Type**: {blocks_class_name}

**Description**: {description}

This pipeline uses a {block_count}-block architecture that can be customized and extended."""

return {
"pipeline_name": pipeline_name,
"model_description": model_description,
"blocks_description": blocks_description,
"components_description": components_description,
"configs_section": configs_section,
"inputs_description": inputs_description,
"outputs_description": outputs_description,
"trigger_inputs_section": trigger_inputs_section,
"tags": tags,
}
15 changes: 11 additions & 4 deletions src/diffusers/utils/hub_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def load_or_create_model_card(
license: Optional[str] = None,
widget: Optional[List[dict]] = None,
inference: Optional[bool] = None,
is_modular: bool = False,
) -> ModelCard:
"""
Loads or creates a model card.
Expand All @@ -131,6 +132,8 @@ def load_or_create_model_card(
widget (`List[dict]`, *optional*): Widget to accompany a gallery template.
inference: (`bool`, optional): Whether to turn on inference widget. Helpful when using
`load_or_create_model_card` from a training script.
is_modular: (`bool`, optional): Boolean flag to denote if the model card is for a modular pipeline.
When True, uses model_description as-is without additional template formatting.
"""
if not is_jinja_available():
raise ValueError(
Expand Down Expand Up @@ -159,10 +162,14 @@ def load_or_create_model_card(
)
else:
card_data = ModelCardData()
component = "pipeline" if is_pipeline else "model"
if model_description is None:
model_description = f"This is the model card of a 🧨 diffusers {component} that has been pushed on the Hub. This model card has been automatically generated."
model_card = ModelCard.from_template(card_data, model_description=model_description)
if is_modular and model_description is not None:
model_card = ModelCard(model_description)
model_card.data = card_data
else:
component = "pipeline" if is_pipeline else "model"
if model_description is None:
model_description = f"This is the model card of a 🧨 diffusers {component} that has been pushed on the Hub. This model card has been automatically generated."
model_card = ModelCard.from_template(card_data, model_description=model_description)

return model_card

Expand Down
Loading
Loading