Skip to content

Commit

Permalink
attention refactor: the trilogy (huggingface#3387)
Browse files Browse the repository at this point in the history
* Replace `AttentionBlock` with `Attention`

* use _from_deprecated_attn_block check re: @patrickvonplaten
  • Loading branch information
williamberman authored and dg845 committed May 21, 2023
1 parent 80c2e55 commit d749d57
Show file tree
Hide file tree
Showing 6 changed files with 235 additions and 248 deletions.
174 changes: 1 addition & 173 deletions src/diffusers/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,189 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Callable, Optional
from typing import Optional

import torch
import torch.nn.functional as F
from torch import nn

from ..utils import maybe_allow_in_graph
from ..utils.import_utils import is_xformers_available
from .attention_processor import Attention
from .embeddings import CombinedTimestepLabelEmbeddings


if is_xformers_available():
import xformers
import xformers.ops
else:
xformers = None


class AttentionBlock(nn.Module):
"""
An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
to the N-d case.
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
Uses three q, k, v linear layers to compute attention.
Parameters:
channels (`int`): The number of channels in the input and output.
num_head_channels (`int`, *optional*):
The number of channels in each head. If None, then `num_heads` = 1.
norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for group norm.
rescale_output_factor (`float`, *optional*, defaults to 1.0): The factor to rescale the output by.
eps (`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm.
"""

# IMPORTANT;TODO(Patrick, William) - this class will be deprecated soon. Do not use it anymore

def __init__(
self,
channels: int,
num_head_channels: Optional[int] = None,
norm_num_groups: int = 32,
rescale_output_factor: float = 1.0,
eps: float = 1e-5,
):
super().__init__()
self.channels = channels

self.num_heads = channels // num_head_channels if num_head_channels is not None else 1
self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=norm_num_groups, eps=eps, affine=True)

# define q,k,v as linear layers
self.query = nn.Linear(channels, channels)
self.key = nn.Linear(channels, channels)
self.value = nn.Linear(channels, channels)

self.rescale_output_factor = rescale_output_factor
self.proj_attn = nn.Linear(channels, channels, bias=True)

self._use_memory_efficient_attention_xformers = False
self._use_2_0_attn = True
self._attention_op = None

def reshape_heads_to_batch_dim(self, tensor, merge_head_and_batch=True):
batch_size, seq_len, dim = tensor.shape
head_size = self.num_heads
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
tensor = tensor.permute(0, 2, 1, 3)
if merge_head_and_batch:
tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
return tensor

def reshape_batch_dim_to_heads(self, tensor, unmerge_head_and_batch=True):
head_size = self.num_heads

if unmerge_head_and_batch:
batch_head_size, seq_len, dim = tensor.shape
batch_size = batch_head_size // head_size

tensor = tensor.reshape(batch_size, head_size, seq_len, dim)
else:
batch_size, _, seq_len, dim = tensor.shape

tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size, seq_len, dim * head_size)
return tensor

def set_use_memory_efficient_attention_xformers(
self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
):
if use_memory_efficient_attention_xformers:
if not is_xformers_available():
raise ModuleNotFoundError(
(
"Refer to https://github.com/facebookresearch/xformers for more information on how to install"
" xformers"
),
name="xformers",
)
elif not torch.cuda.is_available():
raise ValueError(
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
" only available for GPU "
)
else:
try:
# Make sure we can run the memory efficient attention
_ = xformers.ops.memory_efficient_attention(
torch.randn((1, 2, 40), device="cuda"),
torch.randn((1, 2, 40), device="cuda"),
torch.randn((1, 2, 40), device="cuda"),
)
except Exception as e:
raise e
self._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
self._attention_op = attention_op

def forward(self, hidden_states):
residual = hidden_states
batch, channel, height, width = hidden_states.shape

# norm
hidden_states = self.group_norm(hidden_states)

hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2)

# proj to q, k, v
query_proj = self.query(hidden_states)
key_proj = self.key(hidden_states)
value_proj = self.value(hidden_states)

scale = 1 / math.sqrt(self.channels / self.num_heads)

_use_2_0_attn = self._use_2_0_attn and not self._use_memory_efficient_attention_xformers
use_torch_2_0_attn = hasattr(F, "scaled_dot_product_attention") and _use_2_0_attn

query_proj = self.reshape_heads_to_batch_dim(query_proj, merge_head_and_batch=not use_torch_2_0_attn)
key_proj = self.reshape_heads_to_batch_dim(key_proj, merge_head_and_batch=not use_torch_2_0_attn)
value_proj = self.reshape_heads_to_batch_dim(value_proj, merge_head_and_batch=not use_torch_2_0_attn)

if self._use_memory_efficient_attention_xformers:
# Memory efficient attention
hidden_states = xformers.ops.memory_efficient_attention(
query_proj, key_proj, value_proj, attn_bias=None, op=self._attention_op, scale=scale
)
hidden_states = hidden_states.to(query_proj.dtype)
elif use_torch_2_0_attn:
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query_proj, key_proj, value_proj, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.to(query_proj.dtype)
else:
attention_scores = torch.baddbmm(
torch.empty(
query_proj.shape[0],
query_proj.shape[1],
key_proj.shape[1],
dtype=query_proj.dtype,
device=query_proj.device,
),
query_proj,
key_proj.transpose(-1, -2),
beta=0,
alpha=scale,
)
attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype)
hidden_states = torch.bmm(attention_probs, value_proj)

# reshape hidden_states
hidden_states = self.reshape_batch_dim_to_heads(hidden_states, unmerge_head_and_batch=not use_torch_2_0_attn)

# compute next hidden_states
hidden_states = self.proj_attn(hidden_states)

hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)

# res connect and rescale
hidden_states = (hidden_states + residual) / self.rescale_output_factor
return hidden_states


@maybe_allow_in_graph
class BasicTransformerBlock(nn.Module):
r"""
Expand Down

0 comments on commit d749d57

Please sign in to comment.