Skip to content

Commit bf65397

Browse files
committed
RetNet
1 parent 89d8c6d commit bf65397

File tree

4 files changed

+1050
-0
lines changed

4 files changed

+1050
-0
lines changed

examples/fairseq/models/retnet.py

Lines changed: 377 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,377 @@
1+
# Copyright (c) 2022 Microsoft
2+
# Licensed under The MIT License [see LICENSE for details]
3+
4+
# Copyright (c) Facebook, Inc. and its affiliates.
5+
#
6+
# This source code is licensed under the MIT license found in the
7+
# LICENSE file in the root directory of this source tree.
8+
9+
import logging
10+
from dataclasses import dataclass, field
11+
from typing import Optional
12+
13+
import torch
14+
from fairseq import distributed_utils, utils
15+
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
16+
from fairseq.models import (
17+
FairseqIncrementalDecoder,
18+
FairseqLanguageModel,
19+
register_model,
20+
register_model_architecture,
21+
)
22+
from fairseq.models.transformer import DEFAULT_MIN_PARAMS_TO_WRAP, Embedding
23+
from omegaconf import II
24+
25+
from torchscale.architecture.config import RetNetConfig
26+
from torchscale.architecture.retnet import RetNetDecoder
27+
28+
DEFAULT_MAX_TARGET_POSITIONS = 1024
29+
logger = logging.getLogger(__name__)
30+
31+
32+
@dataclass
33+
class LanguageConfig(FairseqDataclass):
34+
activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field(
35+
default="relu", metadata={"help": "activation function to use"}
36+
)
37+
dropout: float = field(default=0.1, metadata={"help": "dropout probability"})
38+
activation_dropout: float = field(
39+
default=0.0, metadata={"help": "dropout probability after activation in FFN."}
40+
)
41+
relu_dropout: float = field(
42+
default=0.0, metadata={"help": "dropout probability after activation in FFN."}
43+
)
44+
decoder_embed_dim: int = field(
45+
default=512, metadata={"help": "decoder embedding dimension"}
46+
)
47+
decoder_output_dim: int = field(
48+
default=512, metadata={"help": "decoder output dimension"}
49+
)
50+
decoder_input_dim: int = field(
51+
default=512, metadata={"help": "decoder input dimension"}
52+
)
53+
decoder_ffn_embed_dim: int = field(
54+
default=2048, metadata={"help": "decoder embedding dimension for FFN"}
55+
)
56+
decoder_layers: int = field(default=6, metadata={"help": "num decoder layers"})
57+
decoder_retention_heads: int = field(
58+
default=2, metadata={"help": "num decoder retention heads"}
59+
)
60+
decoder_normalize_before: bool = field(
61+
default=False, metadata={"help": "apply layernorm before each decoder block"}
62+
)
63+
share_decoder_input_output_embed: bool = field(
64+
default=False, metadata={"help": "share decoder input and output embeddings"}
65+
)
66+
decoder_learned_pos: bool = field(
67+
default=False,
68+
metadata={"help": "use learned positional embeddings in the decoder"},
69+
)
70+
layernorm_embedding: bool = field(
71+
default=False, metadata={"help": "add layernorm to embedding"}
72+
)
73+
no_scale_embedding: bool = field(
74+
default=False, metadata={"help": "if True, dont scale embeddings"}
75+
)
76+
checkpoint_activations: bool = field(
77+
default=False, metadata={"help": "checkpoint activations at each layer"}
78+
)
79+
offload_activations: bool = field(
80+
default=False,
81+
metadata={"help": "move checkpointed activations to CPU after they are used."},
82+
)
83+
# config for Fully Sharded Data Parallel (FSDP) training
84+
min_params_to_wrap: int = field(
85+
default=DEFAULT_MIN_PARAMS_TO_WRAP,
86+
metadata={
87+
"help": (
88+
"minimum number of params for a layer to be wrapped with FSDP() when "
89+
"training with --ddp-backend=fully_sharded. Smaller values will "
90+
"improve memory efficiency, but may make torch.distributed "
91+
"communication less efficient due to smaller input sizes. This option "
92+
"is set to 0 (i.e., always wrap) when --checkpoint-activations or "
93+
"--offload-activations are passed."
94+
)
95+
},
96+
)
97+
moe_freq: int = field(
98+
default=0,
99+
metadata={"help": "Frequency at which we insert MoE Transformer layers"},
100+
)
101+
moe_expert_count: int = field(
102+
default=0, metadata={"help": "Number of experts in each MoE Layer"}
103+
)
104+
moe_gating_use_fp32: bool = field(
105+
default=False,
106+
metadata={"help": "Use FP32 computations in MoE top2 gating function"},
107+
)
108+
moe_second_expert_policy: str = field(
109+
default="sampling",
110+
metadata={"help": "policy for second expert, options: all/sampling/random"},
111+
)
112+
moe_normalize_gate_prob_before_dropping: bool = field(
113+
default=False,
114+
metadata={
115+
"help": "whether to normalize gate probs before or after dropping experts for capacity and randomization"
116+
},
117+
)
118+
moe_expert_ffn_dim: Optional[int] = field(
119+
default=None, metadata={"help": "MoE expert FFN dimension"}
120+
)
121+
moe_top1_expert: Optional[bool] = field(
122+
default=False, metadata={"help": "Use top1 gate instead of top2"}
123+
)
124+
moe_eval_capacity_token_fraction: Optional[float] = field(
125+
default=0.25,
126+
metadata={
127+
"help": (
128+
"Default: 0.25, Fraction of tokens as capacity during validation, "
129+
"if set to negative, use same as training. range: (0.0, 1.0]."
130+
)
131+
},
132+
)
133+
moe_normalize_expert_grad: Optional[str] = field(
134+
default="world_size",
135+
metadata={
136+
"help": "Divide expert gradients by (1) 'world_size' (2) 'sqrt_world_size'"
137+
},
138+
)
139+
record_a2a_perf_stats: Optional[bool] = field(
140+
default=False,
141+
metadata={"help": "records all to all perf stats during distributed training"},
142+
)
143+
dummy_a2a: Optional[bool] = field(
144+
default=False,
145+
metadata={
146+
"help": "By passes all to all during distributed training by returning the input buffer as output"
147+
},
148+
)
149+
moe_batch_prioritized_routing: Optional[bool] = field(
150+
default=False,
151+
metadata={
152+
"help": "if true orders token by the gate prob before capacity dropping."
153+
},
154+
)
155+
use_xmoe: Optional[bool] = field(
156+
default=False,
157+
)
158+
chunkwise_recurrent: Optional[bool] = field(
159+
default=False,
160+
)
161+
recurrent_chunk_size: Optional[int] = field(
162+
default=512,
163+
)
164+
165+
166+
# options from other parts of the config
167+
add_bos_token: bool = II("task.add_bos_token")
168+
tokens_per_sample: int = II("task.tokens_per_sample")
169+
max_target_positions: Optional[int] = II("task.max_target_positions")
170+
tpu: bool = II("common.tpu")
171+
memory_efficient_fp16: bool = II("common.memory_efficient_fp16")
172+
fp16: bool = II("common.fp16")
173+
fp16_no_flatten_grads: bool = II("common.fp16_no_flatten_grads")
174+
ddp_backend: str = II("distributed_training.ddp_backend")
175+
world_size: int = II("distributed_training.distributed_world_size")
176+
distributed_rank: int = II("distributed_training.distributed_rank")
177+
ddp_rank: int = II("distributed_training.distributed_rank")
178+
deepnorm: Optional[bool] = field(
179+
default=False,
180+
)
181+
subln: Optional[bool] = field(
182+
default=False,
183+
)
184+
185+
186+
@register_model("retnet", dataclass=LanguageConfig)
187+
class RetNetLanguageModel(FairseqLanguageModel):
188+
def __init__(self, args, decoder):
189+
self.args = args
190+
super().__init__(decoder)
191+
192+
@classmethod
193+
def build_model(cls, args, task):
194+
195+
if getattr(args, "max_target_positions", None) is None:
196+
args.max_target_positions = getattr(
197+
args, "tokens_per_sample", DEFAULT_MAX_TARGET_POSITIONS
198+
)
199+
200+
embed_tokens = cls.build_embedding(
201+
args, task.source_dictionary, args.decoder_embed_dim
202+
)
203+
if args.share_decoder_input_output_embed:
204+
output_projection = torch.nn.Linear(
205+
embed_tokens.weight.shape[1],
206+
embed_tokens.weight.shape[0],
207+
bias=False,
208+
)
209+
output_projection.weight = embed_tokens.weight
210+
else:
211+
output_projection = torch.nn.Linear(
212+
args.decoder_embed_dim, len(task.dictionary), bias=False
213+
)
214+
torch.nn.init.normal_(
215+
output_projection.weight, mean=0, std=args.decoder_embed_dim**-0.5
216+
)
217+
218+
if getattr(args, "moe_freq", 0) > 0 and (
219+
getattr(args, "fp16", False)
220+
and not getattr(args, "memory_efficient_fp16", False)
221+
and getattr(args, "ddp_backend", None) != "fully_sharded"
222+
):
223+
assert (
224+
args.fp16_no_flatten_grads
225+
), "If training moe models, set --fp16-no-flatten-grads to calculate correct gradnorm"
226+
227+
args.ddp_rank = distributed_utils.get_data_parallel_rank()
228+
229+
config = RetNetConfig()
230+
config.override(args)
231+
232+
decoder = LMDecoder(
233+
config,
234+
embed_tokens,
235+
output_projection,
236+
dictionary=task.dictionary,
237+
)
238+
239+
return cls(args, decoder)
240+
241+
@classmethod
242+
def build_embedding(cls, args, dictionary, embed_dim, path=None):
243+
return Embedding(len(dictionary), embed_dim, dictionary.pad())
244+
245+
246+
class LMDecoder(RetNetDecoder, FairseqIncrementalDecoder):
247+
def forward(self, src_tokens, **kwargs):
248+
return super().forward(src_tokens, **kwargs)
249+
250+
def max_positions(self):
251+
return self.args.max_target_positions
252+
253+
def reorder_incremental_state_scripting(
254+
self,
255+
incremental_state,
256+
new_order,
257+
):
258+
for module in incremental_state:
259+
for key in incremental_state[module]:
260+
result = incremental_state[module][key].index_select(0, new_order)
261+
incremental_state[module][key] = result
262+
263+
264+
@register_model_architecture("retnet", "retnet_base")
265+
def retnet_base_architecture(args):
266+
# backward compatibility for older model checkpoints
267+
if hasattr(args, "no_tie_adaptive_proj"):
268+
# previous models defined --no-tie-adaptive-proj, so use the existence of
269+
# that option to determine if this is an "old" model checkpoint
270+
args.no_decoder_final_norm = True # old models always set this to True
271+
if args.no_tie_adaptive_proj is False:
272+
args.tie_adaptive_proj = True
273+
if hasattr(args, "decoder_final_norm"):
274+
args.no_decoder_final_norm = not args.decoder_final_norm
275+
276+
args.dropout = getattr(args, "dropout", 0.0)
277+
278+
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512)
279+
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 1024)
280+
args.decoder_layers = getattr(args, "decoder_layers", 6)
281+
args.decoder_retention_heads = getattr(args, "decoder_retention_heads", 2)
282+
args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
283+
args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
284+
args.adaptive_softmax_factor = getattr(args, "adaptive_softmax_factor", 4)
285+
args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False)
286+
args.activation_fn = getattr(args, "activation_fn", "gelu")
287+
288+
args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0)
289+
args.decoder_layers_to_keep = getattr(args, "decoder_layers_to_keep", None)
290+
291+
args.base_layers = getattr(args, "base_layers", 0)
292+
args.base_sublayers = getattr(args, "base_sublayers", 1)
293+
args.base_shuffle = getattr(args, "base_shuffle", False)
294+
295+
args.add_bos_token = getattr(args, "add_bos_token", False)
296+
args.no_token_positional_embeddings = getattr(
297+
args, "no_token_positional_embeddings", False
298+
)
299+
args.share_decoder_input_output_embed = getattr(
300+
args, "share_decoder_input_output_embed", False
301+
)
302+
args.character_embeddings = getattr(args, "character_embeddings", False)
303+
304+
args.decoder_output_dim = getattr(
305+
args, "decoder_output_dim", args.decoder_embed_dim
306+
)
307+
args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim)
308+
309+
args.chunkwise_recurrent = getattr(args, "chunkwise_recurrent", False)
310+
args.recurrent_chunk_size = getattr(args, "recurrent_chunk_size", 512)
311+
312+
# Model training is not stable without this
313+
args.decoder_normalize_before = True
314+
args.no_decoder_final_norm = getattr(args, "no_decoder_final_norm", False)
315+
316+
args.adaptive_input = getattr(args, "adaptive_input", False)
317+
args.adaptive_input_factor = getattr(args, "adaptive_input_factor", 4)
318+
args.adaptive_input_cutoff = getattr(args, "adaptive_input_cutoff", None)
319+
320+
args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False)
321+
args.tie_adaptive_proj = getattr(args, "tie_adaptive_proj", False)
322+
323+
args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
324+
args.layernorm_embedding = getattr(args, "layernorm_embedding", False)
325+
args.checkpoint_activations = getattr(args, "checkpoint_activations", False)
326+
args.offload_activations = getattr(args, "offload_activations", False)
327+
if args.offload_activations:
328+
args.checkpoint_activations = True
329+
330+
@register_model_architecture("retnet", "retnet_medium")
331+
def retnet_medium(args):
332+
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1024)
333+
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 2048)
334+
args.decoder_layers = getattr(args, "decoder_layers", 16)
335+
args.decoder_retention_heads = getattr(args, "decoder_retention_heads", 4)
336+
retnet_base_architecture(args)
337+
338+
@register_model_architecture("retnet", "retnet_xl")
339+
def retnet_xl(args):
340+
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 2048)
341+
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 4096)
342+
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8)
343+
args.decoder_layers = getattr(args, "decoder_layers", 24)
344+
retnet_base_architecture(args)
345+
346+
@register_model_architecture("retnet", "retnet_3b")
347+
def retnet_3b(args):
348+
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 2560)
349+
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 5120)
350+
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 10)
351+
args.decoder_layers = getattr(args, "decoder_layers", 32)
352+
retnet_base_architecture(args)
353+
354+
@register_model_architecture("retnet", "retnet_7b")
355+
def retnet_7b(args):
356+
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 4096)
357+
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 8192)
358+
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16)
359+
args.decoder_layers = getattr(args, "decoder_layers", 32)
360+
retnet_base_architecture(args)
361+
362+
@register_model_architecture("retnet", "retnet_13b")
363+
def retnet_13b(args):
364+
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 5120)
365+
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 10240)
366+
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 20)
367+
args.decoder_layers = getattr(args, "decoder_layers", 40)
368+
retnet_base_architecture(args)
369+
370+
@register_model_architecture("retnet", "retnet_65b")
371+
def retnet_65b(args):
372+
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 8192)
373+
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 16384)
374+
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 32)
375+
args.decoder_layers = getattr(args, "decoder_layers", 64)
376+
retnet_base_architecture(args)
377+

0 commit comments

Comments
 (0)