Skip to content

Commit a1eb9ee

Browse files
committed
make component spec loadable: add load/create method
1 parent d456a97 commit a1eb9ee

File tree

2 files changed

+51
-35
lines changed

2 files changed

+51
-35
lines changed

src/diffusers/pipelines/components_manager.py

Lines changed: 15 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -241,19 +241,6 @@ def __init__(self):
241241
self.model_hooks = None
242242
self._auto_offload_enabled = False
243243

244-
245-
def load_component(self, spec: Union[ComponentSpec, ComponentLoadSpec], **kwargs):
246-
module_class = spec.type_hint
247-
248-
249-
if spec.revision is not None:
250-
kwargs["revision"] = spec.revision
251-
if spec.variant is not None:
252-
kwargs["variant"] = spec.variant
253-
254-
component = module_class.from_pretrained(spec.repo, subfolder=spec.subfolder, **kwargs)
255-
return component
256-
257244
def add(self, name, component, collection: Optional[str] = None, load_spec: Optional[ComponentLoadSpec] = None):
258245
if name in self.components:
259246
logger.warning(f"Overriding existing component '{name}' in ComponentsManager")
@@ -284,21 +271,23 @@ def add_with_spec(self, name, spec:Union[ComponentSpec, ComponentLoadSpec], coll
284271
**kwargs: Additional arguments to pass to the component loader
285272
"""
286273

287-
if isinstance(spec, ComponentSpec):
288-
if spec.config is not None:
289-
component = spec.type_hint(**spec.config)
290-
self.add(name, component, collection=collection, load_spec=ComponentLoadSpec.from_component_spec(spec))
291-
return
292-
293-
spec = ComponentLoadSpec.from_component_spec(spec)
294-
274+
if isinstance(spec, ComponentSpec) and spec.repo is None:
275+
component = spec.create(**kwargs)
276+
self.add(name, component, collection=collection)
277+
elif isinstance(spec, ComponentSpec):
278+
load_spec = spec.to_load_spec()
279+
elif isinstance(spec, ComponentLoadSpec):
280+
load_spec = spec
281+
else:
282+
raise ValueError(f"Invalid spec type: {type(spec)}")
283+
295284
for k, v in self.components_specs.items():
296-
if v == spec and not force_add:
297-
logger.warning(f"will not add {name} to ComponentsManager, as {k} already exists with same spec.Please use force_add=True to add it.")
285+
if v == load_spec and not force_add:
286+
logger.warning(f"{name} is not added to ComponentsManager, because `{k}` already exists with same spec. Please use `force_add=True` to add it.")
298287
return
299-
300-
component = self.load_component(spec, **kwargs)
301-
self.add(name, component, collection=collection, load_spec=spec)
288+
289+
component = load_spec.load(**kwargs)
290+
self.add(name, component, collection=collection, load_spec=load_spec)
302291

303292
def remove(self, name):
304293

src/diffusers/pipelines/modular_pipeline_utils.py

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
import re
16-
from dataclasses import dataclass
16+
from dataclasses import dataclass, asdict
1717
from typing import Any, Dict, List, Optional, Tuple, Type, Union
1818

1919
from ..utils.import_utils import is_torch_available
@@ -27,34 +27,61 @@
2727
class ComponentSpec:
2828
"""Specification for a pipeline component."""
2929
name: str
30-
# YiYi NOTE: is type_hint a good fild name? it is the actual class, will be used to create the default instance
31-
type_hint: Type
30+
type_hint: Type # YiYi Notes: change to component_type?
3231
description: Optional[str] = None
3332
config: Optional[FrozenDict[str, Any]] = None # you can specific default config to create a default component if it is a stateless class like scheduler, guider or image processor
3433
repo: Optional[Union[str, List[str]]] = None
3534
subfolder: Optional[str] = None
36-
revision: Optional[str] = None
37-
variant: Optional[str] = None
35+
36+
def create(self, **kwargs) -> Any:
37+
"""
38+
Create the component based on the config and additional kwargs.
39+
40+
Args:
41+
**kwargs: Additional arguments to pass to the component's __init__ method
42+
43+
Returns:
44+
The created component
45+
"""
46+
if self.config is not None:
47+
init_kwargs = self.config
48+
else:
49+
init_kwargs = {}
50+
return self.type_hint(**init_kwargs, **kwargs)
51+
52+
def load(self, **kwargs) -> Any:
53+
return self.to_load_spec().load(**kwargs)
54+
55+
def to_load_spec(self) -> "ComponentLoadSpec":
56+
"""Convert to a ComponentLoadSpec for storage in ComponentsManager."""
57+
return ComponentLoadSpec.from_component_spec(self)
3858

3959
@dataclass
4060
class ComponentLoadSpec:
4161
type_hint: type
4262
repo: Optional[str] = None
4363
subfolder: Optional[str] = None
44-
revision: Optional[str] = None
45-
variant: Optional[str] = None
4664

65+
def load(self, **kwargs) -> Any:
66+
"""Load the component from the repository."""
67+
repo = kwargs.pop("repo", self.repo)
68+
subfolder = kwargs.pop("subfolder", self.subfolder)
69+
70+
return self.type_hint.from_pretrained(repo, subfolder=subfolder, **kwargs)
71+
72+
4773
@classmethod
4874
def from_component_spec(cls, component_spec: ComponentSpec):
49-
return cls(type_hint=component_spec.type_hint, repo=component_spec.repo, subfolder=component_spec.subfolder, revision=component_spec.revision, variant=component_spec.variant)
75+
return cls(type_hint=component_spec.type_hint, repo=component_spec.repo, subfolder=component_spec.subfolder)
76+
5077

5178
@dataclass
5279
class ConfigSpec:
5380
"""Specification for a pipeline configuration parameter."""
5481
name: str
5582
value: Any
5683
description: Optional[str] = None
57-
repo: Optional[Union[str, List[str]]] = None
84+
repo: Optional[Union[str, List[str]]] = None #YiYi Notes: not sure if this field is needed
5885

5986
@dataclass
6087
class InputParam:

0 commit comments

Comments
 (0)