Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 26 additions & 7 deletions docs/source/api/models.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,32 @@ Diffusers contains pretrained models for popular algorithms and modules for crea
The primary function of these models is to denoise an input sample, by modeling the distribution $p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_t)$.
The models are built on the base class ['ModelMixin'] that is a `torch.nn.module` with basic functionality for saving and loading models both locally and from the HuggingFace hub.

## API
## ModelMixin
[[autodoc]] ModelMixin

Models should provide the `def forward` function and initialization of the model.
All saving, loading, and utilities should be in the base ['ModelMixin'] class.
## UNet2DOutput
[[autodoc]] models.unet_2d.UNet2DOutput

## Examples
## UNet2DModel
[[autodoc]] UNet2DModel

- The ['UNetModel'] was proposed in [TODO](https://arxiv.org/) and has been used in paper1, paper2, paper3.
- Extensions of the ['UNetModel'] include the ['UNetGlideModel'] that uses attention and timestep embeddings for the [GLIDE](https://arxiv.org/abs/2112.10741) paper, the ['UNetGradTTS'] model from this [paper](https://arxiv.org/abs/2105.06337) for text-to-speech, ['UNetLDMModel'] for latent-diffusion models in this [paper](https://arxiv.org/abs/2112.10752), and the ['TemporalUNet'] used for time-series prediciton in this reinforcement learning [paper](https://arxiv.org/abs/2205.09991).
- TODO: mention VAE / SDE score estimation
## UNet2DConditionOutput
[[autodoc]] models.unet_2d_condition.UNet2DConditionOutput

## UNet2DConditionModel
[[autodoc]] UNet2DConditionModel

## DecoderOutput
[[autodoc]] models.vae.DecoderOutput

## VQEncoderOutput
[[autodoc]] models.vae.VQEncoderOutput

## VQModel
[[autodoc]] VQModel

## AutoencoderKLOutput
[[autodoc]] models.vae.AutoencoderKLOutput

## AutoencoderKL
[[autodoc]] AutoencoderKL
78 changes: 13 additions & 65 deletions src/diffusers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,27 +117,12 @@ class ModelMixin(torch.nn.Module):
Base class for all models.

[`ModelMixin`] takes care of storing the configuration of the models and handles methods for loading, downloading
and saving models as well as a few methods common to all models to:
and saving models.

- resize the input embeddings,
- prune heads in the self-attention heads.
Class attributes:

Class attributes (overridden by derived classes):

- **config_class** ([`ConfigMixin`]) -- A subclass of [`ConfigMixin`] to use as configuration class for this
model architecture.
- **load_tf_weights** (`Callable`) -- A python *method* for loading a TensorFlow checkpoint in a PyTorch model,
taking as arguments:

- **model** ([`ModelMixin`]) -- An instance of the model on which to load the TensorFlow checkpoint.
- **config** ([`PreTrainedConfigMixin`]) -- An instance of the configuration associated to the model.
- **path** (`str`) -- A path to the TensorFlow checkpoint.

- **base_model_prefix** (`str`) -- A string indicating the attribute associated to the base model in derived
classes of the same architecture adding modules on top of the base model.
- **is_parallelizable** (`bool`) -- A flag indicating whether this model supports model parallelization.
- **main_input_name** (`str`) -- The name of the principal input to the model (often `input_ids` for NLP
models, `pixel_values` for vision models and `input_values` for speech models).
- **config_name** ([`str`]) -- A filename under which the model should be stored when calling
[`~modeling_utils.ModelMixin.save_pretrained`].
"""
config_name = CONFIG_NAME
_automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
Expand All @@ -150,11 +135,10 @@ def save_pretrained(
save_directory: Union[str, os.PathLike],
is_main_process: bool = True,
save_function: Callable = torch.save,
**kwargs,
):
"""
Save a model and its configuration file to a directory, so that it can be re-loaded using the
`[`~ModelMixin.from_pretrained`]` class method.
`[`~modeling_utils.ModelMixin.from_pretrained`]` class method.

Arguments:
save_directory (`str` or `os.PathLike`):
Expand All @@ -166,9 +150,6 @@ def save_pretrained(
save_function (`Callable`):
The function to use to save the state dictionary. Useful on distributed training like TPUs when one
need to replace `torch.save` by another method.

kwargs:
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
"""
if os.path.isfile(save_directory):
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
Expand Down Expand Up @@ -224,34 +205,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
- A path to a *directory* containing model weights saved using [`~ModelMixin.save_pretrained`],
e.g., `./my_model_directory/`.

config (`Union[ConfigMixin, str, os.PathLike]`, *optional*):
Can be either:

- an instance of a class derived from [`ConfigMixin`],
- a string or path valid as input to [`~ConfigMixin.from_pretrained`].

ConfigMixinuration for the model to use instead of an automatically loaded configuration.
ConfigMixinuration can be automatically loaded when:

- The model is a model provided by the library (loaded with the *model id* string of a pretrained
model).
- The model was saved using [`~ModelMixin.save_pretrained`] and is reloaded by supplying the save
directory.
- The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a
configuration JSON file named *config.json* is found in the directory.
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory in which a downloaded pretrained model configuration should be cached if the
standard cache should not be used.
from_tf (`bool`, *optional*, defaults to `False`):
Load the model weights from a TensorFlow checkpoint save file (see docstring of
`pretrained_model_name_or_path` argument).
from_flax (`bool`, *optional*, defaults to `False`):
Load the model weights from a Flax checkpoint save file (see docstring of
`pretrained_model_name_or_path` argument).
ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`):
Whether or not to raise an error if some of the weights from the checkpoint do not have the same size
as the weights of the model (if for instance, you are instantiating a model with 10 labels from a
checkpoint with 3 labels).
torch_dtype (`str` or `torch.dtype`, *optional*):
Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
will be automatically derived from the model's weights.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
Expand All @@ -267,7 +226,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
Whether or not to only look at local files (i.e., do not try to download the model).
use_auth_token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
when running `transformers-cli login` (stored in `~/.huggingface`).
when running `diffusers-cli login` (stored in `~/.huggingface`).
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
Expand All @@ -278,18 +237,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
Please refer to the mirror site for more information.

kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
`output_attentions=True`). Behaves differently depending on whether a `config` is provided or
automatically loaded:

- If a configuration is provided with `config`, `**kwargs` will be directly passed to the
underlying model's `__init__` method (we assume all relevant updates to the configuration have
already been done)
- If a configuration is not provided, `kwargs` will be first passed to the configuration class
initialization function ([`~ConfigMixin.from_pretrained`]). Each key of `kwargs` that corresponds
to a configuration attribute will be used to override said attribute with the supplied `kwargs`
value. Remaining keys that do not correspond to any configuration attribute will be passed to the
underlying model's `__init__` function.
Can be used to update the [`ConfigMixin`] of the model (after it being loaded).

<Tip>

Expand All @@ -299,8 +247,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P

<Tip>

Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to
use this method in a firewalled environment.
Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use
this method in a firewalled environment.

</Tip>

Expand Down Expand Up @@ -404,7 +352,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
f" directory containing a file named {WEIGHTS_NAME} or"
" \nCheckout your internet connection or see how to run the library in"
" offline mode at 'https://huggingface.co/docs/transformers/installation#offline-mode'."
" offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
)
except EnvironmentError:
raise EnvironmentError(
Expand Down
106 changes: 94 additions & 12 deletions src/diffusers/models/attention.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import math
from typing import Optional

import torch
import torch.nn.functional as F
Expand All @@ -10,16 +11,24 @@ class AttentionBlock(nn.Module):
An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
to the N-d case.
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
Uses three q, k, v linear layers to compute attention
Uses three q, k, v linear layers to compute attention.

Parameters:
channels (:obj:`int`): The number of channels in the input and output.
num_head_channels (:obj:`int`, *optional*):
The number of channels in each head. If None, then `num_heads` = 1.
num_groups (:obj:`int`, *optional*, defaults to 32): The number of groups to use for group norm.
rescale_output_factor (:obj:`float`, *optional*, defaults to 1.0): The factor to rescale the output by.
eps (:obj:`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm.
"""

def __init__(
self,
channels,
num_head_channels=None,
num_groups=32,
rescale_output_factor=1.0,
eps=1e-5,
channels: int,
num_head_channels: Optional[int] = None,
num_groups: int = 32,
rescale_output_factor: float = 1.0,
eps: float = 1e-5,
):
super().__init__()
self.channels = channels
Expand Down Expand Up @@ -86,10 +95,26 @@ def forward(self, hidden_states):
class SpatialTransformer(nn.Module):
"""
Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply
standard transformer action. Finally, reshape to image
standard transformer action. Finally, reshape to image.

Parameters:
in_channels (:obj:`int`): The number of channels in the input and output.
n_heads (:obj:`int`): The number of heads to use for multi-head attention.
d_head (:obj:`int`): The number of channels in each head.
depth (:obj:`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
dropout (:obj:`float`, *optional*, defaults to 0.1): The dropout probability to use.
context_dim (:obj:`int`, *optional*): The number of context dimensions to use.
"""

def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0.0, context_dim=None):
def __init__(
self,
in_channels: int,
n_heads: int,
d_head: int,
depth: int = 1,
dropout: float = 0.0,
context_dim: Optional[int] = None,
):
super().__init__()
self.n_heads = n_heads
self.d_head = d_head
Expand Down Expand Up @@ -127,7 +152,29 @@ def forward(self, x, context=None):


class BasicTransformerBlock(nn.Module):
def __init__(self, dim, n_heads, d_head, dropout=0.0, context_dim=None, gated_ff=True, checkpoint=True):
r"""
A basic Transformer block.

Parameters:
dim (:obj:`int`): The number of channels in the input and output.
n_heads (:obj:`int`): The number of heads to use for multi-head attention.
d_head (:obj:`int`): The number of channels in each head.
dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use.
context_dim (:obj:`int`, *optional*): The size of the context vector for cross attention.
gated_ff (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use a gated feed-forward network.
checkpoint (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use checkpointing.
"""

def __init__(
self,
dim: int,
n_heads: int,
d_head: int,
dropout=0.0,
context_dim: Optional[int] = None,
gated_ff: bool = True,
checkpoint: bool = True,
):
super().__init__()
self.attn1 = CrossAttention(
query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout
Expand All @@ -154,7 +201,21 @@ def forward(self, x, context=None):


class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
r"""
A cross attention layer.

Parameters:
query_dim (:obj:`int`): The number of channels in the query.
context_dim (:obj:`int`, *optional*):
The number of channels in the context. If not given, defaults to `query_dim`.
heads (:obj:`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
dim_head (:obj:`int`, *optional*, defaults to 64): The number of channels in each head.
Comment on lines +211 to +212
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should consolidate these names across modules in the future. The are called n_heads and d_head in BasicTransformerBlock.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Leaving for future PR for now due to time constraint

dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use.
"""

def __init__(
self, query_dim: int, context_dim: Optional[int] = None, heads: int = 8, dim_head: int = 64, dropout: int = 0.0
):
super().__init__()
inner_dim = dim_head * heads
context_dim = context_dim if context_dim is not None else query_dim
Expand Down Expand Up @@ -228,7 +289,20 @@ def _attention(self, query, key, value, sequence_length, dim):


class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
r"""
A feed-forward layer.

Parameters:
dim (:obj:`int`): The number of channels in the input.
dim_out (:obj:`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
mult (:obj:`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
glu (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use GLU activation.
dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use.
"""

def __init__(
self, dim: int, dim_out: Optional[int] = None, mult: int = 4, glu: bool = False, dropout: float = 0.0
):
super().__init__()
inner_dim = int(dim * mult)
dim_out = dim_out if dim_out is not None else dim
Expand All @@ -242,7 +316,15 @@ def forward(self, x):

# feedforward
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out):
r"""
A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.

Parameters:
dim_in (:obj:`int`): The number of channels in the input.
dim_out (:obj:`int`): The number of channels in the output.
"""

def __init__(self, dim_in: int, dim_out: int):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2)

Expand Down
13 changes: 9 additions & 4 deletions src/diffusers/models/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,12 @@


def get_timestep_embedding(
timesteps, embedding_dim, flip_sin_to_cos=False, downscale_freq_shift=1, scale=1, max_period=10000
timesteps: torch.Tensor,
embedding_dim: int,
flip_sin_to_cos: bool = False,
downscale_freq_shift: float = 1,
scale: float = 1,
max_period: int = 10000,
):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
Expand Down Expand Up @@ -55,7 +60,7 @@ def get_timestep_embedding(


class TimestepEmbedding(nn.Module):
def __init__(self, channel, time_embed_dim, act_fn="silu"):
def __init__(self, channel: int, time_embed_dim: int, act_fn: str = "silu"):
super().__init__()

self.linear_1 = nn.Linear(channel, time_embed_dim)
Expand All @@ -75,7 +80,7 @@ def forward(self, sample):


class Timesteps(nn.Module):
def __init__(self, num_channels, flip_sin_to_cos, downscale_freq_shift):
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float):
super().__init__()
self.num_channels = num_channels
self.flip_sin_to_cos = flip_sin_to_cos
Expand All @@ -94,7 +99,7 @@ def forward(self, timesteps):
class GaussianFourierProjection(nn.Module):
"""Gaussian Fourier embeddings for noise levels."""

def __init__(self, embedding_size=256, scale=1.0):
def __init__(self, embedding_size: int = 256, scale: float = 1.0):
super().__init__()
self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)

Expand Down
Loading