|
13 | 13 | # limitations under the License. |
14 | 14 |
|
15 | 15 | import re |
16 | | -from dataclasses import dataclass |
| 16 | +from dataclasses import dataclass, asdict |
17 | 17 | from typing import Any, Dict, List, Optional, Tuple, Type, Union |
18 | 18 |
|
19 | 19 | from ..utils.import_utils import is_torch_available |
|
27 | 27 | class ComponentSpec: |
28 | 28 | """Specification for a pipeline component.""" |
29 | 29 | name: str |
30 | | - # YiYi NOTE: is type_hint a good fild name? it is the actual class, will be used to create the default instance |
31 | | - type_hint: Type |
| 30 | + type_hint: Type # YiYi Notes: change to component_type? |
32 | 31 | description: Optional[str] = None |
33 | 32 | config: Optional[FrozenDict[str, Any]] = None # you can specific default config to create a default component if it is a stateless class like scheduler, guider or image processor |
34 | 33 | repo: Optional[Union[str, List[str]]] = None |
35 | 34 | subfolder: Optional[str] = None |
36 | | - revision: Optional[str] = None |
37 | | - variant: Optional[str] = None |
| 35 | + |
| 36 | + def create(self, **kwargs) -> Any: |
| 37 | + """ |
| 38 | + Create the component based on the config and additional kwargs. |
| 39 | + |
| 40 | + Args: |
| 41 | + **kwargs: Additional arguments to pass to the component's __init__ method |
| 42 | + |
| 43 | + Returns: |
| 44 | + The created component |
| 45 | + """ |
| 46 | + if self.config is not None: |
| 47 | + init_kwargs = self.config |
| 48 | + else: |
| 49 | + init_kwargs = {} |
| 50 | + return self.type_hint(**init_kwargs, **kwargs) |
| 51 | + |
| 52 | + def load(self, **kwargs) -> Any: |
| 53 | + return self.to_load_spec().load(**kwargs) |
| 54 | + |
| 55 | + def to_load_spec(self) -> "ComponentLoadSpec": |
| 56 | + """Convert to a ComponentLoadSpec for storage in ComponentsManager.""" |
| 57 | + return ComponentLoadSpec.from_component_spec(self) |
38 | 58 |
|
39 | 59 | @dataclass |
40 | 60 | class ComponentLoadSpec: |
41 | 61 | type_hint: type |
42 | 62 | repo: Optional[str] = None |
43 | 63 | subfolder: Optional[str] = None |
44 | | - revision: Optional[str] = None |
45 | | - variant: Optional[str] = None |
46 | 64 |
|
| 65 | + def load(self, **kwargs) -> Any: |
| 66 | + """Load the component from the repository.""" |
| 67 | + repo = kwargs.pop("repo", self.repo) |
| 68 | + subfolder = kwargs.pop("subfolder", self.subfolder) |
| 69 | + |
| 70 | + return self.type_hint.from_pretrained(repo, subfolder=subfolder, **kwargs) |
| 71 | + |
| 72 | + |
47 | 73 | @classmethod |
48 | 74 | def from_component_spec(cls, component_spec: ComponentSpec): |
49 | | - return cls(type_hint=component_spec.type_hint, repo=component_spec.repo, subfolder=component_spec.subfolder, revision=component_spec.revision, variant=component_spec.variant) |
| 75 | + return cls(type_hint=component_spec.type_hint, repo=component_spec.repo, subfolder=component_spec.subfolder) |
| 76 | + |
50 | 77 |
|
51 | 78 | @dataclass |
52 | 79 | class ConfigSpec: |
53 | 80 | """Specification for a pipeline configuration parameter.""" |
54 | 81 | name: str |
55 | 82 | value: Any |
56 | 83 | description: Optional[str] = None |
57 | | - repo: Optional[Union[str, List[str]]] = None |
| 84 | + repo: Optional[Union[str, List[str]]] = None #YiYi Notes: not sure if this field is needed |
58 | 85 |
|
59 | 86 | @dataclass |
60 | 87 | class InputParam: |
|
0 commit comments