-
Notifications
You must be signed in to change notification settings - Fork 400
/
alibi.py
187 lines (147 loc) · 8.15 KB
/
alibi.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0
"""Core ALiBi classes and functions."""
from __future__ import annotations
import logging
from typing import Optional, Sequence, Union
import torch
from torch.optim import Optimizer
from composer.core import Algorithm, Event, State
from composer.loggers import Logger
from composer.utils import MissingConditionalImportError, module_surgery
log = logging.getLogger(__name__)
__all__ = ['Alibi', 'apply_alibi']
def apply_alibi(
model: torch.nn.Module,
max_sequence_length: int,
optimizers: Optional[Union[Optimizer, Sequence[Optimizer]]] = None,
) -> None:
"""Removes position embeddings and replaces the attention function and attention mask
as per :class:`.Alibi`. Note that the majority of the training speed-up from using ALiBi
comes from being able to train on shorter sequence lengths; this function does not scale
the training sequence length as :class:`.Alibi` does, so little speedup will be
observed from using it alone. See the :doc:`Method Card </method_cards/alibi>` for
more details. This function should be called after the model is instantiated and
before training begins.
Example:
.. code-block:: python
import composer.functional as cf
cf.apply_alibi(
model=model,
max_sequence_length=512
)
Args:
model (torch.nn.Module): Model to transform.
max_sequence_length (int): Maximum sequence length that the
model will be able to accept. Internally, the transformations applied by alibi
change sequence-shaped tensors to handle sequences up to ``max_sequence_length``.
Depending on ``max_sequence_length`` and ``model`` these changes could increase
or decrease the model's maximum sequence length.
At minimum, ``max_sequence_length`` should be set to the sequence length used
during training. However, if evaluating on sequence lengths longer than those
used in training, ``max_sequence_length`` should be set accordingly.
Note that larger ``max_sequence_length`` means a larger memory footprint of
the model. So, it is best to set this parameter equal the longest
sequence length that will be seen during training and/or evaluation.
optimizers (torch.optim.Optimizer | Sequence[torch.optim.Optimizer], optional):
Existing optimizers bound to ``model.parameters()``. All optimizers that have already been
constructed with ``model.parameters()`` must be specified here so
they will optimize the correct parameters.
If the optimizer(s) are constructed *after* calling this function,
then it is safe to omit this parameter. These optimizers will see the correct
model parameters.
"""
try:
from composer.algorithms.alibi.attention_surgery_functions import policy_registry
except ImportError as e:
raise MissingConditionalImportError(extra_deps_group='nlp', conda_package='transformers') from e
# To use model surgery utilities, we need to define a policy of type
# Mapping[Type[torch.nn.Module], ReplacementFunction], where ReplacementFunction is
# Callable[[torch.nn.Module, Optional[int]], Optional[torch.nn.Module]].
#
# This mapping is built by the source code in `./attention_surgery_functions/` but
# needs to be completed here by "freezing" alibi-specific arguments.
#
# For additional details, see `./attention_surgery_functions/utils.py`.
def as_replacement_function(surgery_function):
def replacement_function(module: torch.nn.Module, module_index: int):
return surgery_function(module, module_index, max_sequence_length=max_sequence_length)
return replacement_function
# Wrap each alibi_surgery_function as a ReplacementFunction by "freezing" `max_sequence_length`
policies = {
module_class: as_replacement_function(alibi_surgery_function)
for module_class, alibi_surgery_function in policy_registry.items()
}
# Note: `policies` defines replacements for _all_ the modules registered in `policy_registry`,
# meaning that some replacements may be irrelevant for `model`.
# Conversely, attention modules within `model` may be ignored if they are not registered by the
# implementations within `./attention_surgery_functions/`.
replaced_pairs = module_surgery.replace_module_classes(model, optimizers=optimizers, policies=policies)
count = len(replaced_pairs)
if count == 0:
supported_modules = ''.join(sorted(['\n\t' + c.__module__ + '.' + c.__name__ for c in policy_registry.keys()]))
log.warning(f'ALiBi had no effect on the model! Support for ALiBi surgery '
f'is currently limited to the following classes: {supported_modules}')
else:
log.info(f' {count} instances of ALiBi added')
class Alibi(Algorithm):
"""ALiBi (Attention with Linear Biases; `Press et al, 2021 <https://arxiv.org/abs/2108.12409>`_) dispenses with
position embeddings and instead directly biases attention matrices such that nearby tokens attend to one another
more strongly.
ALiBi yields excellent extrapolation to unseen sequence lengths
compared to other position embedding schemes. We leverage this
extrapolation capability by training with shorter sequence lengths,
which reduces the memory and computation load.
This algorithm runs on :attr:`.Event.INIT` to modify the model
before the model has been moved to accelerators. It also runs on
:attr:`.Event.AFTER_DATALOADER` to modify the shape of a batch of
data after the model and data have been moved to accelerators.
See the :doc:`Method Card </method_cards/alibi>` for more details.
Example:
.. code-block::
from composer.algorithms import Alibi
from composer.trainer import Trainer
alibi = Alibi(
max_sequence_length=512,
train_sequence_length_scaling=0.25,
)
trainer = Trainer(
model=model,
train_dataloader=train_dataloader,
max_duration="1ep",
algorithms=[alibi]
)
Args:
max_sequence_length (int): Maximum sequence length that the
model will be able to accept. This is sometimes necessary for evaluating
on sequence lengths longer than the model was initialized to
accommodate.
train_sequence_length_scaling (float, optional): Amount by which to scale
training sequence length. One batch of training data will be
reshaped from shape :math:`(sequence\\_length, batch)` to
:math:`(sequence\\_length \\times train\\_sequence\\_length\\_scaling,
\\frac{batch}{train\\_sequence\\_length\\_scaling})`. Default: ``0.25``.
"""
def __init__(self, max_sequence_length: int, train_sequence_length_scaling: float = 0.25) -> None:
# self.position_embedding_attribute = position_embedding_attribute
self.max_sequence_length = max_sequence_length
self.train_sequence_length_scaling = train_sequence_length_scaling
self._applied = False
def match(self, event: Event, state: State) -> bool:
return (event == Event.INIT and not self._applied) or event == Event.AFTER_DATALOADER
def apply(self, event: Event, state: State, logger: Logger) -> Optional[int]:
if event == Event.INIT:
apply_alibi(
state.model,
optimizers=state.optimizers,
max_sequence_length=self.max_sequence_length,
)
self._applied = True
elif event == Event.AFTER_DATALOADER:
# Change sequence length by reshaping data
if not self.train_sequence_length_scaling == 1 and \
hasattr(state, 'batch') and isinstance(state.batch, dict):
sequence_scaling = self.train_sequence_length_scaling
for k, v in state.batch.items():
batch_len, sequence_len = v.shape[0], v.shape[1]
state.batch[k] = v.reshape(int(batch_len / sequence_scaling), int(sequence_len * sequence_scaling))