|
| 1 | +# Copyright (c) 2022 Microsoft |
| 2 | +# Licensed under The MIT License [see LICENSE for details] |
| 3 | + |
| 4 | +# Copyright (c) Facebook, Inc. and its affiliates. |
| 5 | +# |
| 6 | +# This source code is licensed under the MIT license found in the |
| 7 | +# LICENSE file in the root directory of this source tree. |
| 8 | + |
| 9 | +import logging |
| 10 | +from dataclasses import dataclass, field |
| 11 | +from typing import Optional |
| 12 | + |
| 13 | +import torch |
| 14 | +from fairseq import distributed_utils, utils |
| 15 | +from fairseq.dataclass import ChoiceEnum, FairseqDataclass |
| 16 | +from fairseq.models import ( |
| 17 | + FairseqIncrementalDecoder, |
| 18 | + FairseqLanguageModel, |
| 19 | + register_model, |
| 20 | + register_model_architecture, |
| 21 | +) |
| 22 | +from fairseq.models.transformer import DEFAULT_MIN_PARAMS_TO_WRAP, Embedding |
| 23 | +from omegaconf import II |
| 24 | + |
| 25 | +from torchscale.architecture.config import RetNetConfig |
| 26 | +from torchscale.architecture.retnet import RetNetDecoder |
| 27 | + |
| 28 | +DEFAULT_MAX_TARGET_POSITIONS = 1024 |
| 29 | +logger = logging.getLogger(__name__) |
| 30 | + |
| 31 | + |
| 32 | +@dataclass |
| 33 | +class LanguageConfig(FairseqDataclass): |
| 34 | + activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field( |
| 35 | + default="relu", metadata={"help": "activation function to use"} |
| 36 | + ) |
| 37 | + dropout: float = field(default=0.1, metadata={"help": "dropout probability"}) |
| 38 | + activation_dropout: float = field( |
| 39 | + default=0.0, metadata={"help": "dropout probability after activation in FFN."} |
| 40 | + ) |
| 41 | + relu_dropout: float = field( |
| 42 | + default=0.0, metadata={"help": "dropout probability after activation in FFN."} |
| 43 | + ) |
| 44 | + decoder_embed_dim: int = field( |
| 45 | + default=512, metadata={"help": "decoder embedding dimension"} |
| 46 | + ) |
| 47 | + decoder_output_dim: int = field( |
| 48 | + default=512, metadata={"help": "decoder output dimension"} |
| 49 | + ) |
| 50 | + decoder_input_dim: int = field( |
| 51 | + default=512, metadata={"help": "decoder input dimension"} |
| 52 | + ) |
| 53 | + decoder_ffn_embed_dim: int = field( |
| 54 | + default=2048, metadata={"help": "decoder embedding dimension for FFN"} |
| 55 | + ) |
| 56 | + decoder_layers: int = field(default=6, metadata={"help": "num decoder layers"}) |
| 57 | + decoder_retention_heads: int = field( |
| 58 | + default=2, metadata={"help": "num decoder retention heads"} |
| 59 | + ) |
| 60 | + decoder_normalize_before: bool = field( |
| 61 | + default=False, metadata={"help": "apply layernorm before each decoder block"} |
| 62 | + ) |
| 63 | + share_decoder_input_output_embed: bool = field( |
| 64 | + default=False, metadata={"help": "share decoder input and output embeddings"} |
| 65 | + ) |
| 66 | + decoder_learned_pos: bool = field( |
| 67 | + default=False, |
| 68 | + metadata={"help": "use learned positional embeddings in the decoder"}, |
| 69 | + ) |
| 70 | + layernorm_embedding: bool = field( |
| 71 | + default=False, metadata={"help": "add layernorm to embedding"} |
| 72 | + ) |
| 73 | + no_scale_embedding: bool = field( |
| 74 | + default=False, metadata={"help": "if True, dont scale embeddings"} |
| 75 | + ) |
| 76 | + checkpoint_activations: bool = field( |
| 77 | + default=False, metadata={"help": "checkpoint activations at each layer"} |
| 78 | + ) |
| 79 | + offload_activations: bool = field( |
| 80 | + default=False, |
| 81 | + metadata={"help": "move checkpointed activations to CPU after they are used."}, |
| 82 | + ) |
| 83 | + # config for Fully Sharded Data Parallel (FSDP) training |
| 84 | + min_params_to_wrap: int = field( |
| 85 | + default=DEFAULT_MIN_PARAMS_TO_WRAP, |
| 86 | + metadata={ |
| 87 | + "help": ( |
| 88 | + "minimum number of params for a layer to be wrapped with FSDP() when " |
| 89 | + "training with --ddp-backend=fully_sharded. Smaller values will " |
| 90 | + "improve memory efficiency, but may make torch.distributed " |
| 91 | + "communication less efficient due to smaller input sizes. This option " |
| 92 | + "is set to 0 (i.e., always wrap) when --checkpoint-activations or " |
| 93 | + "--offload-activations are passed." |
| 94 | + ) |
| 95 | + }, |
| 96 | + ) |
| 97 | + moe_freq: int = field( |
| 98 | + default=0, |
| 99 | + metadata={"help": "Frequency at which we insert MoE Transformer layers"}, |
| 100 | + ) |
| 101 | + moe_expert_count: int = field( |
| 102 | + default=0, metadata={"help": "Number of experts in each MoE Layer"} |
| 103 | + ) |
| 104 | + moe_gating_use_fp32: bool = field( |
| 105 | + default=False, |
| 106 | + metadata={"help": "Use FP32 computations in MoE top2 gating function"}, |
| 107 | + ) |
| 108 | + moe_second_expert_policy: str = field( |
| 109 | + default="sampling", |
| 110 | + metadata={"help": "policy for second expert, options: all/sampling/random"}, |
| 111 | + ) |
| 112 | + moe_normalize_gate_prob_before_dropping: bool = field( |
| 113 | + default=False, |
| 114 | + metadata={ |
| 115 | + "help": "whether to normalize gate probs before or after dropping experts for capacity and randomization" |
| 116 | + }, |
| 117 | + ) |
| 118 | + moe_expert_ffn_dim: Optional[int] = field( |
| 119 | + default=None, metadata={"help": "MoE expert FFN dimension"} |
| 120 | + ) |
| 121 | + moe_top1_expert: Optional[bool] = field( |
| 122 | + default=False, metadata={"help": "Use top1 gate instead of top2"} |
| 123 | + ) |
| 124 | + moe_eval_capacity_token_fraction: Optional[float] = field( |
| 125 | + default=0.25, |
| 126 | + metadata={ |
| 127 | + "help": ( |
| 128 | + "Default: 0.25, Fraction of tokens as capacity during validation, " |
| 129 | + "if set to negative, use same as training. range: (0.0, 1.0]." |
| 130 | + ) |
| 131 | + }, |
| 132 | + ) |
| 133 | + moe_normalize_expert_grad: Optional[str] = field( |
| 134 | + default="world_size", |
| 135 | + metadata={ |
| 136 | + "help": "Divide expert gradients by (1) 'world_size' (2) 'sqrt_world_size'" |
| 137 | + }, |
| 138 | + ) |
| 139 | + record_a2a_perf_stats: Optional[bool] = field( |
| 140 | + default=False, |
| 141 | + metadata={"help": "records all to all perf stats during distributed training"}, |
| 142 | + ) |
| 143 | + dummy_a2a: Optional[bool] = field( |
| 144 | + default=False, |
| 145 | + metadata={ |
| 146 | + "help": "By passes all to all during distributed training by returning the input buffer as output" |
| 147 | + }, |
| 148 | + ) |
| 149 | + moe_batch_prioritized_routing: Optional[bool] = field( |
| 150 | + default=False, |
| 151 | + metadata={ |
| 152 | + "help": "if true orders token by the gate prob before capacity dropping." |
| 153 | + }, |
| 154 | + ) |
| 155 | + use_xmoe: Optional[bool] = field( |
| 156 | + default=False, |
| 157 | + ) |
| 158 | + chunkwise_recurrent: Optional[bool] = field( |
| 159 | + default=False, |
| 160 | + ) |
| 161 | + recurrent_chunk_size: Optional[int] = field( |
| 162 | + default=512, |
| 163 | + ) |
| 164 | + |
| 165 | + |
| 166 | + # options from other parts of the config |
| 167 | + add_bos_token: bool = II("task.add_bos_token") |
| 168 | + tokens_per_sample: int = II("task.tokens_per_sample") |
| 169 | + max_target_positions: Optional[int] = II("task.max_target_positions") |
| 170 | + tpu: bool = II("common.tpu") |
| 171 | + memory_efficient_fp16: bool = II("common.memory_efficient_fp16") |
| 172 | + fp16: bool = II("common.fp16") |
| 173 | + fp16_no_flatten_grads: bool = II("common.fp16_no_flatten_grads") |
| 174 | + ddp_backend: str = II("distributed_training.ddp_backend") |
| 175 | + world_size: int = II("distributed_training.distributed_world_size") |
| 176 | + distributed_rank: int = II("distributed_training.distributed_rank") |
| 177 | + ddp_rank: int = II("distributed_training.distributed_rank") |
| 178 | + deepnorm: Optional[bool] = field( |
| 179 | + default=False, |
| 180 | + ) |
| 181 | + subln: Optional[bool] = field( |
| 182 | + default=False, |
| 183 | + ) |
| 184 | + |
| 185 | + |
| 186 | +@register_model("retnet", dataclass=LanguageConfig) |
| 187 | +class RetNetLanguageModel(FairseqLanguageModel): |
| 188 | + def __init__(self, args, decoder): |
| 189 | + self.args = args |
| 190 | + super().__init__(decoder) |
| 191 | + |
| 192 | + @classmethod |
| 193 | + def build_model(cls, args, task): |
| 194 | + |
| 195 | + if getattr(args, "max_target_positions", None) is None: |
| 196 | + args.max_target_positions = getattr( |
| 197 | + args, "tokens_per_sample", DEFAULT_MAX_TARGET_POSITIONS |
| 198 | + ) |
| 199 | + |
| 200 | + embed_tokens = cls.build_embedding( |
| 201 | + args, task.source_dictionary, args.decoder_embed_dim |
| 202 | + ) |
| 203 | + if args.share_decoder_input_output_embed: |
| 204 | + output_projection = torch.nn.Linear( |
| 205 | + embed_tokens.weight.shape[1], |
| 206 | + embed_tokens.weight.shape[0], |
| 207 | + bias=False, |
| 208 | + ) |
| 209 | + output_projection.weight = embed_tokens.weight |
| 210 | + else: |
| 211 | + output_projection = torch.nn.Linear( |
| 212 | + args.decoder_embed_dim, len(task.dictionary), bias=False |
| 213 | + ) |
| 214 | + torch.nn.init.normal_( |
| 215 | + output_projection.weight, mean=0, std=args.decoder_embed_dim**-0.5 |
| 216 | + ) |
| 217 | + |
| 218 | + if getattr(args, "moe_freq", 0) > 0 and ( |
| 219 | + getattr(args, "fp16", False) |
| 220 | + and not getattr(args, "memory_efficient_fp16", False) |
| 221 | + and getattr(args, "ddp_backend", None) != "fully_sharded" |
| 222 | + ): |
| 223 | + assert ( |
| 224 | + args.fp16_no_flatten_grads |
| 225 | + ), "If training moe models, set --fp16-no-flatten-grads to calculate correct gradnorm" |
| 226 | + |
| 227 | + args.ddp_rank = distributed_utils.get_data_parallel_rank() |
| 228 | + |
| 229 | + config = RetNetConfig() |
| 230 | + config.override(args) |
| 231 | + |
| 232 | + decoder = LMDecoder( |
| 233 | + config, |
| 234 | + embed_tokens, |
| 235 | + output_projection, |
| 236 | + dictionary=task.dictionary, |
| 237 | + ) |
| 238 | + |
| 239 | + return cls(args, decoder) |
| 240 | + |
| 241 | + @classmethod |
| 242 | + def build_embedding(cls, args, dictionary, embed_dim, path=None): |
| 243 | + return Embedding(len(dictionary), embed_dim, dictionary.pad()) |
| 244 | + |
| 245 | + |
| 246 | +class LMDecoder(RetNetDecoder, FairseqIncrementalDecoder): |
| 247 | + def forward(self, src_tokens, **kwargs): |
| 248 | + return super().forward(src_tokens, **kwargs) |
| 249 | + |
| 250 | + def max_positions(self): |
| 251 | + return self.args.max_target_positions |
| 252 | + |
| 253 | + def reorder_incremental_state_scripting( |
| 254 | + self, |
| 255 | + incremental_state, |
| 256 | + new_order, |
| 257 | + ): |
| 258 | + for module in incremental_state: |
| 259 | + for key in incremental_state[module]: |
| 260 | + result = incremental_state[module][key].index_select(0, new_order) |
| 261 | + incremental_state[module][key] = result |
| 262 | + |
| 263 | + |
| 264 | +@register_model_architecture("retnet", "retnet_base") |
| 265 | +def retnet_base_architecture(args): |
| 266 | + # backward compatibility for older model checkpoints |
| 267 | + if hasattr(args, "no_tie_adaptive_proj"): |
| 268 | + # previous models defined --no-tie-adaptive-proj, so use the existence of |
| 269 | + # that option to determine if this is an "old" model checkpoint |
| 270 | + args.no_decoder_final_norm = True # old models always set this to True |
| 271 | + if args.no_tie_adaptive_proj is False: |
| 272 | + args.tie_adaptive_proj = True |
| 273 | + if hasattr(args, "decoder_final_norm"): |
| 274 | + args.no_decoder_final_norm = not args.decoder_final_norm |
| 275 | + |
| 276 | + args.dropout = getattr(args, "dropout", 0.0) |
| 277 | + |
| 278 | + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512) |
| 279 | + args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 1024) |
| 280 | + args.decoder_layers = getattr(args, "decoder_layers", 6) |
| 281 | + args.decoder_retention_heads = getattr(args, "decoder_retention_heads", 2) |
| 282 | + args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None) |
| 283 | + args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0) |
| 284 | + args.adaptive_softmax_factor = getattr(args, "adaptive_softmax_factor", 4) |
| 285 | + args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False) |
| 286 | + args.activation_fn = getattr(args, "activation_fn", "gelu") |
| 287 | + |
| 288 | + args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0) |
| 289 | + args.decoder_layers_to_keep = getattr(args, "decoder_layers_to_keep", None) |
| 290 | + |
| 291 | + args.base_layers = getattr(args, "base_layers", 0) |
| 292 | + args.base_sublayers = getattr(args, "base_sublayers", 1) |
| 293 | + args.base_shuffle = getattr(args, "base_shuffle", False) |
| 294 | + |
| 295 | + args.add_bos_token = getattr(args, "add_bos_token", False) |
| 296 | + args.no_token_positional_embeddings = getattr( |
| 297 | + args, "no_token_positional_embeddings", False |
| 298 | + ) |
| 299 | + args.share_decoder_input_output_embed = getattr( |
| 300 | + args, "share_decoder_input_output_embed", False |
| 301 | + ) |
| 302 | + args.character_embeddings = getattr(args, "character_embeddings", False) |
| 303 | + |
| 304 | + args.decoder_output_dim = getattr( |
| 305 | + args, "decoder_output_dim", args.decoder_embed_dim |
| 306 | + ) |
| 307 | + args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim) |
| 308 | + |
| 309 | + args.chunkwise_recurrent = getattr(args, "chunkwise_recurrent", False) |
| 310 | + args.recurrent_chunk_size = getattr(args, "recurrent_chunk_size", 512) |
| 311 | + |
| 312 | + # Model training is not stable without this |
| 313 | + args.decoder_normalize_before = True |
| 314 | + args.no_decoder_final_norm = getattr(args, "no_decoder_final_norm", False) |
| 315 | + |
| 316 | + args.adaptive_input = getattr(args, "adaptive_input", False) |
| 317 | + args.adaptive_input_factor = getattr(args, "adaptive_input_factor", 4) |
| 318 | + args.adaptive_input_cutoff = getattr(args, "adaptive_input_cutoff", None) |
| 319 | + |
| 320 | + args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False) |
| 321 | + args.tie_adaptive_proj = getattr(args, "tie_adaptive_proj", False) |
| 322 | + |
| 323 | + args.no_scale_embedding = getattr(args, "no_scale_embedding", False) |
| 324 | + args.layernorm_embedding = getattr(args, "layernorm_embedding", False) |
| 325 | + args.checkpoint_activations = getattr(args, "checkpoint_activations", False) |
| 326 | + args.offload_activations = getattr(args, "offload_activations", False) |
| 327 | + if args.offload_activations: |
| 328 | + args.checkpoint_activations = True |
| 329 | + |
| 330 | +@register_model_architecture("retnet", "retnet_medium") |
| 331 | +def retnet_medium(args): |
| 332 | + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1024) |
| 333 | + args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 2048) |
| 334 | + args.decoder_layers = getattr(args, "decoder_layers", 16) |
| 335 | + args.decoder_retention_heads = getattr(args, "decoder_retention_heads", 4) |
| 336 | + retnet_base_architecture(args) |
| 337 | + |
| 338 | +@register_model_architecture("retnet", "retnet_xl") |
| 339 | +def retnet_xl(args): |
| 340 | + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 2048) |
| 341 | + args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 4096) |
| 342 | + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8) |
| 343 | + args.decoder_layers = getattr(args, "decoder_layers", 24) |
| 344 | + retnet_base_architecture(args) |
| 345 | + |
| 346 | +@register_model_architecture("retnet", "retnet_3b") |
| 347 | +def retnet_3b(args): |
| 348 | + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 2560) |
| 349 | + args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 5120) |
| 350 | + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 10) |
| 351 | + args.decoder_layers = getattr(args, "decoder_layers", 32) |
| 352 | + retnet_base_architecture(args) |
| 353 | + |
| 354 | +@register_model_architecture("retnet", "retnet_7b") |
| 355 | +def retnet_7b(args): |
| 356 | + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 4096) |
| 357 | + args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 8192) |
| 358 | + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16) |
| 359 | + args.decoder_layers = getattr(args, "decoder_layers", 32) |
| 360 | + retnet_base_architecture(args) |
| 361 | + |
| 362 | +@register_model_architecture("retnet", "retnet_13b") |
| 363 | +def retnet_13b(args): |
| 364 | + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 5120) |
| 365 | + args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 10240) |
| 366 | + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 20) |
| 367 | + args.decoder_layers = getattr(args, "decoder_layers", 40) |
| 368 | + retnet_base_architecture(args) |
| 369 | + |
| 370 | +@register_model_architecture("retnet", "retnet_65b") |
| 371 | +def retnet_65b(args): |
| 372 | + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 8192) |
| 373 | + args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 16384) |
| 374 | + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 32) |
| 375 | + args.decoder_layers = getattr(args, "decoder_layers", 64) |
| 376 | + retnet_base_architecture(args) |
| 377 | + |
0 commit comments