-
Notifications
You must be signed in to change notification settings - Fork 555
/
__init__.py
70 lines (54 loc) · 2.79 KB
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import fields
from pathlib import Path
from typing import Any, Dict, Union
from xformers.utils import import_all_modules
from .activations import Activation, build_activation # noqa
from .attention import Attention, build_attention # noqa
from .in_proj_container import InProjContainer, InProjParams # noqa
from .multi_head_dispatch import MultiHeadDispatch # noqa
from .multi_head_dispatch import MultiHeadDispatchConfig
from .residual import LayerNormStyle, PostNorm, PreNorm, Residual # noqa
# automatically import any Python files in the directory
import_all_modules(str(Path(__file__).parent), "xformers.components")
def build_multi_head_attention(
multi_head_config: Union[MultiHeadDispatchConfig, Dict[str, Any]],
):
"""Builds a multihead attention from a config.
This assumes a 'name' key in the config which is used to determine what
attention class to instantiate. For instance, a config `{"name": "my_attention",
"foo": "bar"}` will find a class that was registered as "my_attention"
(see :func:`register_attention`) and call .from_config on it."""
if not isinstance(multi_head_config, MultiHeadDispatchConfig):
# Extract the required fields
field_names = list(map(lambda x: x.name, fields(MultiHeadDispatchConfig)))
# The missing fields get Noned
for k in field_names:
if k not in multi_head_config.keys():
multi_head_config[k] = None
# Could be that the attention needs to be instantiated
if not isinstance(multi_head_config["attention"], Attention):
# Convenience: fill in possible missing fields
if "num_heads" not in multi_head_config["attention"]:
multi_head_config["attention"]["num_heads"] = multi_head_config[
"num_heads"
]
if "dim_model" not in multi_head_config["attention"]:
multi_head_config["attention"]["dim_model"] = multi_head_config[
"dim_model"
]
if (
"dim_features" not in multi_head_config["attention"]
or multi_head_config["attention"]["dim_features"] is None
):
multi_head_config["attention"]["dim_features"] = (
multi_head_config["dim_model"] // multi_head_config["num_heads"]
)
multi_head_config["attention"] = build_attention(
multi_head_config["attention"]
)
multi_head_config = MultiHeadDispatchConfig(**multi_head_config)
return MultiHeadDispatch.from_config(multi_head_config)