diff --git a/src/diffusers/models/activations.py b/src/diffusers/models/activations.py index 04c978403f41..46da899096c2 100644 --- a/src/diffusers/models/activations.py +++ b/src/diffusers/models/activations.py @@ -1,7 +1,15 @@ from torch import nn -def get_activation(act_fn): +def get_activation(act_fn: str) -> nn.Module: + """Helper function to get activation function from string. + + Args: + act_fn (str): Name of activation function. + + Returns: + nn.Module: Activation function. + """ if act_fn in ["swish", "silu"]: return nn.SiLU() elif act_fn == "mish": diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index ac66e2271c61..3972b438b076 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -14,7 +14,7 @@ # limitations under the License. from functools import partial -from typing import Optional +from typing import Optional, Tuple, Union import torch import torch.nn as nn @@ -38,9 +38,18 @@ class Upsample1D(nn.Module): option to use a convolution transpose. out_channels (`int`, optional): number of output channels. Defaults to `channels`. + name (`str`, default `conv`): + name of the upsampling 1D layer. """ - def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"): + def __init__( + self, + channels: int, + use_conv: bool = False, + use_conv_transpose: bool = False, + out_channels: Optional[int] = None, + name: str = "conv", + ): super().__init__() self.channels = channels self.out_channels = out_channels or channels @@ -54,7 +63,7 @@ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_chann elif use_conv: self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1) - def forward(self, inputs): + def forward(self, inputs: torch.Tensor) -> torch.Tensor: assert inputs.shape[1] == self.channels if self.use_conv_transpose: return self.conv(inputs) @@ -79,9 +88,18 @@ class Downsample1D(nn.Module): number of output channels. Defaults to `channels`. padding (`int`, default `1`): padding for the convolution. + name (`str`, default `conv`): + name of the downsampling 1D layer. """ - def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"): + def __init__( + self, + channels: int, + use_conv: bool = False, + out_channels: Optional[int] = None, + padding: int = 1, + name: str = "conv", + ): super().__init__() self.channels = channels self.out_channels = out_channels or channels @@ -96,7 +114,7 @@ def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name= assert self.channels == self.out_channels self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride) - def forward(self, inputs): + def forward(self, inputs: torch.Tensor) -> torch.Tensor: assert inputs.shape[1] == self.channels return self.conv(inputs) @@ -113,9 +131,18 @@ class Upsample2D(nn.Module): option to use a convolution transpose. out_channels (`int`, optional): number of output channels. Defaults to `channels`. + name (`str`, default `conv`): + name of the upsampling 2D layer. """ - def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"): + def __init__( + self, + channels: int, + use_conv: bool = False, + use_conv_transpose: bool = False, + out_channels: Optional[int] = None, + name: str = "conv", + ): super().__init__() self.channels = channels self.out_channels = out_channels or channels @@ -135,7 +162,7 @@ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_chann else: self.Conv2d_0 = conv - def forward(self, hidden_states, output_size=None, scale: float = 1.0): + def forward(self, hidden_states: torch.Tensor, output_size: Optional[int] = None, scale: float = 1.0): assert hidden_states.shape[1] == self.channels if self.use_conv_transpose: @@ -191,9 +218,18 @@ class Downsample2D(nn.Module): number of output channels. Defaults to `channels`. padding (`int`, default `1`): padding for the convolution. + name (`str`, default `conv`): + name of the downsampling 2D layer. """ - def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"): + def __init__( + self, + channels: int, + use_conv: bool = False, + out_channels: Optional[int] = None, + padding: int = 1, + name: str = "conv", + ): super().__init__() self.channels = channels self.out_channels = out_channels or channels @@ -246,7 +282,13 @@ class FirUpsample2D(nn.Module): kernel for the FIR filter. """ - def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)): + def __init__( + self, + channels: int = None, + out_channels: Optional[int] = None, + use_conv: bool = False, + fir_kernel: Tuple[int, int, int, int] = (1, 3, 3, 1), + ): super().__init__() out_channels = out_channels if out_channels else channels if use_conv: @@ -255,7 +297,14 @@ def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel= self.fir_kernel = fir_kernel self.out_channels = out_channels - def _upsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1): + def _upsample_2d( + self, + hidden_states: torch.Tensor, + weight: Optional[torch.Tensor] = None, + kernel: Optional[torch.FloatTensor] = None, + factor: int = 2, + gain: float = 1, + ) -> torch.Tensor: """Fused `upsample_2d()` followed by `Conv2d()`. Padding is performed only once at the beginning, not between the operations. The fused op is considerably more @@ -335,7 +384,7 @@ def _upsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1 return output - def forward(self, hidden_states): + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if self.use_conv: height = self._upsample_2d(hidden_states, self.Conv2d_0.weight, kernel=self.fir_kernel) height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1) @@ -359,7 +408,13 @@ class FirDownsample2D(nn.Module): kernel for the FIR filter. """ - def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)): + def __init__( + self, + channels: int = None, + out_channels: Optional[int] = None, + use_conv: bool = False, + fir_kernel: Tuple[int, int, int, int] = (1, 3, 3, 1), + ): super().__init__() out_channels = out_channels if out_channels else channels if use_conv: @@ -368,7 +423,14 @@ def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel= self.use_conv = use_conv self.out_channels = out_channels - def _downsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1): + def _downsample_2d( + self, + hidden_states: torch.Tensor, + weight: Optional[torch.Tensor] = None, + kernel: Optional[torch.FloatTensor] = None, + factor: int = 2, + gain: float = 1, + ) -> torch.Tensor: """Fused `Conv2d()` followed by `downsample_2d()`. Padding is performed only once at the beginning, not between the operations. The fused op is considerably more efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of @@ -422,7 +484,7 @@ def _downsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain return output - def forward(self, hidden_states): + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if self.use_conv: downsample_input = self._downsample_2d(hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel) hidden_states = downsample_input + self.Conv2d_0.bias.reshape(1, -1, 1, 1) @@ -434,14 +496,20 @@ def forward(self, hidden_states): # downsample/upsample layer used in k-upscaler, might be able to use FirDownsample2D/DirUpsample2D instead class KDownsample2D(nn.Module): - def __init__(self, pad_mode="reflect"): + r"""A 2D K-downsampling layer. + + Parameters: + pad_mode (`str`, *optional*, default to `"reflect"`): the padding mode to use. + """ + + def __init__(self, pad_mode: str = "reflect"): super().__init__() self.pad_mode = pad_mode kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]]) self.pad = kernel_1d.shape[1] // 2 - 1 self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False) - def forward(self, inputs): + def forward(self, inputs: torch.Tensor) -> torch.Tensor: inputs = F.pad(inputs, (self.pad,) * 4, self.pad_mode) weight = inputs.new_zeros([inputs.shape[1], inputs.shape[1], self.kernel.shape[0], self.kernel.shape[1]]) indices = torch.arange(inputs.shape[1], device=inputs.device) @@ -451,14 +519,20 @@ def forward(self, inputs): class KUpsample2D(nn.Module): - def __init__(self, pad_mode="reflect"): + r"""A 2D K-upsampling layer. + + Parameters: + pad_mode (`str`, *optional*, default to `"reflect"`): the padding mode to use. + """ + + def __init__(self, pad_mode: str = "reflect"): super().__init__() self.pad_mode = pad_mode kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]]) * 2 self.pad = kernel_1d.shape[1] // 2 - 1 self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False) - def forward(self, inputs): + def forward(self, inputs: torch.Tensor) -> torch.Tensor: inputs = F.pad(inputs, ((self.pad + 1) // 2,) * 4, self.pad_mode) weight = inputs.new_zeros([inputs.shape[1], inputs.shape[1], self.kernel.shape[0], self.kernel.shape[1]]) indices = torch.arange(inputs.shape[1], device=inputs.device) @@ -501,23 +575,23 @@ class ResnetBlock2D(nn.Module): def __init__( self, *, - in_channels, - out_channels=None, - conv_shortcut=False, - dropout=0.0, - temb_channels=512, - groups=32, - groups_out=None, - pre_norm=True, - eps=1e-6, - non_linearity="swish", - skip_time_act=False, - time_embedding_norm="default", # default, scale_shift, ada_group, spatial - kernel=None, - output_scale_factor=1.0, - use_in_shortcut=None, - up=False, - down=False, + in_channels: int, + out_channels: Optional[int] = None, + conv_shortcut: bool = False, + dropout: float = 0.0, + temb_channels: int = 512, + groups: int = 32, + groups_out: Optional[int] = None, + pre_norm: bool = True, + eps: float = 1e-6, + non_linearity: str = "swish", + skip_time_act: bool = False, + time_embedding_norm: str = "default", # default, scale_shift, ada_group, spatial + kernel: Optional[torch.FloatTensor] = None, + output_scale_factor: float = 1.0, + use_in_shortcut: Optional[bool] = None, + up: bool = False, + down: bool = False, conv_shortcut_bias: bool = True, conv_2d_out_channels: Optional[int] = None, ): @@ -667,7 +741,7 @@ def forward(self, input_tensor, temb, scale: float = 1.0): # unet_rl.py -def rearrange_dims(tensor): +def rearrange_dims(tensor: torch.Tensor) -> torch.Tensor: if len(tensor.shape) == 2: return tensor[:, :, None] if len(tensor.shape) == 3: @@ -681,16 +755,24 @@ def rearrange_dims(tensor): class Conv1dBlock(nn.Module): """ Conv1d --> GroupNorm --> Mish + + Parameters: + inp_channels (`int`): Number of input channels. + out_channels (`int`): Number of output channels. + kernel_size (`int` or `tuple`): Size of the convolving kernel. + n_groups (`int`, default `8`): Number of groups to separate the channels into. """ - def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8): + def __init__( + self, inp_channels: int, out_channels: int, kernel_size: Union[int, Tuple[int, int]], n_groups: int = 8 + ): super().__init__() self.conv1d = nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2) self.group_norm = nn.GroupNorm(n_groups, out_channels) self.mish = nn.Mish() - def forward(self, inputs): + def forward(self, inputs: torch.Tensor) -> torch.Tensor: intermediate_repr = self.conv1d(inputs) intermediate_repr = rearrange_dims(intermediate_repr) intermediate_repr = self.group_norm(intermediate_repr) @@ -701,7 +783,19 @@ def forward(self, inputs): # unet_rl.py class ResidualTemporalBlock1D(nn.Module): - def __init__(self, inp_channels, out_channels, embed_dim, kernel_size=5): + """ + Residual 1D block with temporal convolutions. + + Parameters: + inp_channels (`int`): Number of input channels. + out_channels (`int`): Number of output channels. + embed_dim (`int`): Embedding dimension. + kernel_size (`int` or `tuple`): Size of the convolving kernel. + """ + + def __init__( + self, inp_channels: int, out_channels: int, embed_dim: int, kernel_size: Union[int, Tuple[int, int]] = 5 + ): super().__init__() self.conv_in = Conv1dBlock(inp_channels, out_channels, kernel_size) self.conv_out = Conv1dBlock(out_channels, out_channels, kernel_size) @@ -713,7 +807,7 @@ def __init__(self, inp_channels, out_channels, embed_dim, kernel_size=5): nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity() ) - def forward(self, inputs, t): + def forward(self, inputs: torch.Tensor, t: torch.Tensor) -> torch.Tensor: """ Args: inputs : [ batch_size x inp_channels x horizon ] @@ -729,7 +823,9 @@ def forward(self, inputs, t): return out + self.residual_conv(inputs) -def upsample_2d(hidden_states, kernel=None, factor=2, gain=1): +def upsample_2d( + hidden_states: torch.Tensor, kernel: Optional[torch.FloatTensor] = None, factor: int = 2, gain: float = 1 +) -> torch.Tensor: r"""Upsample2D a batch of 2D images with the given filter. Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified @@ -766,7 +862,9 @@ def upsample_2d(hidden_states, kernel=None, factor=2, gain=1): return output -def downsample_2d(hidden_states, kernel=None, factor=2, gain=1): +def downsample_2d( + hidden_states: torch.Tensor, kernel: Optional[torch.FloatTensor] = None, factor: int = 2, gain: float = 1 +) -> torch.Tensor: r"""Downsample2D a batch of 2D images with the given filter. Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the @@ -801,7 +899,9 @@ def downsample_2d(hidden_states, kernel=None, factor=2, gain=1): return output -def upfirdn2d_native(tensor, kernel, up=1, down=1, pad=(0, 0)): +def upfirdn2d_native( + tensor: torch.Tensor, kernel: torch.Tensor, up: int = 1, down: int = 1, pad: Tuple[int, int] = (0, 0) +) -> torch.Tensor: up_x = up_y = up down_x = down_y = down pad_x0 = pad_y0 = pad[0] @@ -849,9 +949,14 @@ class TemporalConvLayer(nn.Module): """ Temporal convolutional layer that can be used for video (sequence of images) input Code mostly copied from: https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/models/multi_modal/video_synthesis/unet_sd.py#L1016 + + Parameters: + in_dim (`int`): Number of input channels. + out_dim (`int`): Number of output channels. + dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use. """ - def __init__(self, in_dim, out_dim=None, dropout=0.0): + def __init__(self, in_dim: int, out_dim: Optional[int] = None, dropout: float = 0.0): super().__init__() out_dim = out_dim or in_dim self.in_dim = in_dim @@ -884,7 +989,7 @@ def __init__(self, in_dim, out_dim=None, dropout=0.0): nn.init.zeros_(self.conv4[-1].weight) nn.init.zeros_(self.conv4[-1].bias) - def forward(self, hidden_states, num_frames=1): + def forward(self, hidden_states: torch.Tensor, num_frames: int = 1) -> torch.Tensor: hidden_states = ( hidden_states[None, :].reshape((-1, num_frames) + hidden_states.shape[1:]).permute(0, 2, 1, 3, 4) )