-
Notifications
You must be signed in to change notification settings - Fork 400
/
alibi.py
341 lines (290 loc) · 15.6 KB
/
alibi.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0
"""Core ALiBi classes and functions."""
from __future__ import annotations
import importlib
import logging
import math
from operator import attrgetter
from types import MethodType, ModuleType
from typing import Any, Callable, Optional, Sequence, Tuple, Type, Union, cast
import torch
from torch.optim import Optimizer
from composer.core import Algorithm, Event, State
from composer.loggers import Logger
from composer.utils import module_surgery
log = logging.getLogger(__name__)
__all__ = ['Alibi', 'apply_alibi']
def apply_alibi(
model: torch.nn.Module,
heads_per_layer: int,
max_sequence_length: int,
position_embedding_attribute: str,
attention_module: Type[torch.nn.Module],
attr_to_replace: str,
alibi_attention: Callable,
mask_replacement_function: Optional[Callable[[torch.nn.Module, int], torch.nn.Module]] = None,
optimizers: Optional[Union[Optimizer, Sequence[Optimizer]]] = None,
) -> None:
"""Removes position embeddings and replaces the attention function and attention mask
as per :class:`.Alibi`. Note that the majority of the training speed-up from using ALiBi
comes from being able to train on shorter sequence lengths; this function does not scale
the training sequence length as :class:`.Alibi` does, so little speedup will be
observed from using it alone. See the :doc:`Method Card </method_cards/alibi>` for
more details. This function should be called after the model is instantiated and
before training begins.
Example:
.. code-block:: python
import composer.functional as cf
from composer.algorithms.alibi.gpt2_alibi import _attn
from composer.algorithms.alibi.gpt2_alibi import enlarge_mask
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention
cf.apply_alibi(
model=model,
heads_per_layer=12,
max_sequence_length=8192,
position_embedding_attribute="module.transformer.wpe",
attention_module=GPT2Attention,
attr_to_replace="_attn",
alibi_attention=_attn,
mask_replacement_function=enlarge_mask
)
Args:
model (torch.nn.Module): Model to transform.
heads_per_layer (int): Number of attention heads per layer.
max_sequence_length (int): See :class:`.Alibi`.
position_embedding_attribute (str): See :class:`.Alibi`.
attention_module (torch.nn.Module): Module/class that will have its
self-attention function replaced. For example, in
HuggingFace's GPT, the self-attention module is
``transformers.models.gpt2.modeling_gpt2.GPT2Attention``.
attr_to_replace (str): See :class:`.Alibi`.
alibi_attention (Callable): Path to new self-attention function in which
ALiBi is implemented. Used to replace
``{attention_module}.{attr_to_replace}``. Example:
``composer.algorithms.alibi._gpt2_alibi._attn``.
mask_replacement_function ([Callable[[torch.nn.Module, int], torch.nn.Module]], optional):
Function to replace model's attention mask. This can be
necessary for evaluating on sequence lengths longer than the model was
initialized to accommodate. Takes positional arguments ``module`` and
``max_sequence_length``. For example,
``composer.algorithms.alibi._gpt2_alibi.enlarge_mask``. Default: ``None``,
which means no modification of the model's default attention mask.
optimizers (torch.optim.Optimizer | Sequence[torch.optim.Optimizer], optional):
Existing optimizers bound to ``model.parameters()``. All optimizers that have already been
constructed with ``model.parameters()`` must be specified here so
they will optimize the correct parameters.
If the optimizer(s) are constructed *after* calling this function,
then it is safe to omit this parameter. These optimizers will see the correct
model parameters.
Returns:
None
"""
old_embed, new_embed = _zero_and_freeze_expand_position_embeddings(
model=model,
attribute=position_embedding_attribute,
new_embedding_length=max_sequence_length,
)
if optimizers and old_embed is not None and new_embed is not None:
module_surgery.update_params_in_optimizer([old_embed], [new_embed], optimizers=optimizers)
log.info(f' Position embedding expanded to sequence length {max_sequence_length}, zeroed, and frozen')
def convert_attention(module: torch.nn.Module, module_index: Optional[int] = None):
del module_index # unused
module = _register_alibi(module=module, n_heads=heads_per_layer, max_token_length=max_sequence_length)
setattr(module, attr_to_replace, MethodType(alibi_attention, module))
if mask_replacement_function:
module = mask_replacement_function(module, max_sequence_length)
return module
replaced_pairs = module_surgery.replace_module_classes(model,
optimizers=optimizers,
policies={attention_module: convert_attention})
count = len(replaced_pairs)
log.info(f' {count} instances of ALiBi added')
class Alibi(Algorithm):
"""ALiBi (Attention with Linear Biases; `Press et al, 2021 <https://arxiv.org/abs/2108.12409>`_) dispenses with
position embeddings and instead directly biases attention matrices such that nearby tokens attend to one another
more strongly.
ALiBi yields excellent extrapolation to unseen sequence lengths
compared to other position embedding schemes. We leverage this
extrapolation capability by training with shorter sequence lengths,
which reduces the memory and computation load.
This algorithm runs on :attr:`~composer.core.event.Event.INIT` to modify the model
before the model has been moved to accelerators. It also runs on
:attr:`~composer.core.event.Event.AFTER_DATALOADER` to modify the shape of a batch of
data after the model and data have been moved to accelerators.
See the :doc:`Method Card </method_cards/alibi>` for more details.
Example:
.. code-block::
from composer.algorithms import Alibi
from composer.trainer import Trainer
alibi = Alibi(
position_embedding_attribute="module.transformer.wpe",
attention_module_name="transformers.models.gpt2.modeling_gpt2.GPT2Attention"
attr_to_replace="_attn",
alibi_attention="composer.algorithms._gpt2_alibi._attn",
mask_replacement_function="composer.algorithms.alibi.gpt2_alibi.enlarge_mask"
max_sequence_length=8192
)
trainer = Trainer(
model=model,
train_dataloader=train_dataloader,
max_duration="1ep",
algorithms=[alibi]
)
Args:
position_embedding_attribute (str): Attribute for position
embeddings. For example in HuggingFace's GPT2, the
position embeddings are ``'transformer.wpe'``.
attention_module_name (str): Module/class that will have its
self-attention function replaced. For example, in
HuggingFace's GPT, the self-attention module is
``'transformers.models.gpt2.modeling_gpt2.GPT2Attention'``.
attr_to_replace (str): Attribute that self-attention function will
replace. For example, in HuggingFace's GPT2, the
self-attention function is ``'_attn'``.
alibi_attention (str): Path to new self-attention function in which
ALiBi is implemented. Used to replace
``{attention_module}.{attr_to_replace}``. Example:
``'composer.algorithms.alibi._gpt2_alibi._attn'``.
mask_replacement_function (str, optional): Path to function to replace model's
attention mask. This can be necessary if evaluating
on sequence lengths longer than the model was initialized to
accommodate. Takes positional arguments ``module`` and
``max_sequence_length``. For example,
``'composer.algorithms.alibi._gpt2_alibi.enlarge_mask'``. Default = ``None``,
which means no modification of the model's default attention mask.
heads_per_layer (int, optional): Number of attention heads per layer.
max_sequence_length (int): Maximum sequence length that the
model will be able to accept. This is sometimes necessary for evaluating
on sequence lengths longer than the model was initialized to
accommodate. Default: ``8192``.
train_sequence_length_scaling (float, optional): Amount by which to scale
training sequence length. One batch of training data will be
reshaped from shape :math:`(sequence\\_length, batch)` to
:math:`(sequence\\_length \\times train\\_sequence\\_length\\_scaling,
\\frac{batch}{train\\_sequence\\_length\\_scaling})`. Default: ``0.25``.
"""
def __init__(self,
position_embedding_attribute: str,
attention_module_name: str,
attr_to_replace: str,
alibi_attention: str,
mask_replacement_function: Optional[str] = None,
heads_per_layer: Optional[int] = None,
max_sequence_length: int = 8192,
train_sequence_length_scaling: float = 0.25) -> None:
self.position_embedding_attribute = position_embedding_attribute
self.attention_module_name = attention_module_name
self.attr_to_replace = attr_to_replace
self.alibi_attention = alibi_attention
self.mask_replacement_function = mask_replacement_function
self.heads_per_layer = heads_per_layer
self.max_sequence_length = max_sequence_length
self.train_sequence_length_scaling = train_sequence_length_scaling
self._applied = False
def match(self, event: Event, state: State) -> bool:
return (event == Event.INIT and not self._applied) or event == Event.AFTER_DATALOADER
def apply(self, event: Event, state: State, logger: Logger) -> Optional[int]:
if event == Event.INIT:
if self.heads_per_layer is None:
try:
self.heads_per_layer = state.model.config.n_head # type: ignore
except AttributeError:
log.exception('alibi.heads_per_layer not provided, and unable to '
'determine number of heads from model.config.n_head.'
' Please provide alibi. heads_per_layer.')
apply_alibi(
state.model,
optimizers=state.optimizers,
heads_per_layer=cast(int, self.heads_per_layer),
max_sequence_length=self.max_sequence_length,
position_embedding_attribute=self.position_embedding_attribute,
attr_to_replace=self.attr_to_replace,
# Access method from string
attention_module=_lazy_import(self.attention_module_name),
# Access method from string
alibi_attention=_lazy_import(self.alibi_attention),
# Access method from string
mask_replacement_function=_lazy_import(self.mask_replacement_function))
self._applied = True
elif event == Event.AFTER_DATALOADER:
# Change sequence length by reshaping data
if not self.train_sequence_length_scaling == 1 and \
hasattr(state, 'batch') and isinstance(state.batch, dict):
sequence_scaling = self.train_sequence_length_scaling
for k, v in state.batch.items():
batch_len, sequence_len = v.shape[0], v.shape[1]
state.batch[k] = v.reshape(int(batch_len / sequence_scaling), int(sequence_len * sequence_scaling))
def _zero_and_freeze_expand_position_embeddings(
model: torch.nn.Module,
new_embedding_length: int,
attribute: str,
) -> Union[Tuple[torch.nn.Parameter, torch.nn.Parameter], Tuple[None, None]]:
try:
pos_embedding_module = attrgetter(attribute)(model)
old_weight = getattr(pos_embedding_module, 'weight')
if not isinstance(old_weight, torch.nn.Parameter):
raise TypeError(
f"Model {model._get_name()}, position embedding {attribute}, 'weight' attribute must of type torch.nn.Module"
)
new_weight = torch.nn.Parameter(
torch.zeros((new_embedding_length, old_weight.shape[1]),
dtype=old_weight.dtype,
layout=old_weight.layout,
device=old_weight.device))
new_weight.requires_grad = False
setattr(pos_embedding_module, 'weight', new_weight)
return old_weight, new_weight
except AttributeError:
log.error(f'Unable to zero and freeze position embeddings. Model '
f'{model} may lack attribute {attribute}, or position '
f"embeddings may lack attribute 'weight'.")
return None, None
def _register_alibi(module: torch.nn.Module, n_heads: int, max_token_length: int):
# Modified from https://github.com/ofirpress/attention_with_linear_biases/blob/5b327adc6d131e28b40ba58906b30bb469483519/fairseq/models/transformer.py#L742
slopes = torch.Tensor(_get_alibi_head_slopes(n_heads))
# In the next line, the part after the * is what constructs the diagonal matrix
# (right matrix in Figure 3 in the paper).
# If you run it you'll see that it doesn't exactly print out the same matrix as we
# have in Figure 3, but one where all rows are identical.
# This works because the softmax operation is invariant to translation, and our bias
# functions are always linear.
alibi = slopes.unsqueeze(1).unsqueeze(1) * \
torch.arange(max_token_length). \
unsqueeze(0).unsqueeze(0).expand(n_heads, -1, -1)
module.register_buffer('alibi', alibi)
return module
def _get_alibi_head_slopes(n_heads: int):
def get_slopes_power_of_2(n_heads):
start = (2**(-2**-(math.log2(n_heads) - 3)))
ratio = start
return [start * ratio**i for i in range(n_heads)]
# In the paper, they only train models that have 2^a heads for some a. This function
# has some good properties that only occur when the input is a power of 2. To
# maintain that even when the number of heads is not a power of 2, we use a
# workaround.
if math.log2(n_heads).is_integer():
return get_slopes_power_of_2(n_heads)
else:
closest_power_of_2 = 2**math.floor(math.log2(n_heads))
return get_slopes_power_of_2(closest_power_of_2) + _get_alibi_head_slopes(
2 * closest_power_of_2)[0::2][:n_heads - closest_power_of_2]
def _lazy_import(name: Optional[str]) -> Any[Callable, ModuleType, None]:
if not name:
return None
components = name.split('.')
try:
mod = importlib.import_module(components[0])
except (ValueError, ModuleNotFoundError):
log.exception(f'Module {components[0]} not found when attempting '
f'to import {name}. Please confirm the name and '
f"module path you're attempting to import.")
raise
try:
mod = attrgetter('.'.join(components[1:]))(mod)
except (ValueError, AttributeError):
log.exception(f'Unable to import {name}. '
f'Please confirm the name and module '
f" path you're attempting to import.")
raise
return mod