Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
1,050 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,377 @@ | ||
# Copyright (c) 2022 Microsoft | ||
# Licensed under The MIT License [see LICENSE for details] | ||
|
||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import logging | ||
from dataclasses import dataclass, field | ||
from typing import Optional | ||
|
||
import torch | ||
from fairseq import distributed_utils, utils | ||
from fairseq.dataclass import ChoiceEnum, FairseqDataclass | ||
from fairseq.models import ( | ||
FairseqIncrementalDecoder, | ||
FairseqLanguageModel, | ||
register_model, | ||
register_model_architecture, | ||
) | ||
from fairseq.models.transformer import DEFAULT_MIN_PARAMS_TO_WRAP, Embedding | ||
from omegaconf import II | ||
|
||
from torchscale.architecture.config import RetNetConfig | ||
from torchscale.architecture.retnet import RetNetDecoder | ||
|
||
DEFAULT_MAX_TARGET_POSITIONS = 1024 | ||
logger = logging.getLogger(__name__) | ||
|
||
|
||
@dataclass | ||
class LanguageConfig(FairseqDataclass): | ||
activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field( | ||
default="relu", metadata={"help": "activation function to use"} | ||
) | ||
dropout: float = field(default=0.1, metadata={"help": "dropout probability"}) | ||
activation_dropout: float = field( | ||
default=0.0, metadata={"help": "dropout probability after activation in FFN."} | ||
) | ||
relu_dropout: float = field( | ||
default=0.0, metadata={"help": "dropout probability after activation in FFN."} | ||
) | ||
decoder_embed_dim: int = field( | ||
default=512, metadata={"help": "decoder embedding dimension"} | ||
) | ||
decoder_output_dim: int = field( | ||
default=512, metadata={"help": "decoder output dimension"} | ||
) | ||
decoder_input_dim: int = field( | ||
default=512, metadata={"help": "decoder input dimension"} | ||
) | ||
decoder_ffn_embed_dim: int = field( | ||
default=2048, metadata={"help": "decoder embedding dimension for FFN"} | ||
) | ||
decoder_layers: int = field(default=6, metadata={"help": "num decoder layers"}) | ||
decoder_retention_heads: int = field( | ||
default=2, metadata={"help": "num decoder retention heads"} | ||
) | ||
decoder_normalize_before: bool = field( | ||
default=False, metadata={"help": "apply layernorm before each decoder block"} | ||
) | ||
share_decoder_input_output_embed: bool = field( | ||
default=False, metadata={"help": "share decoder input and output embeddings"} | ||
) | ||
decoder_learned_pos: bool = field( | ||
default=False, | ||
metadata={"help": "use learned positional embeddings in the decoder"}, | ||
) | ||
layernorm_embedding: bool = field( | ||
default=False, metadata={"help": "add layernorm to embedding"} | ||
) | ||
no_scale_embedding: bool = field( | ||
default=False, metadata={"help": "if True, dont scale embeddings"} | ||
) | ||
checkpoint_activations: bool = field( | ||
default=False, metadata={"help": "checkpoint activations at each layer"} | ||
) | ||
offload_activations: bool = field( | ||
default=False, | ||
metadata={"help": "move checkpointed activations to CPU after they are used."}, | ||
) | ||
# config for Fully Sharded Data Parallel (FSDP) training | ||
min_params_to_wrap: int = field( | ||
default=DEFAULT_MIN_PARAMS_TO_WRAP, | ||
metadata={ | ||
"help": ( | ||
"minimum number of params for a layer to be wrapped with FSDP() when " | ||
"training with --ddp-backend=fully_sharded. Smaller values will " | ||
"improve memory efficiency, but may make torch.distributed " | ||
"communication less efficient due to smaller input sizes. This option " | ||
"is set to 0 (i.e., always wrap) when --checkpoint-activations or " | ||
"--offload-activations are passed." | ||
) | ||
}, | ||
) | ||
moe_freq: int = field( | ||
default=0, | ||
metadata={"help": "Frequency at which we insert MoE Transformer layers"}, | ||
) | ||
moe_expert_count: int = field( | ||
default=0, metadata={"help": "Number of experts in each MoE Layer"} | ||
) | ||
moe_gating_use_fp32: bool = field( | ||
default=False, | ||
metadata={"help": "Use FP32 computations in MoE top2 gating function"}, | ||
) | ||
moe_second_expert_policy: str = field( | ||
default="sampling", | ||
metadata={"help": "policy for second expert, options: all/sampling/random"}, | ||
) | ||
moe_normalize_gate_prob_before_dropping: bool = field( | ||
default=False, | ||
metadata={ | ||
"help": "whether to normalize gate probs before or after dropping experts for capacity and randomization" | ||
}, | ||
) | ||
moe_expert_ffn_dim: Optional[int] = field( | ||
default=None, metadata={"help": "MoE expert FFN dimension"} | ||
) | ||
moe_top1_expert: Optional[bool] = field( | ||
default=False, metadata={"help": "Use top1 gate instead of top2"} | ||
) | ||
moe_eval_capacity_token_fraction: Optional[float] = field( | ||
default=0.25, | ||
metadata={ | ||
"help": ( | ||
"Default: 0.25, Fraction of tokens as capacity during validation, " | ||
"if set to negative, use same as training. range: (0.0, 1.0]." | ||
) | ||
}, | ||
) | ||
moe_normalize_expert_grad: Optional[str] = field( | ||
default="world_size", | ||
metadata={ | ||
"help": "Divide expert gradients by (1) 'world_size' (2) 'sqrt_world_size'" | ||
}, | ||
) | ||
record_a2a_perf_stats: Optional[bool] = field( | ||
default=False, | ||
metadata={"help": "records all to all perf stats during distributed training"}, | ||
) | ||
dummy_a2a: Optional[bool] = field( | ||
default=False, | ||
metadata={ | ||
"help": "By passes all to all during distributed training by returning the input buffer as output" | ||
}, | ||
) | ||
moe_batch_prioritized_routing: Optional[bool] = field( | ||
default=False, | ||
metadata={ | ||
"help": "if true orders token by the gate prob before capacity dropping." | ||
}, | ||
) | ||
use_xmoe: Optional[bool] = field( | ||
default=False, | ||
) | ||
chunkwise_recurrent: Optional[bool] = field( | ||
default=False, | ||
) | ||
recurrent_chunk_size: Optional[int] = field( | ||
default=512, | ||
) | ||
|
||
|
||
# options from other parts of the config | ||
add_bos_token: bool = II("task.add_bos_token") | ||
tokens_per_sample: int = II("task.tokens_per_sample") | ||
max_target_positions: Optional[int] = II("task.max_target_positions") | ||
tpu: bool = II("common.tpu") | ||
memory_efficient_fp16: bool = II("common.memory_efficient_fp16") | ||
fp16: bool = II("common.fp16") | ||
fp16_no_flatten_grads: bool = II("common.fp16_no_flatten_grads") | ||
ddp_backend: str = II("distributed_training.ddp_backend") | ||
world_size: int = II("distributed_training.distributed_world_size") | ||
distributed_rank: int = II("distributed_training.distributed_rank") | ||
ddp_rank: int = II("distributed_training.distributed_rank") | ||
deepnorm: Optional[bool] = field( | ||
default=False, | ||
) | ||
subln: Optional[bool] = field( | ||
default=False, | ||
) | ||
|
||
|
||
@register_model("retnet", dataclass=LanguageConfig) | ||
class RetNetLanguageModel(FairseqLanguageModel): | ||
def __init__(self, args, decoder): | ||
self.args = args | ||
super().__init__(decoder) | ||
|
||
@classmethod | ||
def build_model(cls, args, task): | ||
|
||
if getattr(args, "max_target_positions", None) is None: | ||
args.max_target_positions = getattr( | ||
args, "tokens_per_sample", DEFAULT_MAX_TARGET_POSITIONS | ||
) | ||
|
||
embed_tokens = cls.build_embedding( | ||
args, task.source_dictionary, args.decoder_embed_dim | ||
) | ||
if args.share_decoder_input_output_embed: | ||
output_projection = torch.nn.Linear( | ||
embed_tokens.weight.shape[1], | ||
embed_tokens.weight.shape[0], | ||
bias=False, | ||
) | ||
output_projection.weight = embed_tokens.weight | ||
else: | ||
output_projection = torch.nn.Linear( | ||
args.decoder_embed_dim, len(task.dictionary), bias=False | ||
) | ||
torch.nn.init.normal_( | ||
output_projection.weight, mean=0, std=args.decoder_embed_dim**-0.5 | ||
) | ||
|
||
if getattr(args, "moe_freq", 0) > 0 and ( | ||
getattr(args, "fp16", False) | ||
and not getattr(args, "memory_efficient_fp16", False) | ||
and getattr(args, "ddp_backend", None) != "fully_sharded" | ||
): | ||
assert ( | ||
args.fp16_no_flatten_grads | ||
), "If training moe models, set --fp16-no-flatten-grads to calculate correct gradnorm" | ||
|
||
args.ddp_rank = distributed_utils.get_data_parallel_rank() | ||
|
||
config = RetNetConfig() | ||
config.override(args) | ||
|
||
decoder = LMDecoder( | ||
config, | ||
embed_tokens, | ||
output_projection, | ||
dictionary=task.dictionary, | ||
) | ||
|
||
return cls(args, decoder) | ||
|
||
@classmethod | ||
def build_embedding(cls, args, dictionary, embed_dim, path=None): | ||
return Embedding(len(dictionary), embed_dim, dictionary.pad()) | ||
|
||
|
||
class LMDecoder(RetNetDecoder, FairseqIncrementalDecoder): | ||
def forward(self, src_tokens, **kwargs): | ||
return super().forward(src_tokens, **kwargs) | ||
|
||
def max_positions(self): | ||
return self.args.max_target_positions | ||
|
||
def reorder_incremental_state_scripting( | ||
self, | ||
incremental_state, | ||
new_order, | ||
): | ||
for module in incremental_state: | ||
for key in incremental_state[module]: | ||
result = incremental_state[module][key].index_select(0, new_order) | ||
incremental_state[module][key] = result | ||
|
||
|
||
@register_model_architecture("retnet", "retnet_base") | ||
def retnet_base_architecture(args): | ||
# backward compatibility for older model checkpoints | ||
if hasattr(args, "no_tie_adaptive_proj"): | ||
# previous models defined --no-tie-adaptive-proj, so use the existence of | ||
# that option to determine if this is an "old" model checkpoint | ||
args.no_decoder_final_norm = True # old models always set this to True | ||
if args.no_tie_adaptive_proj is False: | ||
args.tie_adaptive_proj = True | ||
if hasattr(args, "decoder_final_norm"): | ||
args.no_decoder_final_norm = not args.decoder_final_norm | ||
|
||
args.dropout = getattr(args, "dropout", 0.0) | ||
|
||
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512) | ||
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 1024) | ||
args.decoder_layers = getattr(args, "decoder_layers", 6) | ||
args.decoder_retention_heads = getattr(args, "decoder_retention_heads", 2) | ||
args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None) | ||
args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0) | ||
args.adaptive_softmax_factor = getattr(args, "adaptive_softmax_factor", 4) | ||
args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False) | ||
args.activation_fn = getattr(args, "activation_fn", "gelu") | ||
|
||
args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0) | ||
args.decoder_layers_to_keep = getattr(args, "decoder_layers_to_keep", None) | ||
|
||
args.base_layers = getattr(args, "base_layers", 0) | ||
args.base_sublayers = getattr(args, "base_sublayers", 1) | ||
args.base_shuffle = getattr(args, "base_shuffle", False) | ||
|
||
args.add_bos_token = getattr(args, "add_bos_token", False) | ||
args.no_token_positional_embeddings = getattr( | ||
args, "no_token_positional_embeddings", False | ||
) | ||
args.share_decoder_input_output_embed = getattr( | ||
args, "share_decoder_input_output_embed", False | ||
) | ||
args.character_embeddings = getattr(args, "character_embeddings", False) | ||
|
||
args.decoder_output_dim = getattr( | ||
args, "decoder_output_dim", args.decoder_embed_dim | ||
) | ||
args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim) | ||
|
||
args.chunkwise_recurrent = getattr(args, "chunkwise_recurrent", False) | ||
args.recurrent_chunk_size = getattr(args, "recurrent_chunk_size", 512) | ||
|
||
# Model training is not stable without this | ||
args.decoder_normalize_before = True | ||
args.no_decoder_final_norm = getattr(args, "no_decoder_final_norm", False) | ||
|
||
args.adaptive_input = getattr(args, "adaptive_input", False) | ||
args.adaptive_input_factor = getattr(args, "adaptive_input_factor", 4) | ||
args.adaptive_input_cutoff = getattr(args, "adaptive_input_cutoff", None) | ||
|
||
args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False) | ||
args.tie_adaptive_proj = getattr(args, "tie_adaptive_proj", False) | ||
|
||
args.no_scale_embedding = getattr(args, "no_scale_embedding", False) | ||
args.layernorm_embedding = getattr(args, "layernorm_embedding", False) | ||
args.checkpoint_activations = getattr(args, "checkpoint_activations", False) | ||
args.offload_activations = getattr(args, "offload_activations", False) | ||
if args.offload_activations: | ||
args.checkpoint_activations = True | ||
|
||
@register_model_architecture("retnet", "retnet_medium") | ||
def retnet_medium(args): | ||
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1024) | ||
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 2048) | ||
args.decoder_layers = getattr(args, "decoder_layers", 16) | ||
args.decoder_retention_heads = getattr(args, "decoder_retention_heads", 4) | ||
retnet_base_architecture(args) | ||
|
||
@register_model_architecture("retnet", "retnet_xl") | ||
def retnet_xl(args): | ||
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 2048) | ||
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 4096) | ||
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8) | ||
args.decoder_layers = getattr(args, "decoder_layers", 24) | ||
retnet_base_architecture(args) | ||
|
||
@register_model_architecture("retnet", "retnet_3b") | ||
def retnet_3b(args): | ||
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 2560) | ||
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 5120) | ||
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 10) | ||
args.decoder_layers = getattr(args, "decoder_layers", 32) | ||
retnet_base_architecture(args) | ||
|
||
@register_model_architecture("retnet", "retnet_7b") | ||
def retnet_7b(args): | ||
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 4096) | ||
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 8192) | ||
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16) | ||
args.decoder_layers = getattr(args, "decoder_layers", 32) | ||
retnet_base_architecture(args) | ||
|
||
@register_model_architecture("retnet", "retnet_13b") | ||
def retnet_13b(args): | ||
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 5120) | ||
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 10240) | ||
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 20) | ||
args.decoder_layers = getattr(args, "decoder_layers", 40) | ||
retnet_base_architecture(args) | ||
|
||
@register_model_architecture("retnet", "retnet_65b") | ||
def retnet_65b(args): | ||
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 8192) | ||
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 16384) | ||
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 32) | ||
args.decoder_layers = getattr(args, "decoder_layers", 64) | ||
retnet_base_architecture(args) | ||
|
Oops, something went wrong.