Skip to content

Commit

Permalink
More documentation for python code (#5)
Browse files Browse the repository at this point in the history
  • Loading branch information
RunDevelopment committed Nov 18, 2023
1 parent f9b46d4 commit f7bfd9a
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/spandrel/__helpers/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ def load_state_dict_from_file(self, path: str | Path) -> StateDict:
"""
Load the state dict of a model from the given file path.
State dicts are typically only useful to pass them into the `load` function of a specific architecture.
State dicts are typically only useful to pass them into the `load`
function of a specific architecture.
Throws a `ValueError` if the file extension is not supported.
"""
Expand Down
46 changes: 46 additions & 0 deletions src/spandrel/__helpers/model_descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,17 +49,56 @@ def __init__(
size: SizeRequirements | None = None,
):
self.model: T = model
"""
The model itself: a `torch.nn.Module` with weights loaded in.
The specific subclass of `torch.nn.Module` depends on the model architecture.
"""
self.state_dict: StateDict = state_dict
"""
The state dict of the model (weights and biases).
"""
self.architecture: str = architecture
"""
The name of the model architecture. E.g. "ESRGAN".
"""
self.tags: list[str] = tags
"""
A list of tags for the model, usually describing the size or model
parameters. E.g. "64nf" or "large".
Tags are specific to the architecture of the model. Some architectures
may not have any tags.
"""
self.supports_half: bool = supports_half
"""
Whether the model supports half precision (fp16).
"""
self.supports_bfloat16: bool = supports_bfloat16
"""
Whether the model supports bfloat16 precision.
"""

self.scale: int = scale
"""
The output scale of super resolution models. E.g. 4x, 2x, 1x.
Models that are not super resolution models (e.g. denoisers) have a
scale of 1.
"""
self.input_channels: int = input_channels
"""
The number of input image channels of the model. E.g. 3 for RGB, 1 for grayscale.
"""
self.output_channels: int = output_channels
"""
The number of output image channels of the model. E.g. 3 for RGB, 1 for grayscale.
"""

self.size: SizeRequirements = size or SizeRequirements()
"""
Size requirements for the input image. E.g. minimum size.
"""

self.model.load_state_dict(state_dict) # type: ignore

Expand Down Expand Up @@ -136,3 +175,10 @@ def __init__(
InpaintModelDescriptor,
RestorationModelDescriptor,
]
"""
A model descriptor is a loaded model with metadata. Metadata includes the
architecture, purpose, tags, and other information about the model.
The purpose of a model is described by the type of the model descriptor. E.g.
a super resolution model has a descriptor of type `SRModelDescriptor`.
"""
2 changes: 2 additions & 0 deletions src/spandrel/__helpers/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,8 @@ def visit(arch: ArchSupport):
def load(self, state_dict: StateDict) -> ModelDescriptor:
"""
Detects the architecture of the given state dict and loads it.
Throws an `UnsupportedModelError` if the model architecture is not supported.
"""

if "params_ema" in state_dict:
Expand Down

0 comments on commit f7bfd9a

Please sign in to comment.