diff --git a/llm/peft/lora/lora_seq2seq.ipynb b/llm/peft/lora/lora_seq2seq.ipynb index 2c9f9602a..d86d7cb74 100644 --- a/llm/peft/lora/lora_seq2seq.ipynb +++ b/llm/peft/lora/lora_seq2seq.ipynb @@ -2,37 +2,11 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "id": "5f93b7d1", "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "[WARNING] ME(4027200:281473120071712,MainProcess):2024-10-21-16:48:33.250.525 [mindspore/run_check/_check_version.py:357] MindSpore version 2.3.1 and Ascend AI software package (Ascend Data Center Solution)version 7.5 does not match, the version of software package expect one of ['7.2', '7.3']. Please refer to the match info on: https://www.mindspore.cn/install\n", - "/home/lvyufeng/miniconda3/envs/mindspore/lib/python3.9/site-packages/numpy/core/getlimits.py:549: UserWarning: The value of the smallest subnormal for type is zero.\n", - " setattr(self, word, getattr(machar, word).flat[0])\n", - "/home/lvyufeng/miniconda3/envs/mindspore/lib/python3.9/site-packages/numpy/core/getlimits.py:89: UserWarning: The value of the smallest subnormal for type is zero.\n", - " return self._float_to_str(self.smallest_subnormal)\n", - "/home/lvyufeng/miniconda3/envs/mindspore/lib/python3.9/site-packages/numpy/core/getlimits.py:549: UserWarning: The value of the smallest subnormal for type is zero.\n", - " setattr(self, word, getattr(machar, word).flat[0])\n", - "/home/lvyufeng/miniconda3/envs/mindspore/lib/python3.9/site-packages/numpy/core/getlimits.py:89: UserWarning: The value of the smallest subnormal for type is zero.\n", - " return self._float_to_str(self.smallest_subnormal)\n", - "[WARNING] ME(4027200:281473120071712,MainProcess):2024-10-21-16:48:35.646.327 [mindspore/run_check/_check_version.py:375] MindSpore version 2.3.1 and \"te\" wheel package version 7.5 does not match. For details, refer to the installation guidelines: https://www.mindspore.cn/install\n", - "[WARNING] ME(4027200:281473120071712,MainProcess):2024-10-21-16:48:35.649.853 [mindspore/run_check/_check_version.py:382] MindSpore version 2.3.1 and \"hccl\" wheel package version 7.5 does not match. For details, refer to the installation guidelines: https://www.mindspore.cn/install\n", - "[WARNING] ME(4027200:281473120071712,MainProcess):2024-10-21-16:48:35.652.151 [mindspore/run_check/_check_version.py:396] Please pay attention to the above warning, countdown: 3\n", - "[WARNING] ME(4027200:281473120071712,MainProcess):2024-10-21-16:48:36.654.522 [mindspore/run_check/_check_version.py:396] Please pay attention to the above warning, countdown: 2\n", - "[WARNING] ME(4027200:281473120071712,MainProcess):2024-10-21-16:48:37.657.277 [mindspore/run_check/_check_version.py:396] Please pay attention to the above warning, countdown: 1\n", - "Building prefix dict from the default dictionary ...\n", - "Loading model from cache /tmp/jieba.cache\n", - "Loading model cost 0.910 seconds.\n", - "Prefix dict has been built successfully.\n" - ] - } - ], + "outputs": [], "source": [ - "import os\n", "import mindspore\n", "from mindnlp.transformers import AutoModelForSeq2SeqLM\n", "from mindnlp.peft import get_peft_config, get_peft_model, get_peft_model_state_dict, LoraConfig, TaskType\n", @@ -40,7 +14,7 @@ "from mindnlp.core import ops\n", "\n", "from mindnlp.transformers import AutoTokenizer\n", - "from mindnlp.common.optimization import get_linear_schedule_with_warmup\n", + "from mindnlp.transformers.optimization import get_linear_schedule_with_warmup\n", "from tqdm import tqdm\n", "\n", "model_name_or_path = \"bigscience/mt0-large\"\n", @@ -55,30 +29,12 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "id": "8d0850ac", "metadata": { "scrolled": true }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "MT5ForConditionalGeneration has generative capabilities, as `prepare_inputs_for_generation` is explicitly overwritten. However, it doesn't directly inherit from `GenerationMixin`.`PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability to call `generate` and other related functions.\n", - " - If you are the owner of the model architecture code, please modify your model class such that it inherits from `GenerationMixin` (after `PreTrainedModel`, otherwise you'll get an exception).\n", - " - If you are not the owner of the model architecture class, please contact the model code owner to update it.\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[MS_ALLOC_CONF]Runtime config: enable_vmm:True vmm_align_size:2MB\n", - "trainable params: 2,359,296 || all params: 1,231,940,608 || trainable%: 0.19151053100118282\n" - ] - } - ], + "outputs": [], "source": [ "# creating model\n", "peft_config = LoraConfig(task_type=TaskType.SEQ_2_SEQ_LM, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1)\n", @@ -90,32 +46,10 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "id": "4ee2babf", "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "ccab873e2c5a4eb884364c23445f920f", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Downloading builder script: 0%| | 0.00/6.04k [00:00 use the model max length (it's actually the default) + outputs = tokenizer(examples["sentence1"], examples["sentence2"], truncation=True, max_length=None) + return outputs + + +tokenized_datasets = datasets.map( + tokenize_function, + batched=True, + remove_columns=["idx", "sentence1", "sentence2"], +) + +# We also rename the 'label' column to 'labels' which is the expected name for labels by the models of the +# transformers library +tokenized_datasets = tokenized_datasets.rename_column("label", "labels") + + +def collate_fn(examples): + return tokenizer.pad(examples, padding="longest", return_tensors="pt") + + +# Instantiate dataloaders. +train_dataloader = DataLoader(tokenized_datasets["train"], shuffle=True, collate_fn=collate_fn, batch_size=batch_size) +eval_dataloader = DataLoader( + tokenized_datasets["validation"], shuffle=False, collate_fn=collate_fn, batch_size=batch_size +) + +model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path, return_dict=True) +model = get_peft_model(model, peft_config) +model.print_trainable_parameters() +model \ No newline at end of file diff --git a/mindnlp/__init__.py b/mindnlp/__init__.py index f44c5978b..27e734e48 100644 --- a/mindnlp/__init__.py +++ b/mindnlp/__init__.py @@ -21,21 +21,19 @@ import platform from packaging import version +# huggingface env if os.environ.get('HF_ENDPOINT', None) is None: os.environ["HF_ENDPOINT"] = 'https://hf-mirror.com' -os.environ["MS_DEV_FORCE_ACL"] = '1' -os.environ['TOKENIZERS_PARALLELISM'] = 'false' + +# for huawei cloud modelarts if 'RANK_TABLE_FILE' in os.environ: del os.environ['RANK_TABLE_FILE'] -DEVICE_TARGET = os.environ.get('DEVICE_TARGET', None) import mindspore from mindspore import context from mindspore._c_expression import MSContext # pylint: disable=no-name-in-module, import-error -if DEVICE_TARGET is not None and DEVICE_TARGET in ('CPU', 'GPU', 'Ascend'): - context.set_context(device_target=DEVICE_TARGET) - +# for different ascend devices if platform.system().lower() == 'linux': SOC = MSContext.get_instance().get_ascend_soc_version() if ('910b' not in SOC and '310' not in SOC) or version.parse(mindspore.__version__) < version.parse('2.4.0'): @@ -44,17 +42,15 @@ if SOC in ('ascend910', 'ascend310b'): context.set_context(ascend_config={"precision_mode": "allow_mix_precision"}) -if version.parse(mindspore.__version__) < version.parse('2.3.0'): - mindspore.mint = None - -from . import integrations - -import transformers -import evaluate -import mindtorch +# set mindnlp.core to torch +from .utils.torch_proxy import initialize_torch_proxy, setup_metadata_patch +from .utils.safetensors_patch import setup_safetensors_patch +from .core._tensor import enable_mindspore_patch -sys.modules["mindnlp.transformers"] = transformers -sys.modules["mindnlp.evaluate"] = evaluate -sys.modules["mindnlp.core"] = mindtorch +enable_mindspore_patch() +initialize_torch_proxy() +setup_metadata_patch() +setup_safetensors_patch() -__all__ = ['transformers', 'evaluate', 'core'] \ No newline at end of file +from . import core +from . import transformers diff --git a/mindnlp/core/_C/__init__.py b/mindnlp/core/_C/__init__.py new file mode 100644 index 000000000..079ddaece --- /dev/null +++ b/mindnlp/core/_C/__init__.py @@ -0,0 +1,33 @@ +from mindspore import default_generator, Generator as msGenerator + +from . import _nn + +def _jit_set_profiling_executor(mode): + pass + +def _jit_set_profiling_executor(mode): + pass + +def _jit_set_profiling_mode(mode): + pass + +def _jit_override_can_fuse_on_cpu(mode): + pass + +def _jit_override_can_fuse_on_gpu(mode): + pass + +def _jit_set_texpr_fuser_enabled(mode): + pass + +def _debug_set_autodiff_subgraph_inlining(mode): + pass + +Graph = None +Value = None + +DisableTorchFunctionSubclass = None + +class Generator(msGenerator): + def __init__(self, device='cpu'): + super().__init__() diff --git a/mindnlp/core/_C/_nn.py b/mindnlp/core/_C/_nn.py new file mode 100644 index 000000000..a868c046e --- /dev/null +++ b/mindnlp/core/_C/_nn.py @@ -0,0 +1,40 @@ +from mindnlp import core +from ..types import device as device_ + +def _parse_to(*args, **kwargs): + """ + Mimic core._C._nn._parse_to functionality in Python. + + Args: + tensor (core.Tensor): The tensor to parse. + *args: Positional arguments for `to`. + **kwargs: Keyword arguments for `to`. + + Returns: + core.Tensor: The tensor with the desired properties. + """ + if len(args) == 1: + # Handle `device` or `dtype` + if isinstance(args[0], core.dtype): # dtype only + dtype = args[0] + device = None + elif isinstance(args[0], core.device): # device only + device = args[0] + dtype = None + elif isinstance(args[0], str): + device = device_(args[0]) + dtype = None + else: + raise TypeError(f"Expected core.dtype or core.device, but got {type(args[0])}") + elif len(args) == 2: + # Handle `device` and `dtype` + dtype = args[1] + device = args[0] + else: + dtype = kwargs.get("dtype", None) + device = kwargs.get("device", None) + + non_blocking = kwargs.get("non_blocking", False) + memory_format = kwargs.get("memory_format", None) + + return device, dtype, non_blocking, memory_format diff --git a/mindnlp/core/_C/size.py b/mindnlp/core/_C/size.py new file mode 100644 index 000000000..8d04e51d9 --- /dev/null +++ b/mindnlp/core/_C/size.py @@ -0,0 +1,22 @@ +import operator +from functools import reduce + + +def _get_tuple_numel(input): + if input == (): + return 1 + return reduce(operator.mul, list(input)) + + +class Size(tuple): + def __new__(cls, shape=()): + _shape = shape + if not isinstance(_shape, (tuple, list)): + raise TypeError("{} object is not supportted.".format(type(shape))) + return tuple.__new__(Size, _shape) + + def numel(self): + return _get_tuple_numel(self) + + def __repr__(self): + return "core.Size(" + str(list(self)) + ")" diff --git a/mindnlp/core/__future__.py b/mindnlp/core/__future__.py new file mode 100644 index 000000000..174e9b125 --- /dev/null +++ b/mindnlp/core/__future__.py @@ -0,0 +1,75 @@ +_overwrite_module_params_on_conversion: bool = False +_swap_module_params_on_conversion: bool = False + + +def set_overwrite_module_params_on_conversion(value: bool) -> None: + """ + Sets whether to assign new tensors to the parameters instead of changing the + existing parameters in-place when converting an ``nn.Module``. + + When enabled, the following methods will assign new parameters to the module: + + #. ``module.{device}()`` (e.g. :meth:`nn.Module.cuda()`) for moving a module between devices + #. ``module.{dtype}()`` (e.g. :meth:`nn.Module.float()`) for converting a module to a different dtype + #. :meth:`nn.Module.to` + #. :meth:`nn.Module.to_empty` + + Args: + value (bool): Whether to assign new tensors or not. + + """ + global _overwrite_module_params_on_conversion + _overwrite_module_params_on_conversion = value + + +def get_overwrite_module_params_on_conversion() -> bool: + """ + Returns whether to assign new tensors to the parameters instead of changing the + existing parameters in-place when converting an :class:`torch.nn.Module`. Defaults to ``False``. + + See :func:`~torch.__future__.set_overwrite_module_params_on_conversion` for more information. + """ + return _overwrite_module_params_on_conversion + + +def set_swap_module_params_on_conversion(value: bool) -> None: + """ + Sets whether to use :func:`~torch.utils.swap_tensors` instead of setting ``.data`` to + change the existing parameters in-place when converting an ``nn.Module`` and instead + of ``param.copy_(state_dict[key])`` when loading a state dict into an ``nn.Module``. + + .. note:: + This function takes precedence over :func:`~torch.__future__.get_overwrite_module_params_on_conversion` + + When enabled, the following methods will swap the existing parameters in-place: + + #. ``module.{device}()`` (e.g. :meth:`nn.Module.cuda()`) for moving a module between devices + #. ``module.{dtype}()`` (e.g. :meth:`nn.Module.float()`) for converting a module to a different dtype + #. :meth:`nn.Module.to` + #. :meth:`nn.Module.to_empty` + #. :meth:`nn.Module.load_state_dict` + + The semantics for :meth:`~nn.Module.load_state_dict` when this is set are as follows: + + #. For each parameter/buffer, its corresponding ``state_dict['key']`` is transformed via + :meth:`~torch.Tensor.module_load` (i.e. ``res = param.module_load(state_dict['key'])``) + #. If necessary, ``res`` will be wrapped in an :class:`~nn.Parameter` + #. The parameter/buffer in the module will be swapped via :func:`~torch.utils.swap_tensors` + with ``res`` + + Args: + value (bool): Whether to use :func:`~torch.utils.swap_tensors` or not. + + """ + global _swap_module_params_on_conversion + _swap_module_params_on_conversion = value + + +def get_swap_module_params_on_conversion() -> bool: + """ + Returns whether to use :func:`~torch.utils.swap_tensors` instead of setting .data to + change the existing parameters in-place when converting an ``nn.Module``. Defaults to ``False``. + + See :func:`~torch.__future__.set_swap_module_params_on_conversion` for more information. + """ + return _swap_module_params_on_conversion \ No newline at end of file diff --git a/mindnlp/core/__init__.py b/mindnlp/core/__init__.py new file mode 100644 index 000000000..ec74c77c0 --- /dev/null +++ b/mindnlp/core/__init__.py @@ -0,0 +1,50 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""core module""" +import os +import platform +from typing import ( + Any as _Any, + Callable as _Callable, + get_origin as _get_origin, + Optional as _Optional, + overload as _overload, + TYPE_CHECKING, + TypeVar as _TypeVar, + Union as _Union, +) + +strided = None +contiguous_format = None +preserve_format = None + +inf = float("inf") +nan = float("nan") + +from ._C import * +from ._dtype import * +from ._tensor import Tensor, tensor, is_tensor, \ + LongTensor, FloatTensor, BoolTensor, HalfTensor, BFloat16Tensor +from .types import device +from ._C.size import Size +from .types import device +from .autograd import * +from .ops import * +from .serialization import load, save +from ._bind import get_default_dtype, set_default_dtype + +from . import profiler, cuda, optim, amp, compiler, jit, version, __future__, overrides, \ + return_types + diff --git a/mindnlp/core/_bind.py b/mindnlp/core/_bind.py new file mode 100644 index 000000000..fd4dfcfd0 --- /dev/null +++ b/mindnlp/core/_bind.py @@ -0,0 +1,151 @@ +import ctypes +from typing import Any +from ._dtype import * +from .types import device as device_ + +DEFAULT_DTYPE, DEFAULT_DEVICE = float32, device_('cpu') + +AUTO_CAST_DTYE = { + 'cuda': float16, + 'cpu': bfloat16, + 'npu': float16 +} + +def set_autocast_dtype(device_type, dtype): + assert device_type in AUTO_CAST_DTYE.keys(), f'{device_type} is not in {AUTO_CAST_DTYE.keys()}' + AUTO_CAST_DTYE[device_type] = dtype + +def get_autocast_dtype(device_type): + return AUTO_CAST_DTYE[device_type] + +def set_default_dtype(dtype): + """set default dtype""" + global DEFAULT_DTYPE + DEFAULT_DTYPE = dtype + +def get_default_dtype(): + """get default dtype""" + return DEFAULT_DTYPE + +def set_default_device(device): + """set default dtype""" + global DEFAULT_DEVICE + if isinstance(device, str): + device = device_(device) + DEFAULT_DEVICE = device + +def get_default_device(): + """get default dtype""" + return DEFAULT_DEVICE + +bits_map = { + +} + +min_map = { + float32: -3.40282e+38, + float16: -65504, + bfloat16: -3.38953e+38 +} + +max_map = { + float32: 3.40282e+38, + float16: 65504, + bfloat16: 3.38953e+38 +} + +eps_map = { + float32: 1.19209e-07, + float16: 0.000976562, + bfloat16: 0.0078125 +} + +tiny_map = { + float32: 1.17549e-38, + float16: 6.10352e-05, + bfloat16: 1.17549e-38 +} + +smallest_normal_map = { + float32: 1.17549e-38, + float16: 6.10352e-05, + bfloat16: 1.17549e-38 +} + +resolution_map = { + float32: 1e-06, + float16: 0.001, + bfloat16: 0.01 +} +class iinfo: + def __init__(self, dtype): + self._dtype = dtype + + @property + def bits(self): + return bits_map[self._dtype] + + @property + def min(self): + return min_map[self._dtype] + + @property + def max(self): + return max_map[self._dtype] + + @property + def dtype(self): + return str(self._dtype) + + +class finfo: + def __init__(self, dtype): + self._dtype = dtype + + @property + def bits(self): + return bits_map[self._dtype] + + @property + def min(self): + return min_map[self._dtype] + + @property + def max(self): + return max_map[self._dtype] + + @property + def eps(self): + return eps_map[self._dtype] + + @property + def tiny(self): + return tiny_map[self._dtype] + + @property + def smallest_normal(self): + return smallest_normal_map[self._dtype] + + @property + def resolution(self): + return resolution_map[self._dtype] + + @property + def dtype(self): + return str(self._dtype) + +def asarray(obj: Any, *, dtype, device=None, copy = None, requires_grad = False): + data = obj.data.view(core.dtype2np[dtype]) + out = core.Tensor(data) + core._utils.set_device_address(out) + return out + +def view(self, dtype): + data_ptr = self.data_ptr() + nbytes = self.nbytes + data = np.ctypeslib.as_array((ctypes.c_byte * nbytes).from_address(data_ptr), shape=(nbytes,)) + data = data.view(core.dtype2np[dtype]) + assert data_ptr == data.ctypes.data + out = core.Tensor(data) + core._utils.set_device_address(out) + return out \ No newline at end of file diff --git a/mindnlp/core/_custom_ops.py b/mindnlp/core/_custom_ops.py new file mode 100644 index 000000000..e69de29bb diff --git a/mindnlp/core/_dtype.py b/mindnlp/core/_dtype.py new file mode 100644 index 000000000..5bb777f92 --- /dev/null +++ b/mindnlp/core/_dtype.py @@ -0,0 +1,51 @@ +import numpy as np +from mindspore.common.dtype import * +from mindspore._c_expression import typing +from mindspore._c_expression.typing import Type + +dtype = Type + +def is_floating_point(self): + return isinstance(self, (typing.Float, typing.BFloat16)) + +Type.is_floating_point = is_floating_point + +half = float16 +float = float32 +double = float64 + +long = int64 +int = int32 +bool = bool_ + +float8_e4m3fn = None # TODO: not support fp8 for now + +np2dtype = { + np.bool_: bool, + np.int8: int8, + np.int16: int16, + np.int32: int32, + np.int64: int64, + np.uint8: uint8, + np.uint16: uint16, + np.uint32: uint32, + np.uint64: uint64, + np.float16: float16, + np.float32: float32, + np.float64: float64, +} + +dtype2np = { + bool : np.bool_, + int8 : np.int8, + int16 : np.int16, + int32 : np.int32, + int64 : np.int64, + uint8 : np.uint8, + uint16 : np.uint16, + uint32 : np.uint32, + uint64 : np.uint64, + float16 : np.float16, + float32 : np.float32, + float64 : np.float64, +} \ No newline at end of file diff --git a/mindnlp/core/_dynamo/__init__.py b/mindnlp/core/_dynamo/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/mindnlp/core/_dynamo/utils.py b/mindnlp/core/_dynamo/utils.py new file mode 100644 index 000000000..1f7544cae --- /dev/null +++ b/mindnlp/core/_dynamo/utils.py @@ -0,0 +1,3 @@ +def is_compile_supported(device_type): + compile_supported = False + return compile_supported \ No newline at end of file diff --git a/mindnlp/core/_jit_internal.py b/mindnlp/core/_jit_internal.py new file mode 100644 index 000000000..6f2ed46d4 --- /dev/null +++ b/mindnlp/core/_jit_internal.py @@ -0,0 +1,82 @@ +class FunctionModifiers: + """ + Used to denote the behavior of a function in TorchScript. See export() and + ignore() for details. + """ + + UNUSED = "unused (ignored and replaced with raising of an exception)" + IGNORE = "ignore (leave as a call to Python, cannot be torch.jit.save'd)" + EXPORT = "export (compile this function even if nothing calls it)" + DEFAULT = "default (compile if called from a exported function / forward)" + COPY_TO_SCRIPT_WRAPPER = ( + "if this method is not scripted, copy the python method onto the scripted model" + ) + _DROP = "_drop (function is fully ignored, declaration can be unscriptable)" + +def unused(fn): + """ + This decorator indicates to the compiler that a function or method should + be ignored and replaced with the raising of an exception. This allows you + to leave code in your model that is not yet TorchScript compatible and still + export your model. + + Example (using ``@torch.jit.unused`` on a method):: + + import torch + import torch.nn as nn + + + class MyModule(nn.Module): + def __init__(self, use_memory_efficient): + super().__init__() + self.use_memory_efficient = use_memory_efficient + + @torch.jit.unused + def memory_efficient(self, x): + import pdb + + pdb.set_trace() + return x + 10 + + def forward(self, x): + # Use not-yet-scriptable memory efficient mode + if self.use_memory_efficient: + return self.memory_efficient(x) + else: + return x + 10 + + + m = torch.jit.script(MyModule(use_memory_efficient=False)) + m.save("m.pt") + + m = torch.jit.script(MyModule(use_memory_efficient=True)) + # exception raised + m(torch.rand(100)) + """ + if isinstance(fn, property): + prop = fn + setattr( # noqa: B010 + prop.fget, "_torchscript_modifier", FunctionModifiers.UNUSED + ) + + if prop.fset: + setattr( # noqa: B010 + prop.fset, "_torchscript_modifier", FunctionModifiers.UNUSED + ) + + return prop + + fn._torchscript_modifier = FunctionModifiers.UNUSED + return fn + +# allows BroadcastingList instance to be subscriptable +class BroadcastingListCls: + def __getitem__(self, types): + return + + +# mypy doesn't support parameters on types, so we have to explicitly type each +# list size +BroadcastingList1 = BroadcastingListCls() +for i in range(2, 7): + globals()[f"BroadcastingList{i}"] = BroadcastingList1 diff --git a/mindnlp/core/_linalg_utils.py b/mindnlp/core/_linalg_utils.py new file mode 100644 index 000000000..fc80b2de4 --- /dev/null +++ b/mindnlp/core/_linalg_utils.py @@ -0,0 +1,150 @@ +# mypy: allow-untyped-defs +"""Various linear algebra utility methods for internal use.""" + +from typing import Optional + +from mindnlp import core +from mindnlp.core import Tensor + + +def is_sparse(A): + """Check if tensor A is a sparse tensor""" + if isinstance(A, core.Tensor): + return A.layout == core.sparse_coo + + error_str = "expected Tensor" + if not core.jit.is_scripting(): + error_str += f" but got {type(A)}" + raise TypeError(error_str) + + +def get_floating_dtype(A): + """Return the floating point dtype of tensor A. + + Integer types map to float32. + """ + dtype = A.dtype + if dtype in (core.float16, core.float32, core.float64): + return dtype + return core.float32 + + +def matmul(A: Optional[Tensor], B: Tensor) -> Tensor: + """Multiply two matrices. + + If A is None, return B. A can be sparse or dense. B is always + dense. + """ + if A is None: + return B + if is_sparse(A): + return core.sparse.mm(A, B) + return core.matmul(A, B) + + +def bform(X: Tensor, A: Optional[Tensor], Y: Tensor) -> Tensor: + """Return bilinear form of matrices: :math:`X^T A Y`.""" + return matmul(X.mT, matmul(A, Y)) + + +def qform(A: Optional[Tensor], S: Tensor): + """Return quadratic form :math:`S^T A S`.""" + return bform(S, A, S) + + +def basis(A): + """Return orthogonal basis of A columns.""" + return core.linalg.qr(A).Q + + +def symeig(A: Tensor, largest: Optional[bool] = False) -> tuple[Tensor, Tensor]: + """Return eigenpairs of A with specified ordering.""" + if largest is None: + largest = False + E, Z = core.linalg.eigh(A, UPLO="U") + # assuming that E is ordered + if largest: + E = core.flip(E, dims=(-1,)) + Z = core.flip(Z, dims=(-1,)) + return E, Z + + +# These functions were deprecated and removed +# This nice error message can be removed in version 1.13+ +def matrix_rank(input, tol=None, symmetric=False, *, out=None) -> Tensor: + raise RuntimeError( + "This function was deprecated since version 1.9 and is now removed.\n" + "Please use the `core.linalg.matrix_rank` function instead. " + "The parameter 'symmetric' was renamed in `core.linalg.matrix_rank()` to 'hermitian'." + ) + + +def solve(input: Tensor, A: Tensor, *, out=None) -> tuple[Tensor, Tensor]: + raise RuntimeError( + "This function was deprecated since version 1.9 and is now removed. " + "`core.solve` is deprecated in favor of `core.linalg.solve`. " + "`core.linalg.solve` has its arguments reversed and does not return the LU factorization.\n\n" + "To get the LU factorization see `core.lu`, which can be used with `core.lu_solve` or `core.lu_unpack`.\n" + "X = core.solve(B, A).solution " + "should be replaced with:\n" + "X = core.linalg.solve(A, B)" + ) + + +def lstsq(input: Tensor, A: Tensor, *, out=None) -> tuple[Tensor, Tensor]: + raise RuntimeError( + "This function was deprecated since version 1.9 and is now removed. " + "`core.lstsq` is deprecated in favor of `core.linalg.lstsq`.\n" + "`core.linalg.lstsq` has reversed arguments and does not return the QR decomposition in " + "the returned tuple (although it returns other information about the problem).\n\n" + "To get the QR decomposition consider using `core.linalg.qr`.\n\n" + "The returned solution in `core.lstsq` stored the residuals of the solution in the " + "last m - n columns of the returned value whenever m > n. In core.linalg.lstsq, " + "the residuals are in the field 'residuals' of the returned named tuple.\n\n" + "The unpacking of the solution, as in\n" + "X, _ = core.lstsq(B, A).solution[:A.size(1)]\n" + "should be replaced with:\n" + "X = core.linalg.lstsq(A, B).solution" + ) + + +def _symeig( + input, + eigenvectors=False, + upper=True, + *, + out=None, +) -> tuple[Tensor, Tensor]: + raise RuntimeError( + "This function was deprecated since version 1.9 and is now removed. " + "The default behavior has changed from using the upper triangular portion of the matrix by default " + "to using the lower triangular portion.\n\n" + "L, _ = core.symeig(A, upper=upper) " + "should be replaced with:\n" + "L = core.linalg.eigvalsh(A, UPLO='U' if upper else 'L')\n\n" + "and\n\n" + "L, V = core.symeig(A, eigenvectors=True) " + "should be replaced with:\n" + "L, V = core.linalg.eigh(A, UPLO='U' if upper else 'L')" + ) + + +def eig( + self: Tensor, + eigenvectors: bool = False, + *, + e=None, + v=None, +) -> tuple[Tensor, Tensor]: + raise RuntimeError( + "This function was deprecated since version 1.9 and is now removed. " + "`core.linalg.eig` returns complex tensors of dtype `cfloat` or `cdouble` rather than real tensors " + "mimicking complex tensors.\n\n" + "L, _ = core.eig(A) " + "should be replaced with:\n" + "L_complex = core.linalg.eigvals(A)\n\n" + "and\n\n" + "L, V = core.eig(A, eigenvectors=True) " + "should be replaced with:\n" + "L_complex, V_complex = core.linalg.eig(A)" + ) \ No newline at end of file diff --git a/mindnlp/core/_lowrank.py b/mindnlp/core/_lowrank.py new file mode 100644 index 000000000..c03d4f468 --- /dev/null +++ b/mindnlp/core/_lowrank.py @@ -0,0 +1,294 @@ +"""Implement various linear algebra algorithms for low rank matrices.""" + +__all__ = ["svd_lowrank", "pca_lowrank"] + +from typing import Optional + +from mindnlp import core +from mindnlp.core import _linalg_utils as _utils, Tensor +from core.overrides import handle_torch_function, has_torch_function + + +def get_approximate_basis( + A: Tensor, + q: int, + niter: Optional[int] = 2, + M: Optional[Tensor] = None, +) -> Tensor: + """Return tensor :math:`Q` with :math:`q` orthonormal columns such + that :math:`Q Q^H A` approximates :math:`A`. If :math:`M` is + specified, then :math:`Q` is such that :math:`Q Q^H (A - M)` + approximates :math:`A - M`. without instantiating any tensors + of the size of :math:`A` or :math:`M`. + + .. note:: The implementation is based on the Algorithm 4.4 from + Halko et al., 2009. + + .. note:: For an adequate approximation of a k-rank matrix + :math:`A`, where k is not known in advance but could be + estimated, the number of :math:`Q` columns, q, can be + choosen according to the following criteria: in general, + :math:`k <= q <= min(2*k, m, n)`. For large low-rank + matrices, take :math:`q = k + 5..10`. If k is + relatively small compared to :math:`min(m, n)`, choosing + :math:`q = k + 0..2` may be sufficient. + + .. note:: To obtain repeatable results, reset the seed for the + pseudorandom number generator + + Args:: + A (Tensor): the input tensor of size :math:`(*, m, n)` + + q (int): the dimension of subspace spanned by :math:`Q` + columns. + + niter (int, optional): the number of subspace iterations to + conduct; ``niter`` must be a + nonnegative integer. In most cases, the + default value 2 is more than enough. + + M (Tensor, optional): the input tensor's mean of size + :math:`(*, m, n)`. + + References:: + - Nathan Halko, Per-Gunnar Martinsson, and Joel Tropp, Finding + structure with randomness: probabilistic algorithms for + constructing approximate matrix decompositions, + arXiv:0909.4061 [math.NA; math.PR], 2009 (available at + `arXiv `_). + """ + + niter = 2 if niter is None else niter + dtype = _utils.get_floating_dtype(A) if not A.is_complex() else A.dtype + matmul = _utils.matmul + + R = core.randn(A.shape[-1], q, dtype=dtype, device=A.device) + + # The following code could be made faster using core.geqrf + core.ormqr + # but geqrf is not differentiable + + X = matmul(A, R) + if M is not None: + X = X - matmul(M, R) + Q = core.linalg.qr(X).Q + for _ in range(niter): + X = matmul(A.mH, Q) + if M is not None: + X = X - matmul(M.mH, Q) + Q = core.linalg.qr(X).Q + X = matmul(A, Q) + if M is not None: + X = X - matmul(M, Q) + Q = core.linalg.qr(X).Q + return Q + + +def svd_lowrank( + A: Tensor, + q: Optional[int] = 6, + niter: Optional[int] = 2, + M: Optional[Tensor] = None, +) -> tuple[Tensor, Tensor, Tensor]: + r"""Return the singular value decomposition ``(U, S, V)`` of a matrix, + batches of matrices, or a sparse matrix :math:`A` such that + :math:`A \approx U \operatorname{diag}(S) V^{\text{H}}`. In case :math:`M` is given, then + SVD is computed for the matrix :math:`A - M`. + + .. note:: The implementation is based on the Algorithm 5.1 from + Halko et al., 2009. + + .. note:: For an adequate approximation of a k-rank matrix + :math:`A`, where k is not known in advance but could be + estimated, the number of :math:`Q` columns, q, can be + choosen according to the following criteria: in general, + :math:`k <= q <= min(2*k, m, n)`. For large low-rank + matrices, take :math:`q = k + 5..10`. If k is + relatively small compared to :math:`min(m, n)`, choosing + :math:`q = k + 0..2` may be sufficient. + + .. note:: This is a randomized method. To obtain repeatable results, + set the seed for the pseudorandom number generator + + .. note:: In general, use the full-rank SVD implementation + :func:`core.linalg.svd` for dense matrices due to its 10x + higher performance characteristics. The low-rank SVD + will be useful for huge sparse matrices that + :func:`core.linalg.svd` cannot handle. + + Args:: + A (Tensor): the input tensor of size :math:`(*, m, n)` + + q (int, optional): a slightly overestimated rank of A. + + niter (int, optional): the number of subspace iterations to + conduct; niter must be a nonnegative + integer, and defaults to 2 + + M (Tensor, optional): the input tensor's mean of size + :math:`(*, m, n)`, which will be broadcasted + to the size of A in this function. + + References:: + - Nathan Halko, Per-Gunnar Martinsson, and Joel Tropp, Finding + structure with randomness: probabilistic algorithms for + constructing approximate matrix decompositions, + arXiv:0909.4061 [math.NA; math.PR], 2009 (available at + `arXiv `_). + + """ + if not core.jit.is_scripting(): + tensor_ops = (A, M) + if not set(map(type, tensor_ops)).issubset( + (core.Tensor, type(None)) + ) and has_torch_function(tensor_ops): + return handle_torch_function( + svd_lowrank, tensor_ops, A, q=q, niter=niter, M=M + ) + return _svd_lowrank(A, q=q, niter=niter, M=M) + + +def _svd_lowrank( + A: Tensor, + q: Optional[int] = 6, + niter: Optional[int] = 2, + M: Optional[Tensor] = None, +) -> tuple[Tensor, Tensor, Tensor]: + # Algorithm 5.1 in Halko et al., 2009 + + q = 6 if q is None else q + m, n = A.shape[-2:] + matmul = _utils.matmul + if M is not None: + M = M.broadcast_to(A.size()) + + # Assume that A is tall + if m < n: + A = A.mH + if M is not None: + M = M.mH + + Q = get_approximate_basis(A, q, niter=niter, M=M) + B = matmul(Q.mH, A) + if M is not None: + B = B - matmul(Q.mH, M) + U, S, Vh = core.linalg.svd(B, full_matrices=False) + V = Vh.mH + U = Q.matmul(U) + + if m < n: + U, V = V, U + + return U, S, V + + +def pca_lowrank( + A: Tensor, + q: Optional[int] = None, + center: bool = True, + niter: int = 2, +) -> tuple[Tensor, Tensor, Tensor]: + r"""Performs linear Principal Component Analysis (PCA) on a low-rank + matrix, batches of such matrices, or sparse matrix. + + This function returns a namedtuple ``(U, S, V)`` which is the + nearly optimal approximation of a singular value decomposition of + a centered matrix :math:`A` such that :math:`A \approx U \operatorname{diag}(S) V^{\text{H}}` + + .. note:: The relation of ``(U, S, V)`` to PCA is as follows: + + - :math:`A` is a data matrix with ``m`` samples and + ``n`` features + + - the :math:`V` columns represent the principal directions + + - :math:`S ** 2 / (m - 1)` contains the eigenvalues of + :math:`A^T A / (m - 1)` which is the covariance of + ``A`` when ``center=True`` is provided. + + - ``matmul(A, V[:, :k])`` projects data to the first k + principal components + + .. note:: Different from the standard SVD, the size of returned + matrices depend on the specified rank and q + values as follows: + + - :math:`U` is m x q matrix + + - :math:`S` is q-vector + + - :math:`V` is n x q matrix + + .. note:: To obtain repeatable results, reset the seed for the + pseudorandom number generator + + Args: + + A (Tensor): the input tensor of size :math:`(*, m, n)` + + q (int, optional): a slightly overestimated rank of + :math:`A`. By default, ``q = min(6, m, + n)``. + + center (bool, optional): if True, center the input tensor, + otherwise, assume that the input is + centered. + + niter (int, optional): the number of subspace iterations to + conduct; niter must be a nonnegative + integer, and defaults to 2. + + References:: + + - Nathan Halko, Per-Gunnar Martinsson, and Joel Tropp, Finding + structure with randomness: probabilistic algorithms for + constructing approximate matrix decompositions, + arXiv:0909.4061 [math.NA; math.PR], 2009 (available at + `arXiv `_). + + """ + + if not core.jit.is_scripting(): + if type(A) is not core.Tensor and has_torch_function((A,)): + return handle_torch_function( + pca_lowrank, (A,), A, q=q, center=center, niter=niter + ) + + (m, n) = A.shape[-2:] + + if q is None: + q = min(6, m, n) + elif not (q >= 0 and q <= min(m, n)): + raise ValueError( + f"q(={q}) must be non-negative integer and not greater than min(m, n)={min(m, n)}" + ) + if not (niter >= 0): + raise ValueError(f"niter(={niter}) must be non-negative integer") + + dtype = _utils.get_floating_dtype(A) + + if not center: + return _svd_lowrank(A, q, niter=niter, M=None) + + if _utils.is_sparse(A): + if len(A.shape) != 2: + raise ValueError("pca_lowrank input is expected to be 2-dimensional tensor") + c = core.sparse.sum(A, dim=(-2,)) / m + # reshape c + column_indices = c.indices()[0] + indices = core.zeros( + 2, + len(column_indices), + dtype=column_indices.dtype, + device=column_indices.device, + ) + indices[0] = column_indices + C_t = core.sparse_coo_tensor( + indices, c.values(), (n, 1), dtype=dtype, device=A.device + ) + + ones_m1_t = core.ones(A.shape[:-2] + (1, m), dtype=dtype, device=A.device) + M = core.sparse.mm(C_t, ones_m1_t).mT + return _svd_lowrank(A, q, niter=niter, M=M) + else: + C = A.mean(dim=(-2,), keepdim=True) + return _svd_lowrank(A - C, q, niter=niter, M=None) \ No newline at end of file diff --git a/mindnlp/core/_ops.py b/mindnlp/core/_ops.py new file mode 100644 index 000000000..e2ce79d52 --- /dev/null +++ b/mindnlp/core/_ops.py @@ -0,0 +1,2 @@ +class OpOverload: + pass diff --git a/mindnlp/core/_prims/__init__.py b/mindnlp/core/_prims/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/mindnlp/core/_prims/ascend/__init__.py b/mindnlp/core/_prims/ascend/__init__.py new file mode 100644 index 000000000..95e4a4e45 --- /dev/null +++ b/mindnlp/core/_prims/ascend/__init__.py @@ -0,0 +1,7 @@ +from . import aclop, pyboost +from .aclop import * +from .pyboost import * + +__all__ = [] +__all__.extend(aclop.__all__) +__all__.extend(pyboost.__all__) diff --git a/mindnlp/core/_prims/ascend/aclop.py b/mindnlp/core/_prims/ascend/aclop.py new file mode 100644 index 000000000..c59c0beaa --- /dev/null +++ b/mindnlp/core/_prims/ascend/aclop.py @@ -0,0 +1,87 @@ +from mindspore.ops.auto_generate import gen_ops_prim +from mindspore.common.api import _pynative_executor +from mindspore.ops._primitive_cache import _get_cache_prim +from mindspore.ops.auto_generate.gen_ops_prim import Range, Cdist +from mindspore.ops import StopGradient, Primitive, ApplyAdadelta, Adam, ApplyAdamWithAmsgradV2, SGD, Imag + +pyboost_list = list(filter(lambda s: s.startswith("pyboost"), dir(gen_ops_prim))) +pyboost_op_list = [op.replace('pyboost_', '') + '_op' for op in pyboost_list] +aclop_list = list(filter(lambda s: s.endswith("_op") and not s in pyboost_op_list, dir(gen_ops_prim))) + +aclop_func = ''' +def {name}(*args): + return _pynative_executor.run_op_async({obj}, {obj}.name, args) +''' + +__all__ = [] + +for op_name in aclop_list: + func_name = op_name.replace('_op', '_npu') + __all__.append(func_name) + prim_op = func_name + '_prim' + globals()[prim_op] = getattr(gen_ops_prim, op_name).__class__().set_device('Ascend') + exec(aclop_func.format(name=func_name, obj=prim_op), globals()) + +imag_op = Imag().set_device('Ascend') +def imag_npu(*args): + return _pynative_executor.run_op_async(imag_op, range_op.name, args) + +__all__.append('imag_npu') + +range_op = Range().set_device('Ascend') +def range_npu(*args): + return _pynative_executor.run_op_async(range_op, range_op.name, args) + +__all__.append('range_npu') + +cdist_op = Cdist().set_device('Ascend') +def cdist_npu(*args): + return _pynative_executor.run_op_async(cdist_op, cdist_op.name, args) + +__all__.append('cdist_npu') + + +stop_gradient_op = StopGradient().set_device('Ascend') +def stop_gradient_npu(*args): + return _pynative_executor.run_op_async(stop_gradient_op, stop_gradient_op.name, args) + +__all__.append('stop_gradient_npu') + +diagonal_op = Primitive('Diagonal').set_device('Ascend') +def diagonal_npu(*args): + return _pynative_executor.run_op_async(diagonal_op, diagonal_op.name, args) + +__all__.append('diagonal_npu') + +adadelta_op = ApplyAdadelta().set_device('Ascend') +def raw_adadelta_npu(param, square_avg, acc_delta, lr, rho, eps, grad): + args = (param, square_avg, acc_delta, lr, rho, eps, grad) + return _pynative_executor.run_op_async(adadelta_op, adadelta_op.name, args) + +adam_op = Adam().set_device('Ascend') +def raw_adam_npu(param, exp_avg, exp_avg_sq, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad): + # var, m, v, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad + args = (param, exp_avg, exp_avg_sq, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad) + return _pynative_executor.run_op_async(adam_op, adam_op.name, args) + +adam_amsgrad_op = ApplyAdamWithAmsgradV2().set_device('Ascend') +def raw_adam_amsgrad_npu(param, exp_avg, exp_avg_sq, max_exp_avg_sq, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad): + # var, m, v, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad + args = (param, exp_avg, exp_avg_sq, max_exp_avg_sq, + beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad) + return _pynative_executor.run_op_async(adam_amsgrad_op, adam_amsgrad_op.name, args) + + +def raw_sgd_npu(param, grad, lr, dampening, weight_decay, nesterov, accum, momentum, stat): + sgd_op = _get_cache_prim(SGD)(dampening, weight_decay, nesterov).set_device('Ascend') + args = (param, grad, lr, accum, momentum, stat) + return _pynative_executor.run_op_async(sgd_op, sgd_op.name, args) + +__all__.extend( + [ + 'raw_adadelta_npu', + 'raw_adam_npu', + 'raw_adam_amsgrad_npu', + 'raw_sgd_npu' + ] +) diff --git a/mindnlp/core/_prims/ascend/pyboost.py b/mindnlp/core/_prims/ascend/pyboost.py new file mode 100644 index 000000000..31bddfc26 --- /dev/null +++ b/mindnlp/core/_prims/ascend/pyboost.py @@ -0,0 +1,202 @@ +from mindspore.ops import Primitive +from mindspore.ops.auto_generate import gen_ops_prim +from mindspore.ops.auto_generate.gen_ops_prim import * +from mindspore._c_expression import pyboost_cast, pyboost_zeros, pyboost_ones, pyboost_empty, \ + pyboost_reduce_max, pyboost_reduce_min, pyboost_reduce_all, pyboost_reduce_all +from mindspore.ops.operations.manually_defined.ops_def import Cast, Zeros, Ones +from mindspore.common.api import _pynative_executor + +pyboost_list = list(filter(lambda s: s.startswith("pyboost"), dir(gen_ops_prim))) + + +pyboost_func = ''' +def {name}(*args): + return {pyboost}({op}, args) +''' + +__all__ = [] + +for op_name in pyboost_list: + op = getattr(gen_ops_prim, op_name) + func_name = op_name.replace('pyboost_', '') + '_npu' + prim_op = func_name.replace('_npu', '_op') + if not hasattr(gen_ops_prim, prim_op): + continue + __all__.append(func_name) + globals()[prim_op] = getattr(gen_ops_prim, prim_op).__class__().set_device('Ascend') + exec(pyboost_func.format(name=func_name, pyboost=op_name, op=prim_op), globals()) + +cast_op = Cast().set_device('Ascend') +def cast_npu(*args): + return pyboost_cast(cast_op, args) + +__all__.append('cast_npu') + +def empty_npu(size, dtype): + return pyboost_empty([size, dtype, 'Ascend']) + +__all__.append('empty_npu') + +zeros_op = Zeros().set_device('Ascend') +def zeros_npu(*args): + return pyboost_zeros(zeros_op, args) + +__all__.append('zeros_npu') + +ones_op = Ones().set_device('Ascend') +def ones_npu(*args): + return pyboost_ones(ones_op, args) + +__all__.append('ones_npu') + + +squeeze_op = Squeeze().set_device('Ascend') +def squeeze_npu(*args): + return pyboost_squeeze(squeeze_op, args) + +__all__.append('squeeze_npu') + +stack_ext_op = StackExt().set_device('Ascend') +def stack_ext_npu(*args): + return pyboost_stack_ext(stack_ext_op, args) + +__all__.append('stack_ext_npu') + +tile_op = Primitive('Tile').set_device('Ascend') +def tile_npu(*args): + return pyboost_tile(tile_op, args) + +__all__.append('tile_npu') + +greater_equal_op = GreaterEqual().set_device('Ascend') +def greater_equal_npu(*args): + return pyboost_greater_equal(greater_equal_op, args) + +__all__.append('greater_equal_npu') + +isclose_op = IsClose().set_device('Ascend') +def isclose_npu(*args): + return pyboost_isclose(isclose_op, args) + +__all__.append('isclose_npu') + +reduce_max_op = ReduceMax().set_device('Ascend') +def reduce_max_npu(*args): + return pyboost_reduce_max(reduce_max_op, args) + +__all__.append('reduce_max_npu') + +reduce_min_op = ReduceMin().set_device('Ascend') +def reduce_min_npu(*args): + return pyboost_reduce_min(reduce_min_op, args) + +__all__.append('reduce_min_npu') + +reduce_all_op = ReduceAll().set_device('Ascend') +def reduce_all_npu(*args): + return pyboost_reduce_all(reduce_all_op, args) + +__all__.append('reduce_all_npu') + +reduce_any_op = ReduceAny().set_device('Ascend') +def reduce_any_npu(*args): + return pyboost_reduce_any(reduce_any_op, args) + +__all__.append('reduce_any_npu') + +unique_consecutive_op = UniqueConsecutive().set_device('Ascend') +def unique_consecutive_npu(*args): + return pyboost_unique_consecutive(unique_consecutive_op, args) + +__all__.append('unique_consecutive_npu') + +nan_to_num_op = NanToNum().set_device('Ascend') +def nan_to_num_npu(*args): + return pyboost_nan_to_num(nan_to_num_op, args) + +__all__.append('nan_to_num_npu') + + +softmax_op = Softmax().set_device('Ascend') +def softmax_npu(*args): + return pyboost_softmax(softmax_op, args) + +__all__.append('softmax_npu') + +broadcast_to_op = Primitive('BroadcastTo').set_device('Ascend') +def broadcast_to_npu(*args): + return pyboost_broadcast_to(broadcast_to_op, args) + +__all__.append('broadcast_to_npu') + +triu_op = Triu().set_device('Ascend') +def triu_npu(*args): + return pyboost_triu(triu_op, args) + +__all__.append('triu_npu') + +tril_ext_op = TrilExt().set_device('Ascend') +def tril_ext_npu(*args): + return pyboost_tril_ext(triu_op, args) + +__all__.append('tril_ext_npu') + +search_sorted_op = SearchSorted().set_device('Ascend') +def search_sorted_npu(*args): + return pyboost_searchsorted(search_sorted_op, args) + +__all__.append('search_sorted_npu') + +roll_op = Primitive('Roll').set_device('Ascend') +def roll_npu(*args): + return pyboost_roll(roll_op, args) + +__all__.append('roll_npu') + +meshgrid_op = Meshgrid().set_device('Ascend') +def meshgrid_npu(*args): + return pyboost_meshgrid(meshgrid_op, args) + +__all__.append('meshgrid_npu') + +reverse_v2_op = Primitive('ReverseV2').set_device('Ascend') +def reverse_v2_npu(*args): + return pyboost_reverse_v2(reverse_v2_op, args) + +__all__.append('reverse_v2_npu') + +hard_shrink_op = HShrink().set_device('Ascend') +def hard_shrink_npu(*args): + return pyboost_hshrink(hard_shrink_op, args) + +__all__.append('hard_shrink_npu') + +concat_op = Concat().set_device('Ascend') +def concat_npu(*args): + return pyboost_concat(concat_op, args) + +__all__.append('concat_npu') + +rms_norm_op = RmsNorm().set_device('Ascend') +def rms_norm_npu(*args): + return pyboost_rms_norm(rms_norm_op, args) + +__all__.append('rms_norm_npu') + +flash_attention_score_op = Primitive('FlashAttentionScore').set_device('Ascend') +def flash_attention_score_npu(*args): + return pyboost_flash_attention_score(flash_attention_score_op, args) + +__all__.append('flash_attention_score_npu') + +argmax_with_value_op = ArgMaxWithValue().set_device('Ascend') +def argmax_with_value_npu(*args): + return pyboost_argmax_with_value(argmax_with_value_op, args) + +__all__.append('argmax_with_value_npu') + +argmin_with_value_op = ArgMinWithValue().set_device('Ascend') +def argmin_with_value_npu(*args): + return pyboost_argmin_with_value(argmin_with_value_op, args) + +__all__.append('argmin_with_value_npu') diff --git a/mindnlp/core/_prims/cpu/__init__.py b/mindnlp/core/_prims/cpu/__init__.py new file mode 100644 index 000000000..4192374af --- /dev/null +++ b/mindnlp/core/_prims/cpu/__init__.py @@ -0,0 +1,249 @@ +from mindspore.common.api import _pynative_executor +from mindspore.ops.auto_generate import gen_ops_prim +from mindspore.ops.auto_generate.gen_ops_prim import * +from mindspore._c_expression import Tensor as MSTensor +from mindspore._c_expression import pyboost_cast, pyboost_empty, pyboost_zeros, pyboost_ones +from mindspore.ops.operations.manually_defined.ops_def import Cast, Zeros, Ones +from mindspore.ops._primitive_cache import _get_cache_prim +from mindspore.ops import StopGradient, Primitive, ApplyAdadelta, Adam, ApplyAdamWithAmsgradV2, SGD +from mindspore.ops import FillV2, UniformReal, Stack, StandardNormal, TensorScatterUpdate +from mindspore.ops.operations import identity, TensorShape +from mindspore.ops.operations._grad_ops import StridedSliceGrad + + +pyboost_list = list(filter(lambda s: s.startswith("pyboost"), dir(gen_ops_prim))) +pyboost_op_list = [op.replace('pyboost_', '') + '_op' for op in pyboost_list] +aclop_list = list(filter(lambda s: s.endswith("_op") and not s in pyboost_op_list, dir(gen_ops_prim))) + + +pyboost_func = ''' +def {name}(*args): + return {pyboost}({op}, args) +''' + +aclop_func = ''' +def {name}(*args): + return _pynative_executor.run_op_async({obj}, {obj}.name, args) +''' + +__all__ = [] + +for op_name in pyboost_list: + op = getattr(gen_ops_prim, op_name) + func_name = op_name.replace('pyboost_', '') + '_cpu' + prim_op = func_name.replace('_cpu', '_op') + if not hasattr(gen_ops_prim, prim_op): + continue + __all__.append(func_name) + globals()[prim_op] = getattr(gen_ops_prim, prim_op).__class__().set_device('CPU') + exec(pyboost_func.format(name=func_name, pyboost=op_name, op=prim_op), globals()) + + +for op_name in aclop_list: + func_name = op_name.replace('_op', '_cpu') + __all__.append(func_name) + prim_op = func_name + '_prim' + globals()[prim_op] = getattr(gen_ops_prim, op_name).__class__().set_device('CPU') + exec(aclop_func.format(name=func_name, obj=prim_op), globals()) + +cast_op = Cast().set_device('CPU') +def cast_cpu(*args): + return pyboost_cast(cast_op, args) + +__all__.append('cast_cpu') + +def empty_cpu(size, dtype): + return pyboost_empty([size, dtype, 'CPU']) + +__all__.append('empty_cpu') + +zeros_op = Zeros().set_device('CPU') +def zeros_cpu(*args): + return pyboost_zeros(zeros_op, args) + +__all__.append('zeros_cpu') + +ones_op = Ones().set_device('CPU') +def ones_cpu(*args): + return pyboost_ones(ones_op, args) + +__all__.append('ones_cpu') + + +squeeze_op = Squeeze().set_device('CPU') +def squeeze_cpu(*args): + return pyboost_squeeze(squeeze_op, args) + +__all__.append('squeeze_cpu') + +stack_ext_op = StackExt().set_device('CPU') +def stack_ext_cpu(*args): + return pyboost_stack_ext(stack_ext_op, args) + +__all__.append('stack_ext_cpu') + +tile_op = Primitive('Tile').set_device('CPU') +def tile_cpu(*args): + return pyboost_tile(tile_op, args) + +__all__.append('tile_cpu') + +greater_equal_op = GreaterEqual().set_device('CPU') +def greater_equal_cpu(*args): + return pyboost_greater_equal(greater_equal_op, args) + +__all__.append('greater_equal_cpu') + +isclose_op = IsClose().set_device('CPU') +def isclose_cpu(*args): + return pyboost_isclose(isclose_op, args) + +__all__.append('isclose_cpu') + +range_op = Range().set_device('CPU') +def range_cpu(*args): + return _pynative_executor.run_op_async(range_op, range_op.name, args) + +__all__.append('range_cpu') + +linspace_op = LinSpace().set_device('CPU') +def linspace_cpu(*args): + return _pynative_executor.run_op_async(linspace_op, linspace_op.name, args) + +__all__.append('linspace_cpu') + +full_op = FillV2().set_device('CPU') +def full_cpu(shape, value): + return _pynative_executor.run_op_async(full_op, full_op.name, [shape, MSTensor(value)]) + +__all__.append('full_cpu') + +stop_gradient_op = StopGradient().set_device('CPU') +def stop_gradient_cpu(*args): + return _pynative_executor.run_op_async(stop_gradient_op, stop_gradient_op.name, args) + +__all__.append('stop_gradient_cpu') + +identity_op = identity().set_device('CPU') +def identity_cpu(*args): + return _pynative_executor.run_op_async(identity_op, identity_op.name, args) + +__all__.append('identity_cpu') + + +tensor_shape_op = TensorShape().set_device('CPU') +def tensor_shape_cpu(*args): + return _pynative_executor.run_op_async(tensor_shape_op, tensor_shape_op.name, args) + +__all__.append('stop_gradient_cpu') + +adadelta_op = ApplyAdadelta().set_device('CPU') +def raw_adadelta_cpu(param, square_avg, acc_delta, lr, rho, eps, grad): + args = (param, square_avg, acc_delta, lr, rho, eps, grad) + return _pynative_executor.run_op_async(adadelta_op, adadelta_op.name, args) + +adam_op = Adam().set_device('CPU') +def raw_adam_cpu(param, exp_avg, exp_avg_sq, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad): + # var, m, v, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad + args = (param, exp_avg, exp_avg_sq, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad) + return _pynative_executor.run_op_async(adam_op, adam_op.name, args) + +adam_amsgrad_op = ApplyAdamWithAmsgradV2().set_device('CPU') +def raw_adam_amsgrad_cpu(param, exp_avg, exp_avg_sq, max_exp_avg_sq, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad): + # var, m, v, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad + args = (param, exp_avg, exp_avg_sq, max_exp_avg_sq, + beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad) + return _pynative_executor.run_op_async(adam_amsgrad_op, adam_amsgrad_op.name, args) + + +def raw_sgd_cpu(param, grad, lr, dampening, weight_decay, nesterov, accum, momentum, stat): + sgd_op = _get_cache_prim(SGD)(dampening, weight_decay, nesterov).set_device('CPU') + args = (param, grad, lr, accum, momentum, stat) + return _pynative_executor.run_op_async(sgd_op, sgd_op.name, args) + +__all__.extend( + [ + 'raw_adadelta_cpu', + 'raw_adam_cpu', + 'raw_adam_amsgrad_cpu', + 'raw_sgd_cpu' + ] +) + +uniform_real_op = UniformReal().set_device('CPU') +def uniform_real_cpu(*args): + return _pynative_executor.run_op_async(uniform_real_op, uniform_real_op.name, args) + +__all__.append('uniform_real_cpu') + +def stack_cpu(tensors, dim): + stack_op = _get_cache_prim(Stack)(dim).set_device('CPU') + return _pynative_executor.run_op_async(stack_op, stack_op.name, tensors) + +__all__.append('stack_cpu') + +argmax_with_value_op = ArgMaxWithValue().set_device('CPU') +def argmax_with_value_cpu(*args): + return pyboost_argmax_with_value(argmax_with_value_op, args) + +__all__.append('argmax_with_value_cpu') + +argmin_with_value_op = ArgMinWithValue().set_device('CPU') +def argmin_with_value_cpu(*args): + return pyboost_argmin_with_value(argmin_with_value_op, args) + +__all__.append('argmin_with_value_cpu') + +log_softmax_op = LogSoftmax().set_device('CPU') +def log_softmax_cpu(*args): + return pyboost_log_softmax(log_softmax_op, args) + +__all__.append('log_softmax_cpu') + +strided_slice_op = StridedSlice().set_device('CPU') +def strided_slice_cpu(*args): + return _pynative_executor.run_op_async(strided_slice_op, strided_slice_op.name, args) + +__all__.append('strided_slice_cpu') + +hard_shrink_op = HShrink().set_device('CPU') +def hard_shrink_cpu(*args): + return pyboost_hshrink(hard_shrink_op, args) + +__all__.append('hard_shrink_cpu') + +normal_op = StandardNormal().set_device('CPU') +def normal_cpu(*args): + return _pynative_executor.run_op_async(normal_op, normal_op.name, args) + +__all__.append('normal_cpu') + +reduce_any_op = ReduceAny().set_device('CPU') +def reduce_any_cpu(*args): + return pyboost_reduce_any(reduce_any_op, args) + +__all__.append('reduce_any_cpu') + +def strided_slice_grad_cpu(input, begin, end, strides, update, begin_mask=0, end_mask=0, ellipsis_mask=0, new_axis_mask=0, shrink_axis_mask=0): + strided_slice_grad = _get_cache_prim(StridedSliceGrad)(begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask).set_device('CPU') + return _pynative_executor.run_op_async(strided_slice_grad, strided_slice_grad.name, [update, input.shape, begin, end, strides]) + +__all__.append('strided_slice_grad_cpu') + +tensor_scatter_update_op = TensorScatterUpdate().set_device('CPU') +def tensor_scatter_update_cpu(*args): + return _pynative_executor.run_op_async(tensor_scatter_update_op, tensor_scatter_update_op.name, args) + +__all__.append('tensor_scatter_update_cpu') + +broadcast_to_op = Primitive('BroadcastTo').set_device('CPU') +def broadcast_to_cpu(*args): + return pyboost_broadcast_to(broadcast_to_op, args) + +__all__.append('broadcast_to_cpu') + +concat_op = Concat().set_device('CPU') +def concat_cpu(*args): + return pyboost_concat(concat_op, args) + +__all__.append('concat_cpu') diff --git a/mindnlp/core/_prims/gpu/__init__.py b/mindnlp/core/_prims/gpu/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/mindnlp/core/_prims_common/__init__.py b/mindnlp/core/_prims_common/__init__.py new file mode 100644 index 000000000..15a5a85b1 --- /dev/null +++ b/mindnlp/core/_prims_common/__init__.py @@ -0,0 +1,18 @@ +from typing import ( + Any, + Callable, + cast, + NamedTuple, + Optional, + overload, + TYPE_CHECKING, + TypeVar, + Union, +) + +from typing_extensions import deprecated, TypeAlias + +from mindnlp import core + +ShapeType: TypeAlias = Union[core.Size, list[int], tuple[int, ...]] +DeviceLikeType: TypeAlias = Union[str, core.device, int] diff --git a/mindnlp/core/_six.py b/mindnlp/core/_six.py new file mode 100644 index 000000000..62a0d09b4 --- /dev/null +++ b/mindnlp/core/_six.py @@ -0,0 +1 @@ +string_classes = (str, bytes) diff --git a/mindnlp/core/_subclasses/__init__.py b/mindnlp/core/_subclasses/__init__.py new file mode 100644 index 000000000..f1616ee29 --- /dev/null +++ b/mindnlp/core/_subclasses/__init__.py @@ -0,0 +1 @@ +from .fake_tensor import FakeTensorMode \ No newline at end of file diff --git a/mindnlp/core/_subclasses/fake_tensor.py b/mindnlp/core/_subclasses/fake_tensor.py new file mode 100644 index 000000000..476f3cf84 --- /dev/null +++ b/mindnlp/core/_subclasses/fake_tensor.py @@ -0,0 +1,2 @@ +class FakeTensorMode: + pass diff --git a/mindnlp/core/_tensor.py b/mindnlp/core/_tensor.py new file mode 100644 index 000000000..f4251f01b --- /dev/null +++ b/mindnlp/core/_tensor.py @@ -0,0 +1,117 @@ +import numpy as np +import mindspore +from mindspore import Tensor +from mindspore.common.tensor import _TensorMeta +try: + from mindspore.common._stub_tensor import StubTensor +except: + class StubTensor: pass + +from ._dtype import dtype2np +from ._bind import get_default_device + +from ._dtype import * + + +class TypedTensorMeta(_TensorMeta): + def __isinstancecheck__(self, instance): + if not isinstance(instance, Tensor): + return False + return instance.dtype == self.dtype + +class LongTensor(Tensor, metaclass=TypedTensorMeta): + dtype = long + def __init__(self, data, device=None): + super().__init__(data, dtype=long) + +class FloatTensor(Tensor, metaclass=TypedTensorMeta): + dtype = float32 + def __init__(self, data, device=None): + super().__init__(data, dtype=float32) + + +class HalfTensor(Tensor, metaclass=TypedTensorMeta): + dtype = float16 + def __init__(self, data, device=None): + super().__init__(data, dtype=float16) + +class BFloat16Tensor(Tensor, metaclass=TypedTensorMeta): + dtype = float16 + def __init__(self, data, device=None): + super().__init__(data, dtype=bfloat16) + + +class BoolTensor(Tensor, metaclass=TypedTensorMeta): + dtype = bool + def __init__(self, data, device=None): + super().__init__(data, dtype=bool) + +def tensor(data, *, dtype=None, device=None, requires_grad=False): + if isinstance(data, Tensor): + UserWarning("To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than core.tensor(sourceTensor).") + return Tensor(data) + + if device is None: + device = get_default_device() + + data_np = np.array(data, order='C') # must be C for mindspore Tensor + if dtype is not None: + data_np = data_np.astype(dtype2np[dtype]) + + tensor = Tensor(data_np).to(device) + return tensor + +def is_tensor(x): + return isinstance(x, Tensor) + +def enable_mindspore_patch(): + def to_(self, *args, **kwargs): + dtype_to = None + if len(args) == 1: + if isinstance(args[0], Type): + dtype_to = args[0] + elif isinstance(args[0], Tensor): + dtype_to = args[0].dtype + elif len(args) == 2: + _, dtype_to = args + else: + dtype_to = kwargs.get("dtype", None) + if dtype_to is not None: + return mindspore.ops.cast(self, dtype_to) + return self + + Tensor.to = to_ + StubTensor.to = to_ + + def size(self, dim=None): + if dim is None: + return self.shape + assert isinstance(dim, int), f'`dim` must be int but got {type(dim)}' + return self.shape[dim] + + Tensor.size = size + StubTensor.size = size + + @property + def is_meta(self): + return False + + Tensor.is_meta = is_meta + StubTensor.is_meta = is_meta + + def data_ptr(self): + return self._data_ptr() + + Tensor.data_ptr = data_ptr + StubTensor.data_ptr = data_ptr + + Tensor.device = None + StubTensor.device = None + + def _expand(self, *size): + if len(size) == 1: + size = size[0] + return self.broadcast_to(size) + + Tensor.expand = _expand + StubTensor.expand = _expand diff --git a/mindnlp/core/_utils.py b/mindnlp/core/_utils.py new file mode 100644 index 000000000..c2e4a912d --- /dev/null +++ b/mindnlp/core/_utils.py @@ -0,0 +1,205 @@ +import sys +import traceback + + +from mindnlp import core + +element_size_map = { + core.float16: 2, + core.float32: 4, + core.float64: 8, + core.bfloat16: 2, + core.int64: 8, + core.int32: 4, + core.int16: 2, + core.int8: 1, + core.uint8: 1, + core.bool: 1 +} + +def _element_size(dtype): + return element_size_map[dtype] + +def _flatten_dense_tensors(tensors): + """Flatten dense tensors into a contiguous 1D buffer. Assume tensors are of + same dense type. + + Since inputs are dense, the resulting tensor will be a concatenated 1D + buffer. Element-wise operation on this buffer will be equivalent to + operating individually. + + Args: + tensors (Iterable[Tensor]): dense tensors to flatten. + + Returns: + A contiguous 1D buffer containing input tensors. + """ + tensors = [tensor.view(-1) for tensor in tensors] + return core.cat(tensors) + + +def _unflatten_dense_tensors(flat, tensors): + """View a flat buffer using the sizes of tensors. Assume that tensors are of + same dense type, and that flat is given by _flatten_dense_tensors. + + Args: + flat (Tensor): flattened dense tensors to unflatten. + tensors (Iterable[Tensor]): dense tensors whose sizes will be used to + unflatten flat. + + Returns: + Unflattened dense tensors with sizes same as tensors and values from + flat. + """ + outputs = [] + offset = 0 + for tensor in tensors: + numel = tensor.numel() + if numel == 0: + outputs.append(core.empty(0, flat.dtype)) + else: + outputs.append(core.narrow(flat, 0, offset, numel).view(tensor.shape)) + offset += numel + return outputs + +def _rebuild_tensor_v2( + storage, + storage_offset, + size, + stride, + requires_grad, + backward_hooks, + metadata=None, +): + return core.Tensor(storage) + +class KeyErrorMessage(str): + r"""str subclass that returns itself in repr""" + + def __repr__(self): + return self + +class ExceptionWrapper: + r"""Wraps an exception plus traceback to communicate across threads""" + + def __init__(self, exc_info=None, where="in background"): + # It is important that we don't store exc_info, see + # NOTE [ Python Traceback Reference Cycle Problem ] + if exc_info is None: + exc_info = sys.exc_info() + self.exc_type = exc_info[0] + self.exc_msg = "".join(traceback.format_exception(*exc_info)) + self.where = where + + def reraise(self): + r"""Reraises the wrapped exception in the current thread""" + # Format a message such as: "Caught ValueError in DataLoader worker + # process 2. Original Traceback:", followed by the traceback. + msg = f"Caught {self.exc_type.__name__} {self.where}.\nOriginal {self.exc_msg}" + if self.exc_type == KeyError: + # KeyError calls repr() on its argument (usually a dict key). This + # makes stack traces unreadable. It will not be changed in Python + # (https://bugs.python.org/issue2651), so we work around it. + msg = KeyErrorMessage(msg) + elif getattr(self.exc_type, "message", None): + # Some exceptions have first argument as non-str but explicitly + # have message field + raise self.exc_type(message=msg) + try: + exception = self.exc_type(msg) + except Exception: + # If the exception takes multiple arguments or otherwise can't + # be constructed, don't try to instantiate since we don't know how to + raise RuntimeError(msg) from None + raise exception + +def set_device_address(tensor): + core._prims.cpu.tensor_shape_cpu(tensor) + + +def _type(self, dtype=None, non_blocking=False, **kwargs): + """Returns the type if `dtype` is not provided, else casts this object to + the specified type. + + If this is already of the correct type, no copy is performed and the + original object is returned. + + Args: + dtype (type or string): The desired type + non_blocking (bool): If ``True``, and the source is in pinned memory + and destination is on the GPU or vice versa, the copy is performed + asynchronously with respect to the host. Otherwise, the argument + has no effect. + **kwargs: For compatibility, may contain the key ``async`` in place of + the ``non_blocking`` argument. The ``async`` arg is deprecated. + """ + non_blocking = _get_async_or_non_blocking("type", non_blocking, kwargs) + if dtype is None: + return self.__module__ + "." + self.__class__.__name__ + + if isinstance(dtype, str): + dtype = _import_dotted_name(dtype) + if dtype == type(self): + return self + if self.is_sparse: + if not dtype.is_sparse: + raise RuntimeError("Cannot cast sparse tensor to dense tensor") + new_module_name = dtype.__module__.replace(".sparse", "") + new_values_type_name = new_module_name + "." + dtype.__name__ + new_values = core.Tensor._values(self).type(new_values_type_name, non_blocking) + new_indices_type_name = new_module_name + ".LongTensor" + new_indices = core.Tensor._indices(self).type( + new_indices_type_name, non_blocking + ) + return dtype(new_indices, new_values, self.size()) + if dtype.is_sparse: + raise RuntimeError("Cannot cast dense tensor to sparse tensor") + return dtype(self.size()).copy_(self, non_blocking) + +def _to(self, device, non_blocking=False): + """Returns a copy of this object in device memory. + + If this object is already on the correct device, then no copy is performed + and the original object is returned. + + Args: + device (int): The destination device. + non_blocking (bool): If ``True`` and the source is in pinned memory, + the copy will be asynchronous with respect to the host. Otherwise, + the argument has no effect. + """ + if self.device == device: + return self + + if device.type == "cpu": + pin_memory = non_blocking and self.device.type in ( + "cuda", + core._C._get_privateuse1_backend_name(), + ) + untyped_storage = core.empty( + self.nbytes(), dtype=core.uint8, device=device, pin_memory=pin_memory + ).untyped_storage() + untyped_storage.copy_(self, non_blocking) + return untyped_storage + + device_module = getattr(mindtorch, device.type, None) + assert device_module is not None, ( + f"{device.type.upper()} device module is not loaded" + ) + with device_module.device(device): + if self.is_sparse and hasattr(device_module, "sparse"): + new_type = getattr(device_module.sparse, self.__class__.__name__) + indices = getattr(core.Tensor._indices(self), device.type)( + device, non_blocking + ) + values = getattr(core.Tensor._values(self), device.type)( + device, non_blocking + ) + return new_type(indices, values, self.size()) + else: + assert not self.is_sparse, ( + f"sparse storage is not supported for {device.type.upper()} tensors" + ) + untyped_storage = core.UntypedStorage(self.size(), device=device) + untyped_storage.copy_(self, non_blocking) + return untyped_storage \ No newline at end of file diff --git a/mindnlp/core/amp/__init__.py b/mindnlp/core/amp/__init__.py new file mode 100644 index 000000000..f2859d21e --- /dev/null +++ b/mindnlp/core/amp/__init__.py @@ -0,0 +1,7 @@ +from .autocast_mode import ( + autocast, + custom_bwd, + custom_fwd, + is_autocast_available, +) +from .grad_scaler import GradScaler \ No newline at end of file diff --git a/mindnlp/core/amp/autocast_mode.py b/mindnlp/core/amp/autocast_mode.py new file mode 100644 index 000000000..ff2888c48 --- /dev/null +++ b/mindnlp/core/amp/autocast_mode.py @@ -0,0 +1,165 @@ +# mypy: allow-untyped-defs +import collections +import functools +import warnings +from typing import Any, Optional + +from mindnlp import core + +from mindspore._c_expression.amp import pop_amp_strategy, push_amp_strategy, AmpLevel +from mindspore.common.dtype import TensorType as _dtype, float32 +from mindspore.train.amp import AMP_AUTO_BLACK_LIST, AMP_AUTO_WHITE_LIST, AMP_PRIM_ARG_TABLE + +try: + import numpy as np + + HAS_NUMPY = True +except ModuleNotFoundError: + HAS_NUMPY = False + np = None # type: ignore[assignment] + +__all__ = [ + "autocast_decorator", + "autocast", + "is_autocast_available", + "custom_fwd", + "custom_bwd", +] + + +def is_autocast_available(device_type: str) -> bool: + r""" + Return a bool indicating if autocast is available on :attr:`device_type`. + + Args: + device_type(str): Device type to use. Possible values are: 'cuda', 'cpu', 'xpu' and so on. + The type is the same as the `type` attribute of a :class:`core.device`. + Thus, you may obtain the device type of a tensor using `Tensor.device.type`. + """ + return True + + +def autocast_decorator(autocast_instance, func): + @functools.wraps(func) + def decorate_autocast(*args, **kwargs): + with autocast_instance: + return func(*args, **kwargs) + + return decorate_autocast + + +class autocast: + + def __init__( + self, + device_type: str, + dtype: Optional[_dtype] = None, + enabled: bool = True, + cache_enabled: Optional[bool] = None, + ): + if not isinstance(device_type, str): + raise ValueError( + f"Expected `device_type` of type `str`, got: `{type(device_type)}`" + ) + self.device_type = device_type + if dtype is None: + dtype = float32 + self.dtype = dtype + self.amp_level = AmpLevel.AmpAuto if enabled else AmpLevel.AmpO0 + + def __enter__(self): + self.prev_fastdtype = core.get_autocast_dtype(self.device_type) + core.set_autocast_dtype(self.device_type, self.dtype) + white_list = [(prim.__name__, AMP_PRIM_ARG_TABLE[prim]) for prim in AMP_AUTO_WHITE_LIST] + black_list = [(prim.__name__, AMP_PRIM_ARG_TABLE[prim]) for prim in AMP_AUTO_BLACK_LIST] + push_amp_strategy(self.amp_level, self.dtype, white_list, black_list) + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any): # type: ignore[override] + pop_amp_strategy() + core.set_autocast_dtype(self.device_type, self.prev_fastdtype) + return False + + def __call__(self, func): + return autocast_decorator(self, func) + +def custom_fwd( + fwd=None, + *, + device_type: str, + cast_inputs: Optional[_dtype] = None, +): + """ + Create a helper decorator for ``forward`` methods of custom autograd functions. + + Autograd functions are subclasses of :class:`core.autograd.Function`. + See the :ref:`example page` for more detail. + + Args: + device_type(str): Device type to use. 'cuda', 'cpu', 'xpu' and so on. + The type is the same as the `type` attribute of a :class:`core.device`. + Thus, you may obtain the device type of a tensor using `Tensor.device.type`. + cast_inputs (:class:`core.dtype` or None, optional, default=None): If not ``None``, + when ``forward`` runs in an autocast-enabled region, casts incoming + floating-point Tensors to the target dtype (non-floating-point Tensors are not affected), + then executes ``forward`` with autocast disabled. + If ``None``, ``forward``'s internal ops execute with the current autocast state. + + .. note:: + If the decorated ``forward`` is called outside an autocast-enabled region, + :func:`custom_fwd` is a no-op and ``cast_inputs`` has no effect. + """ + if fwd is None: + return functools.partial( + custom_fwd, device_type=device_type, cast_inputs=cast_inputs + ) + + @functools.wraps(fwd) + def decorate_fwd(*args, **kwargs): + args[0]._dtype = core.get_autocast_dtype(device_type) + if cast_inputs is None: + args[0]._fwd_used_autocast = core.is_autocast_enabled(device_type) + return fwd(*args, **kwargs) + else: + autocast_context = core.is_autocast_enabled(device_type) + args[0]._fwd_used_autocast = False + if autocast_context: + with autocast(device_type=device_type, enabled=False): + return fwd( + *_cast(args, device_type, cast_inputs), + **_cast(kwargs, device_type, cast_inputs), + ) + else: + return fwd(*args, **kwargs) + + return decorate_fwd + + +# Autograd ensures incoming gradients are the same type as forward outputs. Allowing a separate +# cast_inputs argument on custom_bwd is unnecessary and could cause errors if it doesn't match +# cast_inputs supplied to custom_fwd. +def custom_bwd(bwd=None, *, device_type: str): + """Create a helper decorator for backward methods of custom autograd functions. + + Autograd functions are subclasses of :class:`core.autograd.Function`. + Ensures that ``backward`` executes with the same autocast state as ``forward``. + See the :ref:`example page` for more detail. + + Args: + device_type(str): Device type to use. 'cuda', 'cpu', 'xpu' and so on. + The type is the same as the `type` attribute of a :class:`core.device`. + Thus, you may obtain the device type of a tensor using `Tensor.device.type`. + """ + + if bwd is None: + return custom_bwd + + @functools.wraps(bwd) + def decorate_bwd(*args, **kwargs): + with autocast( + device_type=device_type, + enabled=args[0]._fwd_used_autocast, + dtype=args[0]._dtype, + ): + return bwd(*args, **kwargs) + + return decorate_bwd diff --git a/mindnlp/core/amp/grad_scaler.py b/mindnlp/core/amp/grad_scaler.py new file mode 100644 index 000000000..88f79c6d1 --- /dev/null +++ b/mindnlp/core/amp/grad_scaler.py @@ -0,0 +1,685 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import inspect +import warnings +from collections import abc, defaultdict +from enum import Enum +from typing import Any, cast, Dict, Iterable, List, Optional, overload, Tuple, Union + +from mindnlp import core + + +__all__ = ["OptState", "GradScaler"] + + +class _MultiDeviceReplicator: + """Lazily serves copies of a tensor to requested devices. + + Copies are cached per-device. + """ + + def __init__(self, master_tensor: core.Tensor) -> None: + self.master = master_tensor + self._per_device_tensors: Dict[core.device, core.Tensor] = {} + + def get(self, device: core.device) -> core.Tensor: + retval = self._per_device_tensors.get(device, None) + if retval is None: + retval = self.master.to(device=device, non_blocking=True, copy=True) + self._per_device_tensors[device] = retval + return retval + + +# Defines default_factory for GradScaler's _per_optimizer_states defaultdict, +# as well as associated "enum" values. Prefers defining these at top level because +# - Lambdas can't be pickled, so we don't want to supply a lambda as the factory. +# - Defining READY, UNSCALED, STEPPED and _refresh_per_optimizer_state within GradScaler +# causes a circular reference, which we'd rather avoid. +class OptState(Enum): + READY = 0 + UNSCALED = 1 + STEPPED = 2 + + +def _refresh_per_optimizer_state() -> Dict[str, Any]: + return {"stage": OptState.READY, "found_inf_per_device": {}} + + +class GradScaler: + """An instance ``scaler`` of :class:`GradScaler`. + + Helps perform the steps of gradient scaling + conveniently. + + * ``scaler.scale(loss)`` multiplies a given loss by ``scaler``'s current scale factor. + * ``scaler.step(optimizer)`` safely unscales gradients and calls ``optimizer.step()``. + * ``scaler.update()`` updates ``scaler``'s scale factor. + + Example:: + + # Creates a GradScaler once at the beginning of training. + scaler = GradScaler() + + for epoch in epochs: + for input, target in data: + optimizer.zero_grad() + output = model(input) + loss = loss_fn(output, target) + + # Scales loss. Calls backward() on scaled loss to create scaled gradients. + scaler.scale(loss).backward() + + # scaler.step() first unscales gradients of the optimizer's params. + # If gradients don't contain infs/NaNs, optimizer.step() is then called, + # otherwise, optimizer.step() is skipped. + scaler.step(optimizer) + + # Updates the scale for next iteration. + scaler.update() + + See the :ref:`Automatic Mixed Precision examples` for usage + (along with autocasting) in more complex cases like gradient clipping, gradient accumulation, gradient penalty, + and multiple losses/optimizers. + + ``scaler`` dynamically estimates the scale factor each iteration. To minimize gradient underflow, + a large scale factor should be used. However, ``float16`` values can "overflow" (become inf or NaN) if + the scale factor is too large. Therefore, the optimal scale factor is the largest factor that can be used + without incurring inf or NaN gradient values. + ``scaler`` approximates the optimal scale factor over time by checking the gradients for infs and NaNs during every + ``scaler.step(optimizer)`` (or optional separate ``scaler.unscale_(optimizer)``, see :meth:`unscale_`). + + * If infs/NaNs are found, ``scaler.step(optimizer)`` skips the underlying ``optimizer.step()`` (so the params + themselves remain uncorrupted) and ``update()`` multiplies the scale by ``backoff_factor``. + + * If no infs/NaNs are found, ``scaler.step(optimizer)`` runs the underlying ``optimizer.step()`` as usual. + If ``growth_interval`` unskipped iterations occur consecutively, ``update()`` multiplies the scale by + ``growth_factor``. + + The scale factor often causes infs/NaNs to appear in gradients for the first few iterations as its + value calibrates. ``scaler.step`` will skip the underlying ``optimizer.step()`` for these + iterations. After that, step skipping should occur rarely (once every few hundred or thousand iterations). + + Args: + device (str, optional, default="cuda"): Device type to use. Possible values are: 'cuda' and 'cpu'. + The type is the same as the `type` attribute of a :class:`core.device`. + Thus, you may obtain the device type of a tensor using `Tensor.device.type`. + init_scale (float, optional, default=2.**16): Initial scale factor. + growth_factor (float, optional, default=2.0): Factor by which the scale is multiplied during + :meth:`update` if no inf/NaN gradients occur for ``growth_interval`` consecutive iterations. + backoff_factor (float, optional, default=0.5): Factor by which the scale is multiplied during + :meth:`update` if inf/NaN gradients occur in an iteration. + growth_interval (int, optional, default=2000): Number of consecutive iterations without inf/NaN gradients + that must occur for the scale to be multiplied by ``growth_factor``. + enabled (bool, optional): If ``False``, disables gradient scaling. :meth:`step` simply + invokes the underlying ``optimizer.step()``, and other methods become no-ops. + Default: ``True`` + """ + + def __init__( + self, + device: str = "cuda", + init_scale: float = 2.0**16, + growth_factor: float = 2.0, + backoff_factor: float = 0.5, + growth_interval: int = 2000, + enabled: bool = True, + ) -> None: + self._device = device + self._enabled = enabled + if self._device == "cuda": + if enabled and core.cuda.amp.common.amp_definitely_not_available(): + warnings.warn( + "core.cuda.amp.GradScaler is enabled, but CUDA is not available. Disabling." + ) + self._enabled = False + + if self._enabled: + assert growth_factor > 1.0, "The growth factor must be > 1.0." + assert backoff_factor < 1.0, "The backoff factor must be < 1.0." + + self._init_scale = init_scale + # self._scale will be lazily initialized during the first call to scale() + self._scale: Optional[core.Tensor] = None + self._growth_factor = growth_factor + self._backoff_factor = backoff_factor + self._growth_interval = growth_interval + self._init_growth_tracker = 0 + # self._growth_tracker will be lazily initialized during the first call to scale() + self._growth_tracker: Optional[core.Tensor] = None + self._per_optimizer_states: Dict[int, Dict[str, Any]] = defaultdict( + _refresh_per_optimizer_state + ) + + def _check_scale_growth_tracker( + self, funcname: str + ) -> Tuple[core.Tensor, core.Tensor]: + fix = "This may indicate your script did not use scaler.scale(loss or outputs) earlier in the iteration." + assert self._scale is not None, ( + f"Attempted {funcname} but _scale is None. " + fix + ) + assert self._growth_tracker is not None, ( + f"Attempted {funcname} but _growth_tracker is None. " + fix + ) + return (self._scale, self._growth_tracker) + + def _lazy_init_scale_growth_tracker(self, dev: core.device) -> None: + assert self._growth_tracker is None, "_growth_tracker initialized before _scale" + self._scale = core.full((), self._init_scale, dtype=core.float32, device=dev) + self._growth_tracker = core.full( + (), self._init_growth_tracker, dtype=core.int32, device=dev + ) + + @overload + def scale(self, outputs: core.Tensor) -> core.Tensor: + ... + + @overload + def scale(self, outputs: List[core.Tensor]) -> List[core.Tensor]: + ... + + @overload + def scale(self, outputs: Tuple[core.Tensor, ...]) -> Tuple[core.Tensor, ...]: + ... + + @overload + def scale(self, outputs: Iterable[core.Tensor]) -> Iterable[core.Tensor]: + ... + + def scale( + self, + outputs: Union[core.Tensor, Iterable[core.Tensor]], + ) -> Union[core.Tensor, Iterable[core.Tensor]]: + """ + Multiplies ('scales') a tensor or list of tensors by the scale factor. + + Returns scaled outputs. If this instance of :class:`GradScaler` is not enabled, outputs are returned + unmodified. + + Args: + outputs (Tensor or iterable of Tensors): Outputs to scale. + """ + if not self._enabled: + return outputs + + # Short-circuit for the common case. + if isinstance(outputs, core.Tensor): + if self._scale is None: + self._lazy_init_scale_growth_tracker(outputs.device) + assert self._scale is not None + return outputs * self._scale.to(device=outputs.device, non_blocking=True) + + # Invoke the more complex machinery only if we're treating multiple outputs. + stash: List[ + _MultiDeviceReplicator + ] = [] # holds a reference that can be overwritten by apply_scale + + def apply_scale(val: Union[core.Tensor, Iterable[core.Tensor]]): + if isinstance(val, core.Tensor): + if len(stash) == 0: + if self._scale is None: + self._lazy_init_scale_growth_tracker(val.device) + assert self._scale is not None + stash.append(_MultiDeviceReplicator(self._scale)) + return val * stash[0].get(val.device) + if isinstance(val, abc.Iterable): + iterable = map(apply_scale, val) + if isinstance(val, (list, tuple)): + return type(val)(iterable) + return iterable + raise ValueError("outputs must be a Tensor or an iterable of Tensors") + + return apply_scale(outputs) + + def _unscale_grads_( + self, + optimizer: core.optim.Optimizer, + inv_scale: core.Tensor, + found_inf: core.Tensor, + allow_fp16: bool, + ) -> Dict[core.device, core.Tensor]: + per_device_inv_scale = _MultiDeviceReplicator(inv_scale) + per_device_found_inf = _MultiDeviceReplicator(found_inf) + + # To set up _amp_foreach_non_finite_check_and_unscale_, split grads by device and dtype. + # There could be hundreds of grads, so we'd like to iterate through them just once. + # However, we don't know their devices or dtypes in advance. + + # https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict + # Google says mypy struggles with defaultdicts type annotations. + per_device_and_dtype_grads: Dict[ + core.device, Dict[core.dtype, List[core.Tensor]] + ] = defaultdict(lambda: defaultdict(list)) + with core.no_grad(): + for group in optimizer.param_groups: + for param in group["params"]: + assert isinstance(param, core.Tensor) + if param.grad is None: + continue + if (not allow_fp16) and param.grad.dtype == core.float16: + raise ValueError("Attempting to unscale FP16 gradients.") + if param.grad.is_sparse: + # is_coalesced() == False means the sparse grad has values with duplicate indices. + # coalesce() deduplicates indices and adds all values that have the same index. + # For scaled fp16 values, there's a good chance coalescing will cause overflow, + # so we should check the coalesced _values(). + if param.grad.dtype is core.float16: + param.grad = param.grad.coalesce() + to_unscale = param.grad._values() + else: + to_unscale = param.grad + + # TODO: is there a way to split by device and dtype without appending in the inner loop? + per_device_and_dtype_grads[to_unscale.device][ + to_unscale.dtype + ].append(to_unscale) + + for device, per_dtype_grads in per_device_and_dtype_grads.items(): + for grads in per_dtype_grads.values(): + core._amp_foreach_non_finite_check_and_unscale_( + grads, + per_device_found_inf.get(device), + per_device_inv_scale.get(device), + ) + + return per_device_found_inf._per_device_tensors + + def unscale_(self, optimizer: core.optim.Optimizer) -> None: + """ + Divides ("unscales") the optimizer's gradient tensors by the scale factor. + + :meth:`unscale_` is optional, serving cases where you need to + :ref:`modify or inspect gradients` + between the backward pass(es) and :meth:`step`. + If :meth:`unscale_` is not called explicitly, gradients will be unscaled automatically during :meth:`step`. + + Simple example, using :meth:`unscale_` to enable clipping of unscaled gradients:: + + ... + scaler.scale(loss).backward() + scaler.unscale_(optimizer) + core.nn.utils.clip_grad_norm_(model.parameters(), max_norm) + scaler.step(optimizer) + scaler.update() + + Args: + optimizer (core.optim.Optimizer): Optimizer that owns the gradients to be unscaled. + + .. note:: + :meth:`unscale_` does not incur a CPU-GPU sync. + + .. warning:: + :meth:`unscale_` should only be called once per optimizer per :meth:`step` call, + and only after all gradients for that optimizer's assigned parameters have been accumulated. + Calling :meth:`unscale_` twice for a given optimizer between each :meth:`step` triggers a RuntimeError. + + .. warning:: + :meth:`unscale_` may unscale sparse gradients out of place, replacing the ``.grad`` attribute. + """ + if not self._enabled: + return + + self._check_scale_growth_tracker("unscale_") + + optimizer_state = self._per_optimizer_states[id(optimizer)] + + if optimizer_state["stage"] is OptState.UNSCALED: + raise RuntimeError( + "unscale_() has already been called on this optimizer since the last update()." + ) + elif optimizer_state["stage"] is OptState.STEPPED: + raise RuntimeError("unscale_() is being called after step().") + + # FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64. + assert self._scale is not None + inv_scale = self._scale.double().reciprocal().float() + found_inf = core.full((), 0.0, dtype=core.float32, device=self._scale.device) + + optimizer_state["found_inf_per_device"] = self._unscale_grads_( + optimizer, inv_scale, found_inf, False + ) + optimizer_state["stage"] = OptState.UNSCALED + + def _maybe_opt_step( + self, + optimizer: core.optim.Optimizer, + optimizer_state: Dict[str, Any], + *args: Any, + **kwargs: Any, + ) -> Optional[float]: + retval: Optional[float] = None + if not sum(v.item() for v in optimizer_state["found_inf_per_device"].values()): + retval = optimizer.step(*args, **kwargs) + return retval + + def step( + self, optimizer: core.optim.Optimizer, *args: Any, **kwargs: Any + ) -> Optional[float]: + """Invoke ``unscale_(optimizer)`` followed by parameter update, if gradients are not infs/NaN. + + :meth:`step` carries out the following two operations: + + 1. Internally invokes ``unscale_(optimizer)`` (unless :meth:`unscale_` was explicitly called for ``optimizer`` + earlier in the iteration). As part of the :meth:`unscale_`, gradients are checked for infs/NaNs. + 2. If no inf/NaN gradients are found, invokes ``optimizer.step()`` using the unscaled + gradients. Otherwise, ``optimizer.step()`` is skipped to avoid corrupting the params. + + ``*args`` and ``**kwargs`` are forwarded to ``optimizer.step()``. + + Returns the return value of ``optimizer.step(*args, **kwargs)``. + + Args: + optimizer (core.optim.Optimizer): Optimizer that applies the gradients. + args: Any arguments. + kwargs: Any keyword arguments. + + .. warning:: + Closure use is not currently supported. + """ + if not self._enabled: + return optimizer.step(*args, **kwargs) + + if "closure" in kwargs: + raise RuntimeError( + "Closure use is not currently supported if GradScaler is enabled." + ) + + self._check_scale_growth_tracker("step") + + optimizer_state = self._per_optimizer_states[id(optimizer)] + + if optimizer_state["stage"] is OptState.STEPPED: + raise RuntimeError( + "step() has already been called since the last update()." + ) + + retval: Optional[float] = None + + if getattr(optimizer, "_step_supports_amp_scaling", False): + # This optimizer has customized scale-handling logic, so we can call optimizer.step() directly. + # The contract with custom optimizers is that their step() should accept an additional, + # optional grad_scaler kwarg. We append self to the kwargs so the custom optimizer has full information: + # it can query its own state, invoke unscale_ on itself, etc + # The contract above is being deprecated to avoid introducing `grad_scaler: GradScaler` argument + # to `Optimizer.step`. The new behavior is going to add two Tensor attributes of `grad_scale` + # and `found_inf` to the passed optimizer so that the optimizer can utilize those + # to skip the parameter updates or unscale gradients before updating parameters in + # the fused kernel, e.g. `FusedAdamMathFunctor`. + # In this behavior, `GradScaler._check_inf_per_device` is called if `OptState.READY`, + # while the method is expected to be called by users side, i.e. their optimizers. + kwargs_ = kwargs + has_grad_scaler_kwarg = ( + "grad_scaler" in inspect.signature(optimizer.step).parameters + ) + if has_grad_scaler_kwarg: + warnings.warn( + "GradScaler is going to stop passing itself as a keyword argument to the passed " + "optimizer. In the near future GradScaler registers `grad_scale: Tensor` and " + "`found_inf: Tensor` to the passed optimizer and let the optimizer use them directly.", + FutureWarning, + ) + kwargs_.update({"grad_scaler": self}) + else: + if optimizer_state["stage"] is OptState.READY: + self._check_inf_per_device(optimizer) + scaler = self._get_scale_async() + assert scaler is not None + found_inf = cast( + core.Tensor, + sum( + [ # noqa: C419 + t.to(scaler.device, non_blocking=True) + for t in optimizer_state["found_inf_per_device"].values() + ] + ), + ) + # Take the product of the scales, if the user has already set `optimizer.grad_scale`. + optimizer.grad_scale = ( # type: ignore[attr-defined] + getattr(optimizer, "grad_scale", None) + if optimizer_state["stage"] == OptState.UNSCALED + else scaler * getattr(optimizer, "grad_scale", 1) + ) + optimizer.found_inf = found_inf # type: ignore[attr-defined] + retval = optimizer.step(*args, **kwargs_) + optimizer_state["stage"] = OptState.STEPPED + if not has_grad_scaler_kwarg: + del optimizer.grad_scale # type: ignore[attr-defined] + del optimizer.found_inf # type: ignore[attr-defined] + return retval + + if optimizer_state["stage"] is OptState.READY: + self.unscale_(optimizer) + + assert ( + len(optimizer_state["found_inf_per_device"]) > 0 + ), "No inf checks were recorded for this optimizer." + + retval = self._maybe_opt_step(optimizer, optimizer_state, *args, **kwargs) + + optimizer_state["stage"] = OptState.STEPPED + + return retval + + def update(self, new_scale: Optional[Union[float, core.Tensor]] = None) -> None: + """Update the scale factor. + + If any optimizer steps were skipped the scale is multiplied by ``backoff_factor`` + to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively, + the scale is multiplied by ``growth_factor`` to increase it. + + Passing ``new_scale`` sets the new scale value manually. (``new_scale`` is not + used directly, it's used to fill GradScaler's internal scale tensor. So if + ``new_scale`` was a tensor, later in-place changes to that tensor will not further + affect the scale GradScaler uses internally.) + + Args: + new_scale (float or :class:`core.Tensor`, optional, default=None): New scale factor. + + .. warning:: + :meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has + been invoked for all optimizers used this iteration. + + .. warning:: + For performance reasons, we do not check the scale factor value to avoid synchronizations, + so the scale factor is not guaranteed to be above 1. If the scale falls below 1 and/or + you are seeing NaNs in your gradients or loss, something is likely wrong. For example, + bf16-pretrained models are often incompatible with AMP/fp16 due to differing dynamic ranges. + """ + if not self._enabled: + return + + _scale, _growth_tracker = self._check_scale_growth_tracker("update") + + if new_scale is not None: + assert self._scale is not None + # Accept a new user-defined scale. + if isinstance(new_scale, float): + self._scale.fill_(new_scale) + else: + reason = "new_scale should be a float or a 1-element core.cuda.FloatTensor or \ + core.FloatTensor with requires_grad=False." + assert new_scale.device.type == self._device, reason + assert new_scale.numel() == 1, reason + assert new_scale.requires_grad is False, reason + self._scale.copy_(new_scale) + else: + # Consume shared inf/nan data collected from optimizers to update the scale. + # If all found_inf tensors are on the same device as self._scale, this operation is asynchronous. + found_infs = [ + found_inf.to(device=_scale.device, non_blocking=True) + for state in self._per_optimizer_states.values() + for found_inf in state["found_inf_per_device"].values() + ] + + assert len(found_infs) > 0, "No inf checks were recorded prior to update." + + found_inf_combined = found_infs[0] + if len(found_infs) > 1: + for i in range(1, len(found_infs)): + found_inf_combined += found_infs[i] + + core._amp_update_scale_( + _scale, + _growth_tracker, + found_inf_combined, + self._growth_factor, + self._backoff_factor, + self._growth_interval, + ) + + # To prepare for next iteration, clear the data collected from optimizers this iteration. + self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state) + + def _get_scale_async(self) -> Optional[core.Tensor]: + return self._scale + + def get_scale(self) -> float: + """Return a Python float containing the current scale, or 1.0 if scaling is disabled. + + .. warning:: + :meth:`get_scale` incurs a CPU-GPU sync. + """ + if self._enabled: + return ( + self._init_scale + if (scale := self._get_scale_async()) is None + else cast(float, scale.item()) + ) + return 1.0 + + def get_growth_factor(self) -> float: + r"""Return a Python float containing the scale growth factor.""" + return self._growth_factor + + def set_growth_factor(self, new_factor: float) -> None: + r"""Set a new scale growth factor. + + Args: + new_scale (float): Value to use as the new scale growth factor. + """ + self._growth_factor = new_factor + + def get_backoff_factor(self) -> float: + r"""Return a Python float containing the scale backoff factor.""" + return self._backoff_factor + + def set_backoff_factor(self, new_factor: float) -> None: + r"""Set a new scale backoff factor. + + Args: + new_scale (float): Value to use as the new scale backoff factor. + """ + self._backoff_factor = new_factor + + def get_growth_interval(self) -> int: + r"""Return a Python int containing the growth interval.""" + return self._growth_interval + + def set_growth_interval(self, new_interval: int) -> None: + r"""Set a new growth interval. + + Args: + new_interval (int): Value to use as the new growth interval. + """ + self._growth_interval = new_interval + + def _get_growth_tracker(self) -> int: + if self._enabled: + return ( + self._init_growth_tracker + if self._growth_tracker is None + else cast(int, self._growth_tracker.item()) + ) + return 0 + + def is_enabled(self) -> bool: + r"""Return a bool indicating whether this instance is enabled.""" + return self._enabled + + def state_dict(self) -> Dict[str, Any]: + r"""Return the state of the scaler as a :class:`dict`. + + It contains five entries: + + * ``"scale"`` - a Python float containing the current scale + * ``"growth_factor"`` - a Python float containing the current growth factor + * ``"backoff_factor"`` - a Python float containing the current backoff factor + * ``"growth_interval"`` - a Python int containing the current growth interval + * ``"_growth_tracker"`` - a Python int containing the number of recent consecutive unskipped steps. + + If this instance is not enabled, returns an empty dict. + + .. note:: + If you wish to checkpoint the scaler's state after a particular iteration, :meth:`state_dict` + should be called after :meth:`update`. + """ + if self._enabled: + return { + "scale": self.get_scale(), + "growth_factor": self._growth_factor, + "backoff_factor": self._backoff_factor, + "growth_interval": self._growth_interval, + "_growth_tracker": self._get_growth_tracker(), + } + return {} + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + r"""Load the scaler state. + + If this instance is disabled, :meth:`load_state_dict` is a no-op. + + Args: + state_dict(dict): scaler state. Should be an object returned from a call to :meth:`state_dict`. + """ + if not self._enabled: + return + + if len(state_dict) == 0: + raise RuntimeError( + "The source state dict is empty, possibly because it was saved " + "from a disabled instance of GradScaler." + ) + + self._init_scale = cast(float, state_dict["scale"]) + if self._scale is not None: + self._scale.fill_(state_dict["scale"]) + self._growth_factor = cast(float, state_dict["growth_factor"]) + self._backoff_factor = cast(float, state_dict["backoff_factor"]) + self._growth_interval = cast(int, state_dict["growth_interval"]) + self._init_growth_tracker = cast(int, state_dict["_growth_tracker"]) + if self._growth_tracker is not None: + self._growth_tracker.fill_(state_dict["_growth_tracker"]) + + def __getstate__(self) -> Dict[str, Any]: + state = self.__dict__.copy() + if self._enabled: + assert len(self._per_optimizer_states) == 0, ( + "A GradScaler instance may only be pickled at the beginning " + "of an iteration, or at the end after scaler.update()." + ) + # Pickling _scale and _growth_tracker Tensors directly triggers + # "warnings.warn("pickle support for Storage will be removed in 1.5..." + # so instead, we set the unpickled instance up to reinitialize them lazily. + state["_init_scale"] = self.get_scale() + state["_init_growth_tracker"] = self._get_growth_tracker() + state["_scale"] = None + state["_growth_tracker"] = None + return state + + def __setstate__(self, state: Dict[str, Any]) -> None: + self.__dict__.update(state) + + def _check_inf_per_device(self, optimizer: core.optim.Optimizer) -> Dict[str, Any]: + _scale, _ = self._check_scale_growth_tracker("_check_inf_per_device") + + dummy_inv_scale = core.full((), 1.0, dtype=core.float32, device=_scale.device) + found_inf = core.full((), 0.0, dtype=core.float32, device=_scale.device) + + self._per_optimizer_states[id(optimizer)][ + "found_inf_per_device" + ] = self._unscale_grads_(optimizer, dummy_inv_scale, found_inf, True) + + return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] + + def _found_inf_per_device(self, optimizer: core.optim.Optimizer) -> Dict[str, Any]: + return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] \ No newline at end of file diff --git a/mindnlp/core/ao/__init__.py b/mindnlp/core/ao/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/mindnlp/core/ao/quantization/__init__.py b/mindnlp/core/ao/quantization/__init__.py new file mode 100644 index 000000000..adbaf5c02 --- /dev/null +++ b/mindnlp/core/ao/quantization/__init__.py @@ -0,0 +1 @@ +from .stubs import * \ No newline at end of file diff --git a/mindnlp/core/ao/quantization/stubs.py b/mindnlp/core/ao/quantization/stubs.py new file mode 100644 index 000000000..42cf63385 --- /dev/null +++ b/mindnlp/core/ao/quantization/stubs.py @@ -0,0 +1,36 @@ +from mindnlp.core import nn + +class QuantStub(nn.Module): + r"""Quantize stub module, before calibration, this is same as an observer, + it will be swapped as `nnq.Quantize` in `convert`. + + Args: + qconfig: quantization configuration for the tensor, + if qconfig is not provided, we will get qconfig from parent modules + """ + + def __init__(self, qconfig=None): + super().__init__() + if qconfig: + self.qconfig = qconfig + + def forward(self, x): + return x + + +class DeQuantStub(nn.Module): + r"""Dequantize stub module, before calibration, this is same as identity, + this will be swapped as `nnq.DeQuantize` in `convert`. + + Args: + qconfig: quantization configuration for the tensor, + if qconfig is not provided, we will get qconfig from parent modules + """ + + def __init__(self, qconfig=None): + super().__init__() + if qconfig: + self.qconfig = qconfig + + def forward(self, x): + return x \ No newline at end of file diff --git a/mindnlp/core/autograd/__init__.py b/mindnlp/core/autograd/__init__.py new file mode 100644 index 000000000..f579cc99b --- /dev/null +++ b/mindnlp/core/autograd/__init__.py @@ -0,0 +1,4 @@ +"""autograd""" +from .node import Node +from .function import Function +from .grad_mode import no_grad, enable_grad, inference_mode diff --git a/mindnlp/core/autograd/function.py b/mindnlp/core/autograd/function.py new file mode 100644 index 000000000..e7e35a925 --- /dev/null +++ b/mindnlp/core/autograd/function.py @@ -0,0 +1,57 @@ +"""functional autograd""" +from collections.abc import Generator +from dataclasses import dataclass +from typing import Tuple, Any, Optional, Type, Sequence +import functools + +@dataclass(unsafe_hash=True) +class Context: + """ + Context class is used by `Function` to store information during the forward pass. + """ + + no_grad: bool = False + saved_values: Tuple[Any, ...] = () + + def save_for_backward(self, *values: Any) -> None: + "Store the given `values` if they need to be used during backpropagation." + if self.no_grad: + return + self.saved_values = values + + @property + def saved_tensors(self) -> Tuple[Any, ...]: + return self.saved_values + +# Constructors +class Function: + @classmethod + def _backward(cls, ctx: Context, *grad_out): + return cls.backward(ctx, *grad_out) # type: ignore + + @classmethod + def _forward(cls, ctx: Context, *inps, **kwargs): + return cls.forward(ctx, *inps, **kwargs) # type: ignore + + @classmethod + def apply(cls, *vals, **kwargs): + # Create the context. + ctx = Context(not requires_grad) + # Call forward with the variables. + results = cls._forward(ctx, *vals, **kwargs) + requires_grad = any([x.requires_grad for x in vals]) + + if requires_grad: # cut useless nodes + generation = max([x.generation for x in vals]) + ctx.outputs = [weakref.ref(output) for output in outputs] + back = History(cls, ctx, generation) + for output in outputs: + output.set_creator(back) + + return outputs if len(outputs) > 1 else outputs[0] + + def forward(self, xs): + raise NotImplementedError() + + def backward(self, gys): + raise NotImplementedError() diff --git a/mindnlp/core/autograd/functions/custom.py b/mindnlp/core/autograd/functions/custom.py new file mode 100644 index 000000000..5f53b3bd6 --- /dev/null +++ b/mindnlp/core/autograd/functions/custom.py @@ -0,0 +1,59 @@ +from mindspore._c_expression import TensorPy as MSTensor + +from mindnlp import core +from core._prims.ascend import cast_npu +from core._prims.cpu import cast_cpu +from ..node import Node + + +class AccumulateGrad(Node): + def __init__(self): + super().__init__('AccumulateGrad') + self._post_hook = None + + def construct(self, input): + return input + + def bprop(self, input, output, grad): + if input.grad is None: + input.grad = grad + else: + input.grad += grad + + if self._post_hook is not None: + self._post_hook(input) + return grad + + def register_post_hook(self, hook): + self._post_hook = hook + + +class Cast(Node): + def __init__(self): + super().__init__('Cast') + self.used_bprop_inputs = [] + + def construct(self, input, dtype, device): + self.device = input.device + self.dtype = input.dtype + if device.type == 'cpu': + out = cast_cpu(input, dtype).get_value() + else: + out = cast_npu(input, dtype).get_value() + + output = core.Tensor.__new__(core.Tensor) + MSTensor.__init__(output, out) + output.device = device + return output + + def bprop(self, *args): + grad = args[-1] + if self.device.type == 'cpu': + out = cast_cpu(grad, self.dtype).get_value() + else: + out = cast_npu(grad, self.dtype).get_value() + + output = core.Tensor.__new__(core.Tensor) + MSTensor.__init__(output, out) + output.device = self.device + return output, None, None diff --git a/mindnlp/core/autograd/grad_mode.py b/mindnlp/core/autograd/grad_mode.py new file mode 100644 index 000000000..41884969e --- /dev/null +++ b/mindnlp/core/autograd/grad_mode.py @@ -0,0 +1,77 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""core module""" +from mindspore.common.api import _pynative_executor +from ..utils._contextlib import _NoParamDecoratorContextManager, _DecoratorContextManager + +class no_grad(_NoParamDecoratorContextManager): + """ + Context Manager to disable gradient calculation. When enter this context, we will disable calculate + gradient. When exit this context, we will resume its prev state. + Currently, it can only use in Pynative mode. It also can be used as decorator. + """ + def __init__(self) -> None: + super().__init__() + self.prev_state = False + + def __enter__(self): + self.prev_state = _pynative_executor.enable_grad() + _pynative_executor.set_enable_grad(False) + + def __exit__(self, exc_type, exc_val, exc_tb): + _pynative_executor.set_enable_grad(self.prev_state) + + +class enable_grad(_NoParamDecoratorContextManager): + """ + Context Manager to disable gradient calculation. When enter this context, we will disable calculate + gradient. When exit this context, we will resume its prev state. + Currently, it can only use in Pynative mode. It also can be used as decorator. + """ + + def __enter__(self): + self.prev_state = _pynative_executor.enable_grad() + _pynative_executor.set_enable_grad(True) + + def __exit__(self, exc_type, exc_val, exc_tb): + _pynative_executor.set_enable_grad(self.prev_state) + +class inference_mode(_DecoratorContextManager): + """ + Context Manager to enable or disable inference mode. + Currently, when enter this context, it is equivalent to enable_grad or no_grad. + When exit this context, we will resume its prev state. + Currently, it can only use in Pynative mode. It also can be used as decorator. + """ + + def __init__(self, mode: bool = True) -> None: + super().__init__() + self.mode = mode + self.prev_state = False + + def __new__(cls, mode: bool = True): + if isinstance(mode, bool): + return super().__new__(cls) + return cls()(mode) + + def __enter__(self) -> None: + self.prev_state = _pynative_executor.enable_grad() + _pynative_executor.set_enable_grad(not self.mode) + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + _pynative_executor.set_enable_grad(self.prev_state) + + def clone(self) -> "inference_mode": + return self.__class__(self.mode) \ No newline at end of file diff --git a/mindnlp/core/autograd/node.py b/mindnlp/core/autograd/node.py new file mode 100644 index 000000000..e6e4c7bd9 --- /dev/null +++ b/mindnlp/core/autograd/node.py @@ -0,0 +1,39 @@ +class Node: + def __init__(self, grad_fn, next_functions, name=""): + """ + A class representing a gradient function node in the computational graph. + Gradient function nodes encapsulate the gradient computation and propagation + for a specific operation in the graph. + + Args: + grad_fn: The gradient function. + next_functions: A tuple of next gradient function nodes. + name: The name of the gradient function node (optional). + """ + self.grad_fn = grad_fn + self.next_functions = next_functions + self.name = name + + def __call__(self, grad): + """ + Call the gradient function with the given gradient. + + Args: + grad: The gradient to be passed to the gradient function. + + Returns: + The result of the gradient function. + """ + if self.grad_fn: + return self.grad_fn(grad) + else: + raise RuntimeError("Trying to backward through the graph a second time.") + + def __repr__(self): + """ + Return a string representation of the gradient function node. + + Returns: + A string representation of the gradient function node. + """ + return f"Node={self.name}" \ No newline at end of file diff --git a/mindnlp/core/autograd/tape.py b/mindnlp/core/autograd/tape.py new file mode 100644 index 000000000..dd23d22aa --- /dev/null +++ b/mindnlp/core/autograd/tape.py @@ -0,0 +1,210 @@ +import logging +from mindspore.common.api import _pynative_executor +from mindspore.ops import GradOperation + +from mindnlp import core + +grad_ = GradOperation(False, True, True) + +def tape_func(): pass + +class GradientTape(object): + """Record operations for automatic differentiation. + + Operations are recorded if they are executed within this context manager and + at least one of their inputs is being "watched". + + Trainable variables (created by `tf.Variable` or `tf.compat.v1.get_variable`, + where `trainable=True` is default in both cases) are automatically watched. + Tensors can be manually watched by invoking the `watch` method on this context + manager. + + Note that only tensors with real or complex dtypes are differentiable. + """ + + def __init__(self, persistent=False, watch_accessed_variables=True): + """Creates a new GradientTape. + + Args: + persistent: Boolean controlling whether a persistent gradient tape + is created. False by default, which means at most one call can + be made to the gradient() method on this object. + watch_accessed_variables: Boolean controlling whether the tape will + automatically `watch` any (trainable) variables accessed while the tape + is active. Defaults to True meaning gradients can be requested from any + result computed in the tape derived from reading a trainable `Variable`. + If False users must explicitly `watch` any `Variable`s they want to + request gradients from. + """ + self._tape = None + self._persistent = persistent + self._watch_accessed_variables = watch_accessed_variables + self._watched_variables = () + self._recording = False + + def __enter__(self): + """Enters a context inside which operations are recorded on this tape.""" + self._push_tape() + return self + + def __exit__(self, typ, value, traceback): + """Exits the recording context, no further operations are traced.""" + if self._recording: + self._pop_tape() + + def _push_tape(self): + """Pushes a new tape onto the tape stack.""" + if self._recording: + raise ValueError( + "Tape is still recording, This can happen if you try to " + "re-enter an already-active tape." + ) + _pynative_executor.set_grad_flag(True) + _pynative_executor.new_graph(tape_func) + self._recording = True + + def _pop_tape(self): + if not self._recording: + raise ValueError("Tape is not recording.") + self._recording = False + + def _ensure_recording(self): + """Ensures that this tape is recording.""" + if not self._recording: + try: + self._push_tape() + yield + finally: + self._pop_tape() + else: + yield + + def stop_recording(self): + """Temporarily stops recording operations on this tape. + + Operations executed while this context manager is active will not be + recorded on the tape. This is useful for reducing the memory used by tracing + all computations. + + For example: + + >>> x = tf.constant(4.0) + >>> with tf.GradientTape() as tape: + ... with tape.stop_recording(): + ... y = x ** 2 + >>> dy_dx = tape.gradient(y, x) + >>> print(dy_dx) + None + + Yields: + None + Raises: + RuntimeError: if the tape is not currently recording. + """ + if self._tape is None: + raise RuntimeError( + "Trying to stop recording a tape which is not recording." + ) + self._pop_tape() + try: + yield + finally: + self._push_tape() + + def reset(self): + """Clears all information stored in this tape. + + Equivalent to exiting and reentering the tape context manager with a new + tape. For example, the two following code blocks are equivalent: + + ``` + with tf.GradientTape() as t: + loss = loss_fn() + with tf.GradientTape() as t: + loss += other_loss_fn() + t.gradient(loss, ...) # Only differentiates other_loss_fn, not loss_fn + + + # The following is equivalent to the above + with tf.GradientTape() as t: + loss = loss_fn() + t.reset() + loss += other_loss_fn() + t.gradient(loss, ...) # Only differentiates other_loss_fn, not loss_fn + ``` + + This is useful if you don't want to exit the context manager for the tape, + or can't because the desired reset point is inside a control flow construct: + + ``` + with tf.GradientTape() as t: + loss = ... + if loss > k: + t.reset() + ``` + """ + self._pop_tape() + self._tape = None + self._push_tape() + + def watched_variables(self): + """Returns variables watched by this tape in order of construction.""" + if self._tape is not None: + self._watched_variables = self._tape.watched_variables() + return self._watched_variables + + def gradient( + self, + target, + sources, + output_gradients=None, + ): + """Computes the gradient using operations recorded in context of this tape. + + Note: Unless you set `persistent=True` a GradientTape can only be used to + compute one set of gradients (or jacobians). + + In addition to Tensors, gradient also supports RaggedTensors. For example, + + >>> x = tf.ragged.constant([[1.0, 2.0], [3.0]]) + >>> with tf.GradientTape() as g: + ... g.watch(x) + ... y = x * x + >>> g.gradient(y, x) + + + Args: + target: a list or nested structure of Tensors or Variables or + CompositeTensors to be differentiated. + sources: a list or nested structure of Tensors or Variables or + CompositeTensors. `target` will be differentiated against elements in + `sources`. + output_gradients: a list of gradients, one for each differentiable + element of target. Defaults to None. + unconnected_gradients: a value which can either hold 'none' or 'zero' and + alters the value which will be returned if the target and sources are + unconnected. The possible values and effects are detailed in + 'UnconnectedGradients' and it defaults to 'none'. + + Returns: + a list or nested structure of Tensors (or IndexedSlices, or None, or + CompositeTensor), one for each element in `sources`. Returned structure + is the same as the structure of `sources`. + + Raises: + RuntimeError: If called on a used, non-persistent tape. + RuntimeError: If called inside the context of the tape. + TypeError: If the target is a None object. + ValueError: If the target is a variable or if unconnected gradients is + called with an unknown value. + """ + if target.shape == (): + gradient = core.tensor(1, dtype=target.dtype, device=target.device) + else: + raise RuntimeError("grad must specified for non-0-tensor") + + _pynative_executor.end_graph(tape_func, target.data) + weights = list(sources) + _pynative_executor.check_run(grad_, tape_func, weights, None, gradient) + grads = _pynative_executor.grad(tape_func, grad_, weights, None, gradient) + return grads diff --git a/mindnlp/core/autograd/variable.py b/mindnlp/core/autograd/variable.py new file mode 100644 index 000000000..cd534fb55 --- /dev/null +++ b/mindnlp/core/autograd/variable.py @@ -0,0 +1,16 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +from mindnlp.core import Tensor +import logging + +class Variable(Tensor): + def __new__(cls, data, requires_grad=None, volatile=None): + logging.warning("The Variable API has been deprecated, use Tensor instead.") + obj = Tensor.__new__(cls) + return obj + + def __init__(self, data, requires_grad=None, volatile=None): + if volatile: + logging.warning("UserWarning:volatile was removed (Variable.volatile is always False), " + "please use with core.no_grad() instead.") + Tensor.__init__(self, data, requires_grad=requires_grad) diff --git a/mindnlp/core/compiler/__init__.py b/mindnlp/core/compiler/__init__.py new file mode 100644 index 000000000..2a62b0a92 --- /dev/null +++ b/mindnlp/core/compiler/__init__.py @@ -0,0 +1,13 @@ +import functools + +def disable(fn=None, recursive=True, *, reason=None): + def wrap_func(func): + @functools.wraps(func) + def staging_specialize(*args, **kwargs): + return func(*args, **kwargs) + + return staging_specialize + + if fn is not None: + return wrap_func(fn) + return wrap_func \ No newline at end of file diff --git a/mindnlp/core/configs.py b/mindnlp/core/configs.py new file mode 100644 index 000000000..c046d3ec7 --- /dev/null +++ b/mindnlp/core/configs.py @@ -0,0 +1,21 @@ +from packaging import version +import mindspore +from mindspore._c_expression import MSContext # pylint: disable=no-name-in-module, import-error + +SOC = MSContext.get_instance().get_ascend_soc_version() +DEVICE_TARGET = mindspore.get_context('device_target') +SUPPORT_BF16 = SOC in ["ascend910b", "ascend910_93"] +ON_ORANGE_PI = '310b' in SOC +USE_PYBOOST = DEVICE_TARGET == 'Ascend' +DEFAULT_DTYPE = mindspore.float32 +MS27 = '.'.join(mindspore.__version__.split('.')[:2]) >= '2.7' + + +def set_pyboost(mode: bool): + """set global pyboost""" + global USE_PYBOOST + USE_PYBOOST = mode + +def use_pyboost(): + """set global pyboost""" + return USE_PYBOOST \ No newline at end of file diff --git a/mindnlp/core/cuda/__init__.py b/mindnlp/core/cuda/__init__.py new file mode 100644 index 000000000..98e949b8c --- /dev/null +++ b/mindnlp/core/cuda/__init__.py @@ -0,0 +1,43 @@ +from typing import Any + +import mindspore +from mindspore import get_rng_state, set_rng_state, manual_seed +from mindspore.hal import * + +from mindnlp import core + +FloatTensor = core.FloatTensor +HalfTensor = core.FloatTensor +BFloat16Tensor = core.BFloat16Tensor + +def manual_seed_all(seed: int): + manual_seed(seed) + +def current_device(): + return core.device('cuda', 0) + +def is_available(): + return mindspore.get_context('device_target') == 'GPU' + +def set_device(device): + pass + +def _lazy_call(callable, **kwargs): + callable() + +class device: + r"""Context-manager that changes the selected device. + + Args: + device (core.device or int): device index to select. It's a no-op if + this argument is a negative integer or ``None``. + """ + + def __init__(self, device: Any): + self.prev_idx = -1 + + def __enter__(self): + self.prev_idx = -1 + + def __exit__(self, type: Any, value: Any, traceback: Any): + return False diff --git a/mindnlp/core/cuda/amp/__init__.py b/mindnlp/core/cuda/amp/__init__.py new file mode 100644 index 000000000..34b9bfbaa --- /dev/null +++ b/mindnlp/core/cuda/amp/__init__.py @@ -0,0 +1,9 @@ +from .autocast_mode import autocast, custom_bwd, custom_fwd +from .grad_scaler import GradScaler + +__all__ = [ + "autocast", + "custom_bwd", + "custom_fwd", + "GradScaler" +] \ No newline at end of file diff --git a/mindnlp/core/cuda/amp/autocast_mode.py b/mindnlp/core/cuda/amp/autocast_mode.py new file mode 100644 index 000000000..cd22ce019 --- /dev/null +++ b/mindnlp/core/cuda/amp/autocast_mode.py @@ -0,0 +1,90 @@ +# mypy: allow-untyped-defs +import functools +from typing import Any +from typing_extensions import deprecated + +from mindnlp import core + + +__all__ = ["autocast", "custom_fwd", "custom_bwd"] + + +class autocast(core.amp.autocast_mode.autocast): + r"""See :class:`core.autocast`. + + ``core.cuda.amp.autocast(args...)`` is deprecated. Please use ``core.amp.autocast("cuda", args...)`` instead. + """ + + @deprecated( + "`core.cuda.amp.autocast(args...)` is deprecated. " + "Please use `core.amp.autocast('cuda', args...)` instead.", + category=FutureWarning, + ) + def __init__( + self, + enabled: bool = True, + dtype: core.dtype = core.float16, + cache_enabled: bool = True, + ): + if core._jit_internal.is_scripting(): + self._enabled = enabled + self.device = "cuda" + self.fast_dtype = dtype + return + super().__init__( + "cuda", enabled=enabled, dtype=dtype, cache_enabled=cache_enabled + ) + + def __enter__(self): + if core._jit_internal.is_scripting(): + return self + return super().__enter__() + + # TODO: discuss a unified TorchScript-friendly API for autocast + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any): # type: ignore[override] + if core._jit_internal.is_scripting(): + return + return super().__exit__(exc_type, exc_val, exc_tb) + + def __call__(self, func): + if core._jit_internal.is_scripting(): + return func + return super().__call__(func) + + +# Preserved only for BC reasons +@deprecated( + "`core.cuda.amp.autocast_mode._cast(value, dtype)` is deprecated. " + "Please use `core.amp.autocast_mode._cast(value, 'cuda', dtype)` instead.", + category=FutureWarning, +) +def _cast(value, dtype): + return core.amp.autocast_mode._cast(value, "cuda", dtype) + + +@deprecated( + "`core.cuda.amp.custom_fwd(args...)` is deprecated. " + "Please use `core.amp.custom_fwd(args..., device_type='cuda')` instead.", + category=FutureWarning, +) +def custom_fwd(fwd=None, *, cast_inputs=None): + """ + ``core.cuda.amp.custom_fwd(args...)`` is deprecated. Please use + ``core.amp.custom_fwd(args..., device_type='cuda')`` instead. + """ + return functools.partial(core.amp.custom_fwd, device_type="cuda")( + fwd=fwd, cast_inputs=cast_inputs + ) + + +@deprecated( + "`core.cuda.amp.custom_bwd(args...)` is deprecated. " + "Please use `core.amp.custom_bwd(args..., device_type='cuda')` instead.", + category=FutureWarning, +) +def custom_bwd(bwd): + """ + ``core.cuda.amp.custom_bwd(args...)`` is deprecated. Please use + ``core.amp.custom_bwd(args..., device_type='cuda')`` instead. + """ + return functools.partial(core.amp.custom_bwd, device_type="cuda")(bwd) diff --git a/mindnlp/core/cuda/amp/grad_scaler.py b/mindnlp/core/cuda/amp/grad_scaler.py new file mode 100644 index 000000000..3fd4e4916 --- /dev/null +++ b/mindnlp/core/cuda/amp/grad_scaler.py @@ -0,0 +1,38 @@ +from typing_extensions import deprecated + +from mindnlp import core + +# We need to keep this unused import for BC reasons +from ...amp.grad_scaler import OptState # noqa: F401 + + +__all__ = ["GradScaler"] + + +class GradScaler(core.amp.GradScaler): + r""" + See :class:`torch.amp.GradScaler`. + ``torch.cuda.amp.GradScaler(args...)`` is deprecated. Please use ``torch.amp.GradScaler("cuda", args...)`` instead. + """ + + @deprecated( + "`torch.cuda.amp.GradScaler(args...)` is deprecated. " + "Please use `torch.amp.GradScaler('cuda', args...)` instead.", + category=FutureWarning, + ) + def __init__( + self, + init_scale: float = 2.0**16, + growth_factor: float = 2.0, + backoff_factor: float = 0.5, + growth_interval: int = 2000, + enabled: bool = True, + ) -> None: + super().__init__( + "cuda", + init_scale=init_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + enabled=enabled, + ) \ No newline at end of file diff --git a/mindnlp/core/dispatcher.py b/mindnlp/core/dispatcher.py new file mode 100644 index 000000000..165dd5617 --- /dev/null +++ b/mindnlp/core/dispatcher.py @@ -0,0 +1,48 @@ +from mindnlp import core +from mindnlp.core.types import device as device_ +from mindnlp.core._prims import ascend, cpu + +device_map = { + 'cpu': 'CPU', + 'npu': 'Ascend', + 'cuda': 'GPU' +} + +class SingletonMeta(type): + _instances = {} + + def __call__(cls, *args, **kwargs): + if cls not in cls._instances: + instance = super().__call__(*args, **kwargs) + cls._instances[cls] = instance + return cls._instances[cls] + +class Dispatcher(metaclass=SingletonMeta): + def __init__(self): + self._registry = { + 'cpu': {}, + 'npu': {}, + 'gpu': {} + } + + def register(self, func_name, device, func): + self._registry[device][func_name] = func + + def dispatch(self, func_name, *args, **kwargs): + device = kwargs.pop('device', None) + if isinstance(device, str): + device = device_(device) + + if device is None: + device = args[0].device + + func = self._registry[device.type].get(func_name, None) + if func is None: + raise RuntimeError(f"No implementation for function: {func_name} on {device.type}.") + return func(*args), device + +dispatcher = Dispatcher() +for func_name in ascend.__all__: + dispatcher.register(func_name.replace('_npu', ''), 'npu', getattr(ascend, func_name)) +for func_name in cpu.__all__: + dispatcher.register(func_name.replace('_cpu', ''), 'cpu', getattr(cpu, func_name)) diff --git a/mindnlp/core/distributed/__init__.py b/mindnlp/core/distributed/__init__.py new file mode 100644 index 000000000..758d0feec --- /dev/null +++ b/mindnlp/core/distributed/__init__.py @@ -0,0 +1,98 @@ +# mypy: allow-untyped-defs +import logging +import pdb +import sys +import traceback +import typing + +from mindnlp import core + + +log = logging.getLogger(__name__) + + +def is_available() -> bool: + """ + Return ``True`` if the distributed package is available. + + Otherwise, + ``core.distributed`` does not expose any other APIs. Currently, + ``core.distributed`` is available on Linux, MacOS and Windows. Set + ``USE_DISTRIBUTED=1`` to enable it when building PyTorch from source. + Currently, the default value is ``USE_DISTRIBUTED=1`` for Linux and Windows, + ``USE_DISTRIBUTED=0`` for MacOS. + """ + return True + +# Custom Runtime Errors thrown from the distributed package +DistError = RuntimeError +DistBackendError = RuntimeError +DistNetworkError = RuntimeError +DistStoreError = RuntimeError + +if is_available(): + from .c10d import ( + # _broadcast_coalesced, + # _compute_bucket_assignment_by_size, + # _ControlCollectives, + # _DEFAULT_FIRST_BUCKET_BYTES, + # _make_nccl_premul_sum, + # _register_builtin_comm_hook, + # _register_comm_hook, + # _StoreCollectives, + # _test_python_store, + # _verify_params_across_processes, + # Backend as _Backend, + # BuiltinCommHookType, + # DebugLevel, + # FileStore, + # get_debug_level, + # GradBucket, + # Logger, + PrefixStore, + ProcessGroup as ProcessGroup, + # Reducer, + # set_debug_level, + # set_debug_level_from_env, + Store, + # TCPStore, + Work as _Work, + ) + + + # from .device_mesh import DeviceMesh, init_device_mesh + + # Variables prefixed with underscore are not auto imported + # See the comment in `distributed_c10d.py` above `_backend` on why we expose + # this. + from .distributed_c10d import * # noqa: F403 + from .distributed_c10d import ( + _all_gather_base, + _coalescing_manager, + _CoalescingManager, + _create_process_group_wrapper, + _get_process_group_name, + _rank_not_in_group, + _reduce_scatter_base, + get_node_local_rank, + ) + from .remote_device import _remote_device + # from .rendezvous import ( + # _create_store_from_options, + # register_rendezvous_handler, + # rendezvous, + # ) + + # set_debug_level_from_env() + +else: + # This stub is sufficient to get + # python test/test_public_bindings.py -k test_correct_module_names + # working even when USE_DISTRIBUTED=0. Feel free to add more + # stubs as necessary. + # We cannot define stubs directly because they confuse pyre + + class _ProcessGroupStub: + pass + + sys.modules["core.distributed"].ProcessGroup = _ProcessGroupStub # type: ignore[attr-defined] \ No newline at end of file diff --git a/mindnlp/core/distributed/_checkpointable.py b/mindnlp/core/distributed/_checkpointable.py new file mode 100644 index 000000000..04a903d20 --- /dev/null +++ b/mindnlp/core/distributed/_checkpointable.py @@ -0,0 +1,38 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates +from typing import Any, Protocol, runtime_checkable + +from mindnlp import core + + +@runtime_checkable +class _Checkpointable(Protocol): # noqa: PYI046 + """ + Interface for checkpointable objects. + Implemented as a protocol, implicit subtyping is supported so subclasses do not need to inherit this explicitly. + This is to allow arbitrary objects/tensor subclasses to hook into DCP seamlessly through implementing the interface. + """ + + def __create_write_items__(self, fqn: str, object: Any): + """ + Return a list of WriteItems based on object's contents. + """ + raise NotImplementedError( + "_Checkpointable._create_write_items is not implemented" + ) + + def __create_chunk_list__(self): + """ + Return a list of `ChunkStorageMetadata` based on object's contents. + """ + raise NotImplementedError( + "_Checkpointable._create_chunk_list is not implemented" + ) + + def __get_tensor_shard__(self, index) -> core.Tensor: + """ + Return a 'core.Tensor' shard based on 'MetadataIndex'. + """ + raise NotImplementedError( + "_Checkpointable._get_tensor_shard is not implemented" + ) diff --git a/mindnlp/core/distributed/_composable/__init__.py b/mindnlp/core/distributed/_composable/__init__.py new file mode 100644 index 000000000..f347203e6 --- /dev/null +++ b/mindnlp/core/distributed/_composable/__init__.py @@ -0,0 +1,4 @@ +from .checkpoint_activation import checkpoint +from .contract import _get_registry, contract +from .fully_shard import fully_shard +from .replicate import replicate diff --git a/mindnlp/core/distributed/_composable/checkpoint_activation.py b/mindnlp/core/distributed/_composable/checkpoint_activation.py new file mode 100644 index 000000000..fe4195cdb --- /dev/null +++ b/mindnlp/core/distributed/_composable/checkpoint_activation.py @@ -0,0 +1,126 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +from contextlib import contextmanager, nullcontext +from typing import Any, ContextManager, Dict, Optional, Tuple + +from mindnlp import core +from mindnlp import core.nn as nn +from core.utils.checkpoint import ( + _checkpoint_without_reentrant_generator, + _DEFAULT_DETERMINISM_MODE, +) + +from .contract import contract + + +@contextmanager +def _no_hook(module: nn.Module, user_ctx: Optional[ContextManager] = None): + r""" + Disable hooks installed by checkpoint to avoid unintentional recursion + during backward recomputation. + """ + + with user_ctx if user_ctx else nullcontext(): + orig_enable_hook = checkpoint.state(module).enable_hook + checkpoint.state(module).enable_hook = False + try: + yield + finally: + checkpoint.state(module).enable_hook = orig_enable_hook + + +@contract() +def checkpoint(module: nn.Module, **kwargs) -> nn.Module: + r""" + This is a composable activation checkpointing API. Unlike functional + activation checkpointing APIs, this one does not require changing model + source code. Unlike ``nn.Module`` wrapper activation checkpointing APIs, + this one does not modify model structure or fully-qualified names either. + Under the hood, it registers activation checkpointing logic as pre- and + post-forward hooks. Hence, this API can be easily applied to any model or + sub-modules in the model. + + Args: + module (nn.Module): the target model or sub-module to apply activation + checkpointing. + + Example:: + >>> # xdoctest: +SKIP + >>> from mindnlp import core.nn as nn + >>> + >>> class MyModel(nn.Module): + >>> def __init__(self) -> None: + >>> super().__init__() + >>> self.l1 = nn.Linear(10, 10) + >>> self.l2 = nn.Linear(10, 10) + >>> + >>> def forward(self, x): + >>> return self.l2(self.l1(x)) + >>> + >>> model = MyModel() + >>> checkpoint(model.l1) # apply activation checkpointing only to l1 + >>> model(core.zeros(2, 10)).sum().backward() + + """ + core._C._log_api_usage_once("core.distributed.checkpoint") + + use_reentrant = kwargs.pop("use_reentrant", False) + if use_reentrant: + raise NotImplementedError( + "use_reentrant=True is not supported in composable checkpoint. " + "Please use core.utils.checkpoint.checkpoint instead." + ) + preserve_rng_state = kwargs.pop("preserve_rng_state", True) + user_context_fns = kwargs.pop("context_fn", None) + determinism_check = kwargs.pop("determinism_check", _DEFAULT_DETERMINISM_MODE) + debug = kwargs.pop("debug", False) + + if kwargs: + raise ValueError( + "Unexpected keyword arguments: " + ",".join(arg for arg in kwargs) + ) + + def forward_pre_hook( + module: nn.Module, args: Tuple[Any, ...], kwargs: Dict[str, Any] + ) -> None: + if checkpoint.state(module).enable_hook: + + def context_fns(): + if user_context_fns is not None: + ctx1, ctx2 = user_context_fns() + return ctx1, _no_hook(module, ctx2) + else: + return nullcontext(), _no_hook(module) + + checkpoint.state( + module + )._ac_generator = _checkpoint_without_reentrant_generator( + module, + preserve_rng_state, + context_fns, + determinism_check, + debug, + *args, + **kwargs, + ) + next(checkpoint.state(module)._ac_generator) + + def forward_hook(module: nn.Module, inputs: Tuple[Any, ...], output: Any) -> Any: + if checkpoint.state(module).enable_hook: + try: + next(checkpoint.state(module)._ac_generator) + except StopIteration: + pass + else: + raise RuntimeError( + "Expected non-reentrant activation checkpoint generator to be exhausted, but it was not!" + ) + + # Ensure that we no longer hold on to the generator. always_call=True helps ensure we + # clear this even in the case of exception in fwd pass. + checkpoint.state(module)._ac_generator = None + + checkpoint.state(module).enable_hook = True + module.register_forward_pre_hook(forward_pre_hook, with_kwargs=True) + module.register_forward_hook(forward_hook, prepend=True, always_call=True) + return module diff --git a/mindnlp/core/distributed/_composable/contract.py b/mindnlp/core/distributed/_composable/contract.py new file mode 100644 index 000000000..038cefad0 --- /dev/null +++ b/mindnlp/core/distributed/_composable/contract.py @@ -0,0 +1,224 @@ +# mypy: allow-untyped-defs +import uuid +from collections import OrderedDict +from functools import wraps +from typing import Callable, Dict, List, Optional, Sequence, Type, Union + +from mindnlp import core +from mindnlp import core.nn as nn +from core.distributed._composable_state import _State +from core.distributed.utils import _get_root_modules + + +def generate_state_key(string="__composable_api_state_key"): + return f"{string}_{str(uuid.uuid4())}" + + +STATE_KEY = generate_state_key() +REGISTRY_KEY = generate_state_key() + + +# TODO: we can add additional info to RegistryItem to share across APIs. E.g., +# we can add args and kwargs here, and then we can detect whether fully_shard +# is combined with reentrant activation checkpointing and error out with a clear +# message. +class RegistryItem: + pass + + +def contract(state_cls: Type[_State] = _State): + r""" + Decorate a function as a composable distributed API, where the first + argument of the function must be an :class:`nn.Module` instance or sequence + of :class:`nn.Module` instances. + + The decorator verifies that the decorated function does not modify + fully-qualified names (FQNs) for parameters, buffers, or modules. The + decorated function can return different module instances than the input + modules; the FQN invariant will be enforced following the input order. + + When a function ``func`` is decorated by ``@contract()``, a + ``.state(module: nn.Module)`` method will be installed to the decorated + function. Then you can retrieve and modify the state on a module by calling + ``func.state(module)``. + + Example:: + >>> # xdoctest: +SKIP + >>> from mindnlp import core.nn as nn + >>> + >>> class MyModel(nn.Module): + >>> def __init__(self) -> None: + >>> super().__init__() + >>> self.l1 = nn.Linear(10, 10) + >>> self.l2 = nn.Linear(10, 10) + >>> + >>> def forward(self, x): + >>> return self.l2(self.l1(x)) + >>> + >>> @contract() + >>> def my_feature(module: nn.Module) -> nn.Module: + >>> my_feature.state(module).some_state = "any value" + >>> return module + >>> + >>> model = MyModel() + >>> my_feature(model.l1) + >>> assert my_feature.state(model.l1).some_state == "any value" + >>> my_feature(model.l2) + >>> model(core.randn(2, 10)).sum().backward() + """ + + # wraps will make functions decorated with contract() pickleable - needed for integration with core.package + @wraps(state_cls) + def inner(func): + @wraps(func) + def wrapper( + module: Union[nn.Module, Sequence[nn.Module]], *args, **kwargs + ) -> Optional[nn.Module]: + inp_module = module + if isinstance(module, nn.Module): + modules = [module] + else: + # If the user passes a sequence of modules, then we assume that + # we only need to insert the state object on the root modules + # (i.e. those without a parent) among the passed-in modules. + modules = _get_root_modules(list(module)) + state = state_cls() # shared across all modules + registry_item = RegistryItem() # shared across all modules + + # `func` is allowed to return different module instances than the + # input modules as long as FQNs are preserved following the input + # module order + all_orig_named_params: List[Dict[str, nn.Parameter]] = [] + all_orig_named_buffers: List[Dict[str, core.Tensor]] = [] + all_orig_named_modules: List[Dict[str, nn.Module]] = [] + + for module in modules: + default_all_state: Dict[Callable, _State] = OrderedDict() + default_registry: Dict[str, RegistryItem] = OrderedDict() + all_state: Dict[Callable, _State] = module.__dict__.setdefault( # type: ignore[call-overload] + STATE_KEY, default_all_state + ) + if not isinstance(all_state, dict): + raise AssertionError( + f"Distributed composable API states corrupted: {all_state}" + ) + registry: Dict[str, RegistryItem] = module.__dict__.setdefault( # type: ignore[call-overload] + REGISTRY_KEY, default_registry + ) + if not isinstance(registry, dict): + raise AssertionError( + f"Distributed composable API registry corrupted: {registry}" + ) + if func in all_state or func.__name__ in registry: + raise AssertionError( + "Each distinct composable distributed API can only be applied to a " + f"module once. {func.__name__} has already been applied to the " + f"following module:\n{module}" + ) + all_state.setdefault(func, state) + registry.setdefault(func.__name__, registry_item) + + all_orig_named_params.append(OrderedDict(module.named_parameters())) + all_orig_named_buffers.append(OrderedDict(module.named_buffers())) + all_orig_named_modules.append(OrderedDict(module.named_modules())) + + updated = func(inp_module, *args, **kwargs) + if updated is None: + updated = inp_module + if isinstance(updated, nn.Module): + updated_modules = [updated] + else: + updated_modules = _get_root_modules(list(inp_module)) # type: ignore[arg-type] + + all_new_named_params: List[Dict[str, nn.Parameter]] = [] + all_new_named_buffers: List[Dict[str, core.Tensor]] = [] + all_new_named_modules: List[Dict[str, nn.Module]] = [] + for module in updated_modules: + all_new_named_params.append(OrderedDict(module.named_parameters())) + all_new_named_buffers.append(OrderedDict(module.named_buffers())) + all_new_named_modules.append(OrderedDict(module.named_modules())) + + num_orig_modules = len(all_orig_named_modules) + num_new_modules = len(all_new_named_modules) + if num_orig_modules != num_new_modules: + raise AssertionError( + f"{func.__name__} should return the same number of modules as input modules" + f"Inputs: {num_orig_modules} modules\n" + f"Outputs: {num_new_modules} modules" + ) + + def check_fqn(orig_fqns: List[str], new_fqns: List[str], check_key: str): + if orig_fqns == new_fqns: + return + + orig_fqn_set, new_fqn_set = set(orig_fqns), set(new_fqns) + orig_only = orig_fqn_set - new_fqn_set + new_only = new_fqn_set - orig_fqn_set + if len(orig_only) or len(new_only): + raise RuntimeError( + f"{check_key}" + "Composable distributed API implementations cannot modify FQNs.\n" + f"FQNs only in original: {orig_only}\n" + f"FQNs only in new: {new_only}" + ) + else: + raise RuntimeError( + f"{check_key}" + "Composable distributed API implementations cannot modify " + "the order of FQNs.\n" + f"Original FQNs: {orig_only}\n" + f"New FQNs: {new_only}" + ) + + for orig_named_params, new_named_params in zip( + all_orig_named_params, all_new_named_params + ): + check_fqn( + list(orig_named_params.keys()), + list(new_named_params.keys()), + "Checking parameters: ", + ) + for orig_named_buffers, new_named_buffers in zip( + all_orig_named_buffers, all_new_named_buffers + ): + check_fqn( + list(orig_named_buffers.keys()), + list(new_named_buffers.keys()), + "Checking buffers: ", + ) + for orig_named_modules, new_named_modules in zip( + all_orig_named_modules, all_new_named_modules + ): + check_fqn( + list(orig_named_modules.keys()), + list(new_named_modules.keys()), + "Checking modules: ", + ) + + # TODO: verify that installed distributed paradigms are compatible with + # each other. + + return updated + + def get_state(module: nn.Module) -> Optional[_State]: + return module.__dict__.setdefault( # type: ignore[call-overload] + STATE_KEY, + {}, # TODO(@yhcharles): this is a temporary fix, need a better way + ).get( + func + ) # type: ignore[call-overload] + + wrapper.state = get_state # type: ignore[attr-defined] + + return wrapper + + return inner + + +def _get_registry(module: nn.Module) -> Optional[Dict[str, RegistryItem]]: + r""" + Get an ``OrderedDict`` of composable APIs that have been applied to the + ``module``, indexed by the API name. If no API has been applied, then this + returns ``None``. + """ + return getattr(module, REGISTRY_KEY, None) diff --git a/mindnlp/core/distributed/_composable/fsdp/__init__.py b/mindnlp/core/distributed/_composable/fsdp/__init__.py new file mode 100644 index 000000000..476ce91ed --- /dev/null +++ b/mindnlp/core/distributed/_composable/fsdp/__init__.py @@ -0,0 +1,2 @@ +from ._fsdp_api import CPUOffloadPolicy, MixedPrecisionPolicy, OffloadPolicy +from .fully_shard import FSDPModule, fully_shard, register_fsdp_forward_method diff --git a/mindnlp/core/distributed/_composable/fsdp/_fsdp_api.py b/mindnlp/core/distributed/_composable/fsdp/_fsdp_api.py new file mode 100644 index 000000000..e7c3a519f --- /dev/null +++ b/mindnlp/core/distributed/_composable/fsdp/_fsdp_api.py @@ -0,0 +1,80 @@ +# mypy: allow-untyped-defs +from dataclasses import dataclass +from typing import Optional + +from mindnlp import core + + +@dataclass(frozen=True) +class MixedPrecisionPolicy: + """ + This configures FSDP's mixed precision. Unlike autocast, this applies mixed + precision at the module level, not op level, which means low-precision + activations are saved for backward and high-to-low-precision casts are + incurred only at module boundaries. + + FSDP works well with module-level mixed precision since it keeps the + high-precision sharded parameters in memory anyway. In other words, FSDP + does not require any extra memory to keep a high-precision copy of the + parameters for the optimizer step. + + Attributes: + param_dtype (Optional[core.dtype]): This specifies the dtype for + the unsharded parameter and hence the dtype for forward/backward + computation and the parameter all-gather. If this is ``None``, then + the unsharded parameter uses the original dtype. The optimizer step + uses the sharded parameter in the original dtype. (Default: + ``None``) + reduce_dtype (Optional[core.dtype]): This specifies the dtype for + gradient reduction (i.e. reduce-scatter or all-reduce). If this is + ``None`` but ``param_dtype`` is not ``None``, then the reduction + uses the compute dtype. This can be used to run gradient reduction + in full precision while using low precision for compute. If also + gradient reduction is disabled via :meth:`set_requires_gradient_sync`, + then FSDP will accumulate gradients using ``reduce_dtype``. + (Default: ``None``) + output_dtype (Optional[core.dtype]): This specifies the dtype for + casting floating-point forward outputs. This can be used to + help implement cases where different modules have different mixed + precision policies. (Default: ``None``) + cast_forward_inputs (bool): This specifies whether FSDP should cast the + forward's floating-point input tensors to ``param_dtype`` or not. + """ + + param_dtype: Optional[core.dtype] = None + reduce_dtype: Optional[core.dtype] = None + output_dtype: Optional[core.dtype] = None + cast_forward_inputs: bool = True + + def __post_init__(self): + # Clamp `reduce_dtype` to `None` if no casting is required: since + # gradients are computed in `param_dtype`, if `reduce_dtype` matches, + # then we do not need extra casting + if self.param_dtype == self.reduce_dtype: + # Bypass the frozen dataclass checks + object.__setattr__(self, "reduce_dtype", None) + + +@dataclass +class OffloadPolicy: + """This base class represents the policy of no offloading.""" + + +@dataclass +class CPUOffloadPolicy(OffloadPolicy): + """ + This offload policy offloads parameters, gradients, and optimizer states to + CPU. Sharded parameters are copied host-to-device before all-gather. The + all-gathered parameters are freed according to ``reshard_after_forward``. + Sharded gradients are copied device-to-host in backward, and the optimizer + step runs on CPU with CPU optimizer states. + + Attributes: + pin_memory (bool): Whether to pin sharded parameter and gradient + memory. Pinning memory allows H2D/D2H copying without blocking the + CPU and in turn, overlap with compute, but pinned memory cannot be + used by other processes. Set this to ``False`` if you have + insufficient CPU memory. (Default: ``True``) + """ + + pin_memory: bool = True diff --git a/mindnlp/core/distributed/_composable/fsdp/_fsdp_collectives.py b/mindnlp/core/distributed/_composable/fsdp/_fsdp_collectives.py new file mode 100644 index 000000000..5fc0a11fd --- /dev/null +++ b/mindnlp/core/distributed/_composable/fsdp/_fsdp_collectives.py @@ -0,0 +1,547 @@ +# mypy: allow-untyped-decorators +from typing import cast, List, NamedTuple, Optional, Tuple, Union + +from mindnlp import core +from mindnlp import core.distributed as dist +from core.distributed.device_mesh import _get_device_handle +from core.distributed.distributed_c10d import ReduceOp +from core.distributed.tensor import DTensor + +from ._fsdp_common import ( + _get_dim0_padded_size, + _raise_assert_with_print, + _to_dtype_if_needed, + compiled_autograd_enabled, +) +from ._fsdp_param import FSDPParam, ShardedState + + +class AllGatherResult(NamedTuple): + all_gather_output: core.Tensor + all_gather_event: Optional[core.Event] + all_gather_work: Optional[dist.distributed_c10d.Work] + # For each parameter, the all-gather input dtype for each input + param_all_gather_input_dtypes: List[List[core.dtype]] + # For each parameter, the all-gather input numel for each input + param_all_gather_input_numels: List[List[int]] + # 1D flattened version of `param_all_gather_input_numels` saved to avoid + # CPU overhead from recomputing + all_gather_input_split_sizes: List[int] + + +lib = core.library.Library("fsdp", "FRAGMENT") # noqa: TOR901 + +lib.define( + """ + all_gather_copy_in( + Tensor[] all_gather_inputs, + SymInt[] inp_split_sizes, + SymInt all_gather_input_numel, + SymInt world_size, + SymInt rank, + ScalarType dtype, + Device device + ) -> (Tensor, Tensor) + """ +) + + +@core.library.impl(lib, "all_gather_copy_in", "Meta") +def all_gather_copy_in_meta( + all_gather_inputs: List[core.Tensor], + inp_split_sizes: List[int], + all_gather_input_numel: int, + world_size: int, + rank: int, + dtype: core.dtype, + device: core.device, +) -> Tuple[core.Tensor, core.Tensor]: + all_gather_output = core.empty( + (all_gather_input_numel * world_size,), dtype=dtype, device="meta" + ) + all_gather_input = all_gather_output.narrow( + 0, all_gather_input_numel * rank, all_gather_input_numel + ) + return all_gather_input, all_gather_output + + +@core.library.impl(lib, "all_gather_copy_in", "CUDA") +@core.library.impl(lib, "all_gather_copy_in", "CPU") +def all_gather_copy_in_cuda( + all_gather_inputs: List[core.Tensor], + inp_split_sizes: List[int], + all_gather_input_numel: int, + world_size: int, + rank: int, + dtype: core.dtype, + device: core.device, +) -> Tuple[core.Tensor, core.Tensor]: + all_gather_output = core.empty( + (all_gather_input_numel * world_size,), dtype=dtype, device=device + ) + all_gather_input = all_gather_output.narrow( + 0, all_gather_input_numel * rank, all_gather_input_numel + ) + foreach_copy_dsts = core.split(all_gather_input, inp_split_sizes) + with core.no_grad(): + core._foreach_copy_(foreach_copy_dsts, all_gather_inputs) + return all_gather_input, all_gather_output + + +lib.define( + "split_with_sizes_copy(Tensor all_gather_output, SymInt[] all_gather_input_split_sizes, int dim=0, *, Tensor(a!)[] out) -> ()" +) + + +@core.library.impl(lib, "split_with_sizes_copy", "Meta") +@core.library.impl(lib, "split_with_sizes_copy", "CUDA") +@core.library.impl(lib, "split_with_sizes_copy", "CPU") +def split_with_sizes_copy( + all_gather_output: core.Tensor, + all_gather_input_split_sizes: List[int], + dim: int, + out: List[core.Tensor], +) -> None: + core.split_with_sizes_copy( + all_gather_output, all_gather_input_split_sizes, dim=dim, out=out + ) + + +lib.define( + "chunk_cat(Tensor[] tensors, int dim, int num_chunks, *, Tensor(a!) out) -> ()" +) + + +@core.library.impl(lib, "chunk_cat", "Meta") +@core.library.impl(lib, "chunk_cat", "CUDA") +@core.library.impl(lib, "chunk_cat", "CPU") +def chunk_cat( + tensors: List[core.Tensor], + dim: int, + num_chunks: int, + out: core.Tensor, +) -> None: + core._chunk_cat(tensors, dim, num_chunks, out=out) + + +@core.no_grad() +def foreach_all_gather( + fsdp_params: List[FSDPParam], + group: dist.ProcessGroup, + async_op: bool, + all_gather_copy_in_stream: core.Stream, + all_gather_stream: core.Stream, + device: core.device, +) -> Optional[AllGatherResult]: + world_size, rank = group.size(), group.rank() + device_handle = _get_device_handle(device.type) + with device_handle.stream(all_gather_copy_in_stream): + param_all_gather_inputs = _get_param_all_gather_inputs(fsdp_params) + ( + param_all_gather_input_dtypes, + param_all_gather_input_numels, + dtype, + ) = _get_all_gather_input_metadatas(param_all_gather_inputs) + if dtype == core.uint8: + all_gather_inputs = [ + t.view(core.uint8) for ts in param_all_gather_inputs for t in ts + ] + else: + all_gather_inputs = [t for ts in param_all_gather_inputs for t in ts] + inp_split_sizes = [t.numel() for t in all_gather_inputs] + all_gather_input_numel = sum(inp_split_sizes) + all_gather_input, all_gather_output = core.ops.fsdp.all_gather_copy_in( + all_gather_inputs, + inp_split_sizes, + all_gather_input_numel, + world_size, + rank, + dtype, + device, + ) + del param_all_gather_inputs + all_gather_stream.wait_stream(all_gather_copy_in_stream) + with device_handle.stream(all_gather_stream): + all_gather_work = dist.all_gather_into_tensor( + output_tensor=all_gather_output, + input_tensor=all_gather_input, + group=group, + async_op=async_op, + ) + all_gather_event = all_gather_stream.record_event() + return AllGatherResult( + all_gather_output, + all_gather_event, + all_gather_work, + param_all_gather_input_dtypes, + param_all_gather_input_numels, + inp_split_sizes, + ) + + +@core.no_grad() +def _get_param_all_gather_inputs( + fsdp_params: List[FSDPParam], +) -> List[List[core.Tensor]]: + if compiled_autograd_enabled(): + return [fsdp_param.all_gather_inputs for fsdp_param in fsdp_params] + + # Intentionally try to run a fast-path that bypasses abstractions for the + # common FSDP case of bf16/fp32 mixed precision in order to use foreach + # copy for lower CPU overhead and more efficient copying in eager + def use_foreach_copy(fsdp_param: FSDPParam) -> bool: + return ( + fsdp_param.param_dtype is not None + and not fsdp_param.offload_to_cpu + and not hasattr(fsdp_param._sharded_local_tensor, "fsdp_pre_all_gather") + ) + + param_all_gather_inputs: List[List[core.Tensor]] = [[] for _ in fsdp_params] + foreach_copy_indices: List[int] = [] + foreach_copy_inputs: List[core.Tensor] = [] + foreach_copy_input_numels: List[int] = [] + + # 1st pass: for foreach-copy parameters, get inputs and metadata for the + # foreach copy, and for the others, actually get their all-gather inputs + for i, fsdp_param in enumerate(fsdp_params): + if use_foreach_copy(fsdp_param): + foreach_copy_indices.append(i) + all_gather_input = ( + fsdp_param._sharded_param_data + if fsdp_param.sharded_state == ShardedState.SHARDED + else cast(core.Tensor, fsdp_param._sharded_post_forward_param_data) + ) + foreach_copy_inputs.append(all_gather_input) + foreach_copy_input_numels.append(all_gather_input.numel()) + else: + param_all_gather_inputs[i] = fsdp_param.all_gather_inputs + + # 2nd pass: use foreach copy to compute the remaining all-gather inputs + if foreach_copy_inputs: + fsdp_param_0 = fsdp_params[foreach_copy_indices[0]] + param_dtype, device = fsdp_param_0.param_dtype, fsdp_param_0.device + flat_foreach_copy_input = core.empty( + (sum(foreach_copy_input_numels),), device=device, dtype=param_dtype + ) + splits = core.split(flat_foreach_copy_input, foreach_copy_input_numels) + core._foreach_copy_(splits, foreach_copy_inputs) + for i, split in zip(foreach_copy_indices, splits): + param_all_gather_inputs[i] = [split] + + return param_all_gather_inputs + + +@core.no_grad() +def foreach_all_gather_copy_out( + all_gather_result: AllGatherResult, + fsdp_params: List[FSDPParam], + group: dist.ProcessGroup, +) -> None: + ( + all_gather_output, + all_gather_event, + all_gather_work, + param_all_gather_input_dtypes, + param_all_gather_input_numels, + all_gather_input_split_sizes, + ) = all_gather_result + _dtype, device = all_gather_output.dtype, all_gather_output.device + device_handle = _get_device_handle(device.type) + if all_gather_event is not None: # sync op + device_handle.current_stream().wait_event(all_gather_event) + if isinstance(all_gather_work, dist.distributed_c10d.Work): # async op + all_gather_work.wait() + world_size, device = group.size(), all_gather_output.device + + split_with_sizes_out: List[core.Tensor] = [] + shard_i_copy_infos: List[Tuple[FSDPParam, List[core.Tensor]]] = [] + for all_gather_input_numels, all_gather_input_dtypes, fsdp_param in zip( + param_all_gather_input_numels, param_all_gather_input_dtypes, fsdp_params + ): + # NOTE: Under compile, make sure we always recreate all_gather_outputs + # per AllGather. See [Note: Invariants for core.compile Traceable FSDP2]. + force_recreate = compiled_autograd_enabled() + fsdp_param.init_all_gather_outputs( + all_gather_input_numels, + all_gather_input_dtypes, + world_size, + device, + force_recreate=force_recreate, + ) + if not force_recreate: + fsdp_param.alloc_all_gather_outputs() + param_all_gather_outputs = fsdp_param.all_gather_outputs + if fsdp_param.fsdp_placement.dim != 0: + # Copy to a temporary and then chunk-cat into the final all-gather + # output tensors + param_all_gather_outputs = [ + core.empty_like(t) for t in param_all_gather_outputs + ] + shard_i_copy_infos.append((fsdp_param, param_all_gather_outputs)) + split_with_sizes_out.extend(param_all_gather_outputs) + + all_gather_output = all_gather_output.view(world_size, -1) + if all_gather_output.dtype == core.uint8: + out = [t.view(world_size, -1).view(core.uint8) for t in split_with_sizes_out] + else: + out = [t.view(world_size, -1) for t in split_with_sizes_out] + core.ops.fsdp.split_with_sizes_copy( + all_gather_output, all_gather_input_split_sizes, dim=1, out=out + ) + + for fsdp_param, param_all_gather_outputs in shard_i_copy_infos: + # Chunk-cat from the temporary to the final all-gather output tensors + shard_dim = fsdp_param.fsdp_placement.dim + for param_all_gather_output, target_all_gather_output in zip( + param_all_gather_outputs, fsdp_param.all_gather_outputs + ): + padded_sharded_size = ( + fsdp_param.padded_sharded_param_size + if fsdp_param.sharded_state == ShardedState.SHARDED + else cast( + core.Tensor, fsdp_param._sharded_post_forward_param_data + ).size() + ) + pre_param_size = list(padded_sharded_size) + pre_param_size[0] *= world_size + chunks = core.chunk( + param_all_gather_output.view(pre_param_size), world_size, dim=0 + ) + post_param_size = list(padded_sharded_size) + post_param_size[shard_dim] *= world_size + cat_out = target_all_gather_output.view(post_param_size) + core.cat(chunks, dim=shard_dim, out=cat_out) + core._C._autograd._unsafe_set_version_counter( + target_all_gather_output, target_all_gather_output._version - 1 + ) + + +@core.no_grad() +def foreach_reduce( + fsdp_params: List[FSDPParam], + unsharded_grads: List[core.Tensor], + reduce_scatter_group: dist.ProcessGroup, + reduce_scatter_stream: core.Stream, + orig_dtype: core.dtype, + reduce_dtype: Optional[core.dtype], + device: core.device, + reduce_scatter_reduce_op: Optional[Union[dist.ReduceOp, dist.ReduceOp.RedOpType]], + all_reduce_group: Optional[dist.ProcessGroup], # not `None` iff HSDP + all_reduce_stream: core.Stream, + all_reduce_grads: bool, + partial_reduce_output: Optional[core.Tensor], # only used for HSDP +) -> Tuple[ + core.Tensor, + core.Event, + core.Event, + Optional[core.Tensor], + Optional[core.Event], + Optional[core.Tensor], +]: + """ + ``unsharded_grads`` owns the references to the gradients computed by + autograd, so clearing the list frees the gradients. + """ + grad_dtypes = {grad.dtype for grad in unsharded_grads} + if len(grad_dtypes) != 1: + # Check this at runtime since it could be a real runtime error if e.g. + # fp8 weights do not produce the correct higher precision gradients + _raise_assert_with_print( + f"FSDP reduce-scatter expects uniform gradient dtype but got {grad_dtypes}" + ) + grad_dtype = unsharded_grads[0].dtype + reduce_dtype = reduce_dtype or grad_dtype + predivide_factor, postdivide_factor = _get_gradient_divide_factors( + reduce_scatter_group, all_reduce_group, reduce_dtype + ) + world_size = reduce_scatter_group.size() + for i, (fsdp_param, unsharded_grad) in enumerate(zip(fsdp_params, unsharded_grads)): + if (shard_dim := fsdp_param.fsdp_placement.dim) == 0: + continue + assert ( + unsharded_grad.size(shard_dim) % world_size == 0 + ), f"Shard({shard_dim}) requires even sharding: {unsharded_grad.size()=} {world_size=}" + chunks = core.chunk(unsharded_grad, world_size, dim=shard_dim) + unsharded_grads[i] = core.cat(chunks, dim=0) + padded_unsharded_sizes = tuple( + _get_dim0_padded_size(grad.size(), world_size) for grad in unsharded_grads + ) + reduce_scatter_input_numel = sum(s.numel() for s in padded_unsharded_sizes) + reduce_scatter_output_numel = reduce_scatter_input_numel // world_size + reduce_scatter_input = core.empty( + (reduce_scatter_input_numel,), dtype=reduce_dtype, device=device + ) + device_handle = _get_device_handle(device.type) + foreach_reduce_scatter_copy_in(unsharded_grads, reduce_scatter_input, world_size) + current_stream = device_handle.current_stream() + # Only after the copy-in finishes can we free the gradients + unsharded_grads.clear() + reduce_scatter_stream.wait_stream(current_stream) + all_reduce_input = None + all_reduce_event = None + with device_handle.stream(reduce_scatter_stream): + reduce_output = reduce_scatter_input.new_empty((reduce_scatter_output_numel,)) + _div_if_needed(reduce_scatter_input, predivide_factor) + if reduce_scatter_reduce_op is None: + if predivide_factor is None: + reduce_scatter_reduce_op = ReduceOp.AVG + else: + reduce_scatter_reduce_op = ReduceOp.SUM + dist.reduce_scatter_tensor( + output=reduce_output, + input=reduce_scatter_input, + group=reduce_scatter_group, + op=reduce_scatter_reduce_op, + ) + reduce_scatter_event = reduce_scatter_stream.record_event() + post_reduce_stream = reduce_scatter_stream + if all_reduce_group is not None: # HSDP + # Accumulations must run in the reduce-scatter stream + if not all_reduce_grads: + if partial_reduce_output is not None: + partial_reduce_output += reduce_output + else: + partial_reduce_output = reduce_output + return ( + reduce_scatter_input, + reduce_scatter_event, + post_reduce_stream.record_event(), + all_reduce_input, + all_reduce_event, + partial_reduce_output, + ) + if partial_reduce_output is not None: + reduce_output += partial_reduce_output + post_reduce_stream = all_reduce_stream + all_reduce_stream.wait_stream(reduce_scatter_stream) + with device_handle.stream(all_reduce_stream): + dist.all_reduce( + reduce_output, + group=all_reduce_group, + op=ReduceOp.AVG if predivide_factor is None else ReduceOp.SUM, + ) + all_reduce_input = reduce_output + all_reduce_event = all_reduce_stream.record_event() + with device_handle.stream(post_reduce_stream): + _div_if_needed(reduce_output, postdivide_factor) + reduce_output = _to_dtype_if_needed(reduce_output, orig_dtype) + # View out and accumulate sharded gradients + flat_grad_offset = 0 # [0, reduce_scatter_output_numel - 1] + for padded_unsharded_size, fsdp_param in zip( + padded_unsharded_sizes, fsdp_params + ): + # Assume even sharding for Shard(i), i > 0; otherwise would require + # copy-out for contiguous strides + new_sharded_grad = core.as_strided( + reduce_output, + size=fsdp_param.sharded_size, + stride=fsdp_param.contiguous_sharded_stride, + storage_offset=flat_grad_offset, + ) + to_accumulate_grad = fsdp_param.sharded_param.grad is not None + if fsdp_param.offload_to_cpu: + # Only overlap the D2H copy (copying to pinned memory) if not + # accumulating gradients since the CPU add kernel depends on + # the copy result and we cannot run the add as a callback + non_blocking = fsdp_param.pin_memory and not to_accumulate_grad + # Since the GPU sharded gradient is allocated in the RS stream, + # we can free it here by not keeping a ref without waiting for + # the D2H copy since future RS-stream ops run after the copy + new_sharded_grad = new_sharded_grad.to( + core.device("cpu"), non_blocking=non_blocking + ) + if non_blocking: + # Record an event on which to block the CPU thread to + # ensure that the D2H copy finishes before the optimizer + fsdp_param.grad_offload_event = reduce_scatter_stream.record_event() + if to_accumulate_grad: + assert isinstance(fsdp_param.sharded_param.grad, DTensor) + fsdp_param.sharded_param.grad._local_tensor += new_sharded_grad + else: + new_sharded_dtensor_grad = fsdp_param.to_sharded_dtensor( + new_sharded_grad + ) + fsdp_param.sharded_param.grad = new_sharded_dtensor_grad + if not compiled_autograd_enabled(): + for hook in ( + getattr(fsdp_param.sharded_param, "_post_accumulate_grad_hooks", {}) + or {} + ).values(): + hook(fsdp_param.sharded_param) + padded_sharded_numel = padded_unsharded_size.numel() // world_size + flat_grad_offset += padded_sharded_numel + post_reduce_event = post_reduce_stream.record_event() + # The RS output is allocated in the RS stream and used in the default + # stream (for optimizer). To ensure its memory is not reused for later + # RSs, we do not need extra synchronization since the sharded parameters + # hold refs through the end of backward. + return ( + reduce_scatter_input, + reduce_scatter_event, + post_reduce_event, + all_reduce_input, + all_reduce_event, + None, + ) + + +def foreach_reduce_scatter_copy_in( + unsharded_grads: List[core.Tensor], + reduce_scatter_input: core.Tensor, + world_size: int, +) -> None: + reduce_scatter_input = reduce_scatter_input.view(world_size, -1) + core.ops.fsdp.chunk_cat( + unsharded_grads, dim=0, num_chunks=world_size, out=reduce_scatter_input + ) + + +def _get_all_gather_input_metadatas( + param_all_gather_inputs: List[List[core.Tensor]], +) -> Tuple[List[List[core.dtype]], List[List[int]], core.dtype]: + param_all_gather_input_dtypes: List[List[core.dtype]] = [] + param_all_gather_input_numels: List[List[int]] = [] + all_gather_dtype = param_all_gather_inputs[0][0].dtype + for all_gather_inputs in param_all_gather_inputs: + input_dtypes: List[core.dtype] = [] + input_numels: List[int] = [] + for all_gather_input in all_gather_inputs: + if all_gather_input.dtype != all_gather_dtype: + all_gather_dtype = core.uint8 + input_dtypes.append(all_gather_input.dtype) + input_numels.append(all_gather_input.numel()) + param_all_gather_input_dtypes.append(input_dtypes) + param_all_gather_input_numels.append(input_numels) + return ( + param_all_gather_input_dtypes, + param_all_gather_input_numels, + all_gather_dtype, + ) + + +def _get_gradient_divide_factors( + reduce_scatter_group: dist.ProcessGroup, + all_reduce_group: Optional[dist.ProcessGroup], + reduce_dtype: core.dtype, +) -> Union[Tuple[None, None], Tuple[float, float]]: + # For fp32/bf16, we do not need to worry about overflow/underflow, so we + # use NCCL's built-in division to avoid separate div kernels + if reduce_dtype in (core.float32, core.bfloat16): + return None, None + data_parallel_size = reduce_scatter_group.size() + if all_reduce_group is not None: + data_parallel_size *= all_reduce_group.size() + # Since fp16 has smaller dynamic range than fp32/bf16, we want to avoid + # overflow/underflow. For N data parallel workers, each worker computes + # g_i, and they collectively reduce (g_1 + ... + g_N) / N. To avoid + # overflow/underflow, we divide by ~sqrt(N) before/after the reduction. + factor: int = 1 + while data_parallel_size % factor == 0 and data_parallel_size / factor > factor: + factor *= 2 + factor = float(factor) + return (factor, data_parallel_size / factor) + + +def _div_if_needed(tensor: core.Tensor, div_factor: Optional[float]) -> None: + if div_factor is not None and div_factor > 1: + tensor.div_(div_factor) diff --git a/mindnlp/core/distributed/_composable/fsdp/_fsdp_common.py b/mindnlp/core/distributed/_composable/fsdp/_fsdp_common.py new file mode 100644 index 000000000..fb9527537 --- /dev/null +++ b/mindnlp/core/distributed/_composable/fsdp/_fsdp_common.py @@ -0,0 +1,183 @@ +# mypy: allow-untyped-defs +import math +import traceback +from dataclasses import dataclass +from enum import auto, Enum +from typing import Any, cast, List, Optional + +from mindnlp import core +from mindnlp import core.distributed as dist +from mindnlp import core.nn as nn +from core.distributed._composable.contract import _get_registry +from core.distributed.tensor import DeviceMesh, DTensor +from core.distributed.tensor._dtensor_spec import DTensorSpec + + +_compiled_autograd_enabled: bool = False + +if core._running_with_deploy(): + + def detect_compiled_autograd(): + pass + + def compiled_autograd_enabled(): + return False + +else: + + def detect_compiled_autograd(): + assert ( + not core.compiler.is_compiling() + ), "`detect_compiled_autograd()` is designed to be called in eager mode" + global _compiled_autograd_enabled + from mindnlp import core._dynamo.compiled_autograd as ca + + _compiled_autograd_enabled = ( + ca.compiled_autograd_enabled + or ca.compiled_autograd_enabled_force_eager + or ca.in_compiled_autograd_region + ) + + def compiled_autograd_enabled(): + global _compiled_autograd_enabled + return _compiled_autograd_enabled + + +@dataclass +class DataParallelMeshInfo: + mesh: DeviceMesh + shard_mesh_dim: Optional[int] = None + replicate_mesh_dim: Optional[int] = None + + def __post_init__(self): + if self.shard_mesh_dim is None and self.replicate_mesh_dim is None: + raise AssertionError( + "At least one of shard_mesh_dim and replicate_mesh_dim must not be None" + ) + + +@dataclass +class FSDPMeshInfo(DataParallelMeshInfo): + def __post_init__(self): + super().__post_init__() + if self.shard_mesh_dim is None: + raise AssertionError("Expects non-None shard_mesh_dim") + self.shard_mesh_size: int = self.mesh.size(self.shard_mesh_dim) + self.shard_process_group = self.mesh.get_group(self.shard_mesh_dim) + self.shard_mesh_rank: int = self.shard_process_group.rank() + + +@dataclass +class DDPMeshInfo(DataParallelMeshInfo): + def __post_init__(self): + super().__post_init__() + if self.replicate_mesh_dim is None: + raise AssertionError("Expects non-None replicate_mesh_dim") + self.replicate_mesh_size: int = self.mesh.size(self.replicate_mesh_dim) + self.replicate_process_group = self.mesh.get_group(self.replicate_mesh_dim) + self.replicate_mesh_rank: int = self.replicate_process_group.rank() + + +@dataclass +class HSDPMeshInfo(FSDPMeshInfo, DDPMeshInfo): + def __post_init__(self): + # Calls `FSDPMeshInfo` -> `DDPMeshInfo` -> `DataParallelMeshInfo` + super().__post_init__() + + +class TrainingState(Enum): + """Describes the training state of one FSDP state / parameter group.""" + + # Transition to forward starting pre-forward until post-forward + FORWARD = auto() + # Transition to pre-backward when unsharding in backward + PRE_BACKWARD = auto() + # Transition to post-backward when resharding and reducing gradients + POST_BACKWARD = auto() + # Idle before/after forward or before pre-backward/after post-backward + IDLE = auto() + + +def _raise_assert_with_print(*args: Any, **kwargs: Any): + print(f"[Rank {dist.get_rank()}] ", end="") + print(*args, **kwargs) + traceback.print_stack() + raise AssertionError(*args, **kwargs) + + +def _is_composable_with_fsdp(module: nn.Module) -> bool: + registry = _get_registry(module) + if registry is None: + return True + # Registry keys by function name + return "replicate" not in registry + + +def _get_dim0_padded_size(tensor_size: core.Size, dim0_factor: int) -> core.Size: + padded_dim0 = math.ceil(tensor_size[0] / dim0_factor) * dim0_factor + return cast(core.Size, core.Size([padded_dim0]) + tensor_size[1:]) + + +def _chunk_with_empty( + tensor: core.Tensor, num_chunks: int, dim: int +) -> List[core.Tensor]: + chunks = list(core.chunk(tensor, num_chunks, dim=dim)) + while len(chunks) < num_chunks: + chunks.append(chunks[0].new_empty(0)) + return chunks + + +def _get_dim_chunked_size( + chunk: core.Tensor, unchunked_size: core.Size, dim: int +) -> core.Size: + if chunk.numel() > 0: + return chunk.size() + # For 0 numel, we need to preserve nonzero-sized dims for DTensor APIs + return cast( + core.Size, unchunked_size[:dim] + core.Size([0]) + unchunked_size[dim + 1 :] + ) + + +def _from_local_no_grad( + local_tensor: core.Tensor, + sharding_spec: DTensorSpec, +) -> DTensor: + """ + This method is similar to ``DTensor.from_local()`` except that in eager mode + it avoids some CPU overhead by avoiding default args and not being differentiable. + """ + + if not compiled_autograd_enabled(): + return DTensor( + # Use the local tensor directly instead of constructing a new tensor + # variable, e.g. with `view_as()`, since this is not differentiable + local_tensor, + sharding_spec, + requires_grad=local_tensor.requires_grad, + ) + else: + return DTensor.from_local( + local_tensor, + sharding_spec.mesh, + sharding_spec.placements, + shape=sharding_spec.shape, + stride=sharding_spec.stride, + ) + + +def _to_dtype_if_needed( + tensor: core.Tensor, dtype: Optional[core.dtype] +) -> core.Tensor: + if dtype is not None and tensor.dtype != dtype: + return tensor.to(dtype) + return tensor + + +def _cast_fp_tensor(dtype: core.dtype, x: core.Tensor) -> core.Tensor: + if ( + not isinstance(x, core.Tensor) + or not core.is_floating_point(x) + or x.dtype == dtype + ): + return x + return x.to(dtype) diff --git a/mindnlp/core/distributed/_composable/fsdp/_fsdp_init.py b/mindnlp/core/distributed/_composable/fsdp/_fsdp_init.py new file mode 100644 index 000000000..532c060f3 --- /dev/null +++ b/mindnlp/core/distributed/_composable/fsdp/_fsdp_init.py @@ -0,0 +1,168 @@ +import itertools +from typing import List, Optional, Set, Tuple, Union + +from mindnlp import core +from mindnlp import core.distributed as dist +from mindnlp import core.nn as nn +from core.distributed.device_mesh import _get_device_handle +from core.distributed.tensor import DeviceMesh, DTensor, init_device_mesh +from core.utils._python_dispatch import is_traceable_wrapper_subclass + +from ._fsdp_common import _is_composable_with_fsdp, FSDPMeshInfo, HSDPMeshInfo +from ._fsdp_state import _get_module_fsdp_state + + +def _get_post_forward_mesh_info( + reshard_after_forward: Union[bool, int], mesh_info: FSDPMeshInfo +) -> Optional[FSDPMeshInfo]: + shard_mesh_size = mesh_info.shard_mesh_size + if not isinstance(reshard_after_forward, (bool, int)): + raise ValueError( + "reshard_after_forward should be a bool or an int representing the " + f"group size to reshard to, not {reshard_after_forward}" + ) + # NOTE: `isinstance(False, int)` returns `True`. + if not isinstance(reshard_after_forward, bool) and isinstance( + reshard_after_forward, int + ): + if ( + reshard_after_forward < 1 + or reshard_after_forward > shard_mesh_size + or shard_mesh_size % reshard_after_forward != 0 + ): + raise ValueError( + "If passing reshard_after_forward as an int, it should be a " + f"factor of {shard_mesh_size}, not {reshard_after_forward}" + ) + elif reshard_after_forward == 1: + reshard_after_forward = False + elif reshard_after_forward == shard_mesh_size: + reshard_after_forward = True + post_forward_mesh_info = None + if reshard_after_forward is True: + post_forward_mesh_info = mesh_info + elif reshard_after_forward is not False: # int case + # For HSDP, we can flatten the two replicate dims into the 0th dim + post_forward_mesh_tensor = mesh_info.mesh.mesh.view(-1, reshard_after_forward) + post_forward_mesh = DeviceMesh( + mesh_info.mesh.device_type, post_forward_mesh_tensor + ) + post_forward_mesh_info = HSDPMeshInfo( + post_forward_mesh, shard_mesh_dim=1, replicate_mesh_dim=0 + ) + return post_forward_mesh_info + + +def _init_default_fully_shard_mesh() -> DeviceMesh: + """Default to global CUDA mesh if possible else global CPU mesh.""" + if not dist.distributed_c10d.is_initialized(): + dist.distributed_c10d.init_process_group() + default_pg = dist.distributed_c10d._get_default_group() + device = core._C._get_accelerator() + mesh = init_device_mesh(device.type, mesh_shape=(default_pg.size(),)) + return mesh + + +def _get_device_from_mesh(mesh: DeviceMesh) -> core.device: + if mesh.device_type == "cpu": + return core.device("cpu") + device_handle = _get_device_handle(mesh.device_type) + return core.device(mesh.device_type, device_handle.current_device()) + + +def _get_managed_modules(root_modules: Tuple[nn.Module, ...]) -> List[nn.Module]: + modules: List[nn.Module] = [] + root_modules_set = set(root_modules) + # Track visisted modules to avoid visiting shared modules multiple times + visited_modules: Set[nn.Module] = set() + + def dfs(module: nn.Module) -> None: + """ + Runs a DFS to collect managed modules, not recursing into modules with + a non-composable API or ``fully_shard`` already applied. + """ + if not _is_composable_with_fsdp(module): + return + elif ( + module not in root_modules_set + and _get_module_fsdp_state(module) is not None + ): + return # nested `fully_shard` module + visited_modules.add(module) + for submodule in module.children(): + if submodule not in visited_modules: + dfs(submodule) + modules.append(module) + + for root_module in root_modules: + dfs(root_module) + return modules + + +def _verify_managed_param(name: str, param: nn.Parameter) -> None: + """ + Verify if the parameter is accepted by fully_shard. The only restriction now + is that the parameter cannot be a scalar tensor (param.numel == 0) since we + need at least one dim to shard. + """ + if len(param.shape) == 0: + raise ValueError( + "fully_shard doesn't support salar parameters. " + f"Change {name} to a 1D tensor with numel equal to 1." + ) + + +def _get_managed_states( + modules: List[nn.Module], +) -> Tuple[List[nn.Parameter], List[core.Tensor]]: + params: List[nn.Parameter] = [] + buffers: List[core.Tensor] = [] + # Track visited parameters/buffers to avoid visiting shared parameters and + # buffers multiple times + visited_params: Set[nn.Parameter] = set() + visited_buffers: Set[core.Tensor] = set() + for module in modules: + for name, param in module.named_parameters(recurse=False): + if param not in visited_params: + _verify_managed_param(name, param) + params.append(param) + visited_params.add(param) + for buffer in module.buffers(recurse=False): + if buffer not in visited_buffers: + buffers.append(buffer) + visited_buffers.add(buffer) + return params, buffers + + +def _move_states_to_device( + params: List[nn.Parameter], + buffers: List[core.Tensor], + device: core.device, +) -> None: + """ + We have FSDP move states to device for simpler and faster initialization + since FSDP almost always uses CUDA for training. We move parameters/buffers + rather than modules since modules to support ignoring parameters/buffers in + the future. + """ + # Follow the logic in `nn.Module._apply` + for tensor in itertools.chain(params, buffers): + if tensor.device == device or tensor.device.type == "meta": + # Keep meta-device tensors on meta device for deferred init + continue + if isinstance(tensor, DTensor): + if (dtensor_mesh_type := tensor.device_mesh.device_type) != device.type: + raise ValueError( + "Requires DTensor to have mesh of the same type as the FSDP mesh " + f"but got {dtensor_mesh_type} for DTensor and {device.type} for FSDP" + ) + raise AssertionError( + f"Expects DTensor to be moved to {dtensor_mesh_type} but got {tensor.device}" + ) + tensor_ = tensor + if is_traceable_wrapper_subclass(tensor_): + with core.no_grad(): # avoid autograd increasing C++ refcount by 1 + tensor_on_device = nn.Parameter(tensor.to(device)) + core.utils.swap_tensors(tensor, tensor_on_device) + else: + tensor.data = tensor.to(device) diff --git a/mindnlp/core/distributed/_composable/fsdp/_fsdp_param.py b/mindnlp/core/distributed/_composable/fsdp/_fsdp_param.py new file mode 100644 index 000000000..0bd8050da --- /dev/null +++ b/mindnlp/core/distributed/_composable/fsdp/_fsdp_param.py @@ -0,0 +1,880 @@ +# mypy: allow-untyped-defs +import inspect +import itertools +from dataclasses import dataclass, field +from enum import auto, Enum +from typing import Any, Callable, cast, List, Optional, Sequence, Tuple + +from mindnlp import core +from mindnlp import core.nn as nn +from core._prims_common import make_contiguous_strides_for +from core.distributed._functional_collectives import AsyncCollectiveTensor +from core.distributed.tensor import DTensor, Replicate, Shard +from core.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta +from core.distributed.tensor.device_mesh import _mesh_resources +from core.distributed.tensor.placement_types import _StridedShard, Placement + +from ._fsdp_api import CPUOffloadPolicy, MixedPrecisionPolicy, OffloadPolicy +from ._fsdp_common import ( + _chunk_with_empty, + _from_local_no_grad, + _get_dim_chunked_size, + _raise_assert_with_print, + _to_dtype_if_needed, + compiled_autograd_enabled, + FSDPMeshInfo, + HSDPMeshInfo, +) + + +""" +[Note: FSDP tensors] +FSDP considers the following tensors: +- Original parameter: parameter passed to :class:`FSDPParam`, i.e. the one + on the module when applying FSDP +- Sharded parameter: sharding the original parameter on dim-0 (or a + user-specified dim) as a DTensor over the main mesh +- All-gather inputs: the ``core.Tensor`` or ``Tensor`` s passed to all-gather, + derived from the sharded parameter +- All-gather output: the ``core.Tensor`` or ``Tensor`` s resulting from + all-gathering the all-gather inputs +- Unsharded parameter: parameter used for forward/backward computation, derived + from the all-gather output; autograd leaf + +We define these tensors to describe the general framework that can accomodate +extensions, where: +- all-gather-inputs = pre-all-gather-transform(sharded-parameter) +- unsharded-parameter = post-all-gather-transform(all-gather-outputs) + +For the default ``core.Tensor`` case, there is only one all-gather input, and +it shares the same underlying tensor data as the sharded parameter, meaning +that they can be thought of as the same tensors. The same applies for the +all-gather output and unsharded parameter. For non-``core.Tensor`` extensions, +these equivalences may no longer hold due to the pre/post-all-gather +transforms, and some may have multiple all-gather inputs/outputs (e.g. +quantized data and scales). + +[Note: FSDP and autograd] +FSDP dynamically frees and allocates the unsharded parameter. Since autograd +can pack a reference to it or a view to save for backward, we use storage +resizing to implement the freeing/allocation since that preserves the aliasing. +This implies that we construct the unsharded parameter object once and write to +it in-place thereafter. For the default ``core.Tensor` original parameter +case, the all-gather output and unsharded parameter share the same +data, so we use storage resizing on the all-gather output. +""" + +lib = core.library.Library("fsdp", "FRAGMENT") # noqa: TOR901 + +lib.define("copy_(Tensor(a!) tensor, Tensor data) -> ()") + + +@core.library.impl(lib, "copy_", "Meta") +@core.library.impl(lib, "copy_", "CUDA") +@core.library.impl(lib, "copy_", "CPU") +def copy_(tensor, data): + tensor.copy_(data) + + +""" +[Note: Avoiding functionalization for fsdp.copy_ and inductor.resize_storage_bytes_] + +Currently we don't functionalize `fsdp.copy_` op or `inductor.resize_storage_bytes_` op +(i.e. they show up as a mutation op in the middle of the AOT joint graph). + +Reason: +Traceable FSDP2 compiled autograd BWD graph have the following traits: +(1) Two inputs of the graph were aliased to each other (one from hook closed-over tensors, one from FWD saved tensors). +(2) One of them is mutated (copy_ and resize_ to handle the all-gathered param). +(3) They are both subclasses. +The combination of these traits is not supported by AOTAutograd (it's difficult to reason about subclass aliasing). +So this doesn't work at all for Traceable FSDP2. + +The compromise we use is to avoid functionalization for the FSDP2 copy_ and resize_ ops. +This avoids the problem above, because from AOTAutograd point-of-view there are no mutations +that functionalization needs to handle. (Although we need to be careful not to DCE those mutable ops.) + +We can avoid this functionalization because: +(1) The nn.Parameter is never used before its .copy_() is called in eager code (i.e. no alias of it is created), +so it's safe to call .copy_() in the middle of the graph to update its content and start using the nn.Parameter downstream. +(2) We always re-allocate the buffer for nn.Parameter to store the AllGather output and to be used in downstream user ops. +So calling resize-to-0 in the middle of the graph to free nn.Parameter memory after use should always be okay +(since we always allocate anew next time we need it, we strictly don't need to keep the old tensor storage around anymore). + +Q: Wouldn't the extra resize_ and copy_ ops hurt both memory usage and performance? +A: Yes it would. As an optimization, we have an Inductor post-grad FX pass to remove those resize_ and copy_ ops +for unsharded params that have this pattern: resize_(full) -> copy_ -> resize_(0). + +TODO: +Now that we are maintaining the invariant of "no aliased + mutated graph inputs" in both the forward and backward, +it is now more feasible to functionalize all of the mutable FSDP ops. Some of the pros and cons are: + +Cons (of functionalizing those ops): +(1) By not functionalizing them as we are today, we are making it more likely that they will run at the "correct" time +in the generated code. If we start to functionalize them, we will need to make sure that Inductor reinplaces them +in a way where it properly moves the mutations back to exactly where they should have run, or we risk suffering worse +peak memory than eager. (We probably already need to do something similar in Inductor's reinplacing for copy_: +https://github.com/pytorch/pytorch/issues/135305#issuecomment-2334888089) + +Pros (of functionalizing): +(1) Better safety, we don't need to worry about the graph passes in inductor/partitioning handling input mutations +mid-graph quite as much (to be fair we've already done some amount of auditing, but we might have to do some more). +(2) Better perf: each mutation midway through the graph prevents Inductor from pattern matching across it. +But maybe there are few enough mutations induced by FSDP for this to matter. +""" + + +@core.library.impl(lib, "copy_", "Functionalize") +def copy__functionalize(tensor, data): + core._sync(tensor) + core._sync(data) + tensor_inner = core._from_functional_tensor(tensor) + data_inner = core._from_functional_tensor(data) + with core._C._ExcludeDispatchKeyGuard( + core._C.DispatchKeySet(core._C.DispatchKey.Functionalize) + ): + core.ops.fsdp.copy_.default(tensor_inner, data_inner) + + +core.fx.node.has_side_effect(core.ops.fsdp.copy_.default) + + +class ShardedState(Enum): + """ + - ``SHARDED``: The sharded parameter is registered to the module. It is the + only contributor to parameter memory. + - ``SHARDED_POST_FORWARD``: The unsharded parameter is resharded to a + smaller world size. Since this data should not be used for computation, + we do not register it to the module. Users should reshard the module + before any in-place modifications. Both it and the sharded parameter + contribute to parameter memory. + - ``UNSHARDED``: The unsharded parameter is registered to the module. Both + it and the sharded parameter contribute to parameter memory. + """ + + SHARDED = auto() + SHARDED_POST_FORWARD = auto() + UNSHARDED = auto() + + +@dataclass +class ParamModuleInfo: + """ + For a parameter, this stores the module and the parameter name to be able + to do a parameter swap via ``setattr(module, param_name, ...)`` or to get + the parameter via ``getattr(module, param_name)``. We additionally save + shared modules and shared parameter names to update them accordingly. + """ + + # Parameter names are unprefixed, e.g. "weight", not "lin.weight" + module: nn.Module + param_name: str + shared_modules: List[nn.Module] = field(default_factory=list) + shared_param_names: List[str] = field(default_factory=list) + + +@dataclass +class ExtensionsData: + # User-defined metadata passed from pre to post-all-gather + all_gather_metadata: Optional[Any] = None + # Save the all-gather input sizes to unflatten the all-gather outputs to ND + all_gather_input_sizes: Sequence[core.Size] = () # ND + + def clear(self): + self.all_gather_metadata = None + self.all_gather_input_sizes = () + + +class FSDPParam: + """ + This class manages a parameter with FSDP or FSDP variants applied, + implementing dim-0 per-parameter sharding. + """ + + orig_dtype: core.dtype + param_dtype: Optional[core.dtype] + reduce_dtype: Optional[core.dtype] + _orig_size: core.Size # ND + sharded_size: core.Size # ND + contiguous_sharded_stride: Tuple[int, ...] + padded_sharded_param_size: core.Size # ND + sharded_post_forward_size: core.Size # ND + contiguous_sharded_post_forward_stride: Tuple[int, ...] + _sharded_param_data: core.Tensor # 1D + sharded_param: nn.Parameter # ND + _sharded_post_forward_param_data: Optional[core.Tensor] # 1D + _sharded_post_forward_param: Optional[nn.Parameter] # ND + _unsharded_param: nn.Parameter # ND + unsharded_accumulated_grad: Optional[core.Tensor] # ND + _sharding_spec: DTensorSpec + # DTensor attributes (only defined for DTensor `param`): + _tp_spec: DTensorSpec + all_gather_outputs: List[core.Tensor] # 1D + # All-gather extension attributes + _extensions_data: ExtensionsData + _unsharded_inner_tensors: List[core.Tensor] + + def __init__( + self, + param: nn.Parameter, + module_info: ParamModuleInfo, + mesh_info: FSDPMeshInfo, + post_forward_mesh_info: Optional[FSDPMeshInfo], + device: core.device, + shard_placement_fn: Optional[Callable[[nn.Parameter], Optional[Shard]]], + mp_policy: MixedPrecisionPolicy, + offload_policy: OffloadPolicy, + ): + self._module_info: ParamModuleInfo = module_info + self.mesh_info = mesh_info + self.post_forward_mesh_info = post_forward_mesh_info + self.device = device + self.mp_policy = mp_policy + self.offload_to_cpu: bool = isinstance(offload_policy, CPUOffloadPolicy) + self.pin_memory = ( + self.offload_to_cpu and cast(CPUOffloadPolicy, offload_policy).pin_memory + ) + self.grad_offload_event: Optional[core.Event] = None + self._init_sharded_param(param, device, shard_placement_fn) + if self.post_forward_mesh_info: + self._init_sharded_post_forward_param_metadata(param) + self._init_extensions() + self.all_gather_outputs: List[core.Tensor] = [] + self.unsharded_accumulated_grad = None + self._param_fqn: Optional[str] = None # prefixed from root module + # TODO: Remove this padding logic once DTensor pads the local tensor: + # https://github.com/pytorch/pytorch/issues/113045 + self._post_load_hook_handle = ( + module_info.module.register_load_state_dict_post_hook( + lambda *args, **kwargs: self.reset_sharded_param() + ) + ) + + @core.no_grad() + def _init_sharded_param( + self, + param: nn.Parameter, + device: core.device, + shard_placement_fn: Optional[Callable], + ): + if param.device != device and param.device.type != "meta": + raise AssertionError( + f"Expects the parameter to already be moved to device {device} but got {param.device}" + ) + if not param.is_contiguous(): + raise NotImplementedError( + f"FSDP does not support non-contiguous parameters yet: {param.shape=} {param.stride()=}" + ) + fsdp_placement = shard_placement_fn(param) if shard_placement_fn else None + if fsdp_placement is None: + fsdp_placement = Shard(0) + elif fsdp_placement.dim < 0: + fsdp_placement = Shard(fsdp_placement.dim + param.ndim) + assert isinstance(fsdp_placement, Shard), f"{fsdp_placement}" + self.fsdp_placement = fsdp_placement + shard_dim = fsdp_placement.dim + # TODO: Replace the sharded DTensor parameter construction logic with + # `distribute_tensor` after https://github.com/pytorch/pytorch/issues/116101 + # TODO: Simplify the following sharded parameter padding logic after + # https://github.com/pytorch/pytorch/issues/113045 + self.is_dtensor = isinstance(param, DTensor) + if self.is_dtensor: + self._tp_spec = cast(DTensor, param)._spec + dp_mesh, tp_mesh = (self.mesh_info.mesh, self._tp_spec.mesh) + dp_global_mesh = _mesh_resources.get_root_mesh(dp_mesh) + tp_global_mesh = _mesh_resources.get_root_mesh(tp_mesh) + if dp_global_mesh != tp_global_mesh or ( + dp_global_mesh is None or tp_global_mesh is None + ): + raise AssertionError( + "FSDP requires the DP and TP mesh to have the same parent mesh but got: \n" + f"DP's global mesh: {dp_global_mesh}\nTP's global mesh: {tp_global_mesh}" + ) + name_dims_error = "FSDP requires named DeviceMesh dims for ND parallelism" + assert dp_mesh.mesh_dim_names is not None, name_dims_error + assert tp_mesh.mesh_dim_names is not None, name_dims_error + submesh_names = dp_mesh.mesh_dim_names + tp_mesh.mesh_dim_names + self._spmd_mesh = dp_global_mesh[submesh_names] + if len(self._tp_spec.placements) != 1: + raise NotImplementedError( + f"FSDP only supports 1D TP, not {self._tp_spec.placements}" + ) + split_factor = self._tp_spec.num_shards_map[shard_dim] + assert ( + 2 <= self._spmd_mesh.ndim <= 3 + ), f"_spmd_mesh.ndim can only be 2 or 3 but got {self._spmd_mesh.ndim}." + self._spmd_placements: Tuple[Placement, ...] + dp_shard_tp_placement = ( + ( + _StridedShard(shard_dim, split_factor=split_factor) + if split_factor > 1 + else fsdp_placement + ), + self._tp_spec.placements[0], + ) + if self._spmd_mesh.ndim == 2: + self._spmd_placements = dp_shard_tp_placement + else: + assert self.mesh_info.replicate_mesh_dim == 0 + self._spmd_placements = (Replicate(),) + dp_shard_tp_placement + self._sharding_spec = DTensorSpec( + self._spmd_mesh, + self._spmd_placements, + tensor_meta=self._tp_spec.tensor_meta, + ) + # TODO: Enable uneven sharding for FSDP+TP. + if split_factor > 1: # FSDP has strided sharding on tensor dim 0 + num_shards = self._sharding_spec.num_shards_map[0] + tensor_size_dim_0 = self._sharding_spec.shape[0] + if tensor_size_dim_0 % num_shards != 0: + raise NotImplementedError( + "FSDP+TP sharding does not support uneven sharding for now: " + f"tensor dim 0 has size {tensor_size_dim_0} which cannot be " + f"evenly sharded into {num_shards} shards." + ) + param_data = cast(DTensor, param)._local_tensor + else: + self._spmd_mesh = self.mesh_info.mesh + if isinstance(self.mesh_info, HSDPMeshInfo): + self._spmd_placements = (Replicate(), fsdp_placement) + else: + self._spmd_placements = (fsdp_placement,) + self._sharding_spec = DTensorSpec( + self._spmd_mesh, + self._spmd_placements, + tensor_meta=TensorMeta(param.size(), param.stride(), param.dtype), + ) + param_data = param + assert param_data.is_contiguous(), f"{param_data.shape=} {param_data.stride()=}" + shard_dim = fsdp_placement.dim + if shard_dim >= param_data.ndim: + raise AssertionError( + f"Shard dim {shard_dim} is invalid for {param_data.ndim}D tensor: {param.shape}" + ) + self._orig_size = param_data.size() + self._contiguous_orig_stride = make_contiguous_strides_for(self._orig_size) + shard_rank = self.mesh_info.shard_mesh_rank + shard_world_size = self.mesh_info.shard_mesh_size + if shard_dim > 0 and param_data.size(shard_dim) % shard_world_size != 0: + # If sharding on nonzero dim, require even sharding for now because + # the uneven sharding (1) requires extra copies before/after FSDP + # collectives and (2) introduces extra complexity to handle padding + # and unpadding + raise NotImplementedError( + f"FSDP does not support uneven sharding on dim {shard_dim}: " + f"{param_data.size()} (world size: {shard_world_size})" + ) + chunks = _chunk_with_empty(param_data, shard_world_size, dim=shard_dim) + sharded_param = chunks[shard_rank] + self.sharded_size = _get_dim_chunked_size( + sharded_param, param_data.size(), dim=shard_dim + ) + self.contiguous_sharded_stride = make_contiguous_strides_for(self.sharded_size) + padded_sharded_size = chunks[0].size() # 0th always padded + self.padded_sharded_param_size = padded_sharded_size + # Pre-pad the sharded parameter to avoid padding before all-gather + padded_sharded_param = param_data.new_zeros(padded_sharded_size) + if sharded_param.numel() > 0: + padded_sharded_param.narrow( + dim=shard_dim, start=0, length=sharded_param.size(shard_dim) + ).copy_(sharded_param) + if self.offload_to_cpu and not padded_sharded_param.is_meta: + padded_sharded_param = padded_sharded_param.cpu() + if self.pin_memory: + padded_sharded_param = padded_sharded_param.pin_memory( + device=self.device + ) + self._sharded_param_data = padded_sharded_param.view(-1) + length = sharded_param.size(shard_dim) if sharded_param.numel() > 0 else 0 + sharded_param = padded_sharded_param.narrow( + dim=shard_dim, start=0, length=length + ) + assert sharded_param.is_contiguous(), f"{self.fsdp_placement=}" + self.sharded_param = nn.Parameter(self.to_sharded_dtensor(sharded_param)) + self.sharded_param.requires_grad_(param.requires_grad) + # Let `param_data` be freed normally when its ref count reaches 0 when + # the `fully_shard` call returns to allow provided parameters to alias + self._setattr_on_modules(self.sharded_param) + self.sharded_state = ShardedState.SHARDED + + def _init_sharded_post_forward_param_metadata(self, param: core.Tensor) -> None: + mesh_info = self.post_forward_mesh_info + assert mesh_info is not None # mypy + param_data = param._local_tensor if isinstance(param, DTensor) else param + chunks = _chunk_with_empty(param_data, mesh_info.shard_mesh_size, dim=0) + self.sharded_post_forward_size = _get_dim_chunked_size( + chunks[mesh_info.shard_mesh_rank], + param_data.size(), + dim=self.fsdp_placement.dim, + ) + self.contiguous_sharded_post_forward_stride = make_contiguous_strides_for( + self.sharded_post_forward_size + ) + + def init_dtype_attrs(self, mp_policy: MixedPrecisionPolicy): + param_dtype, reduce_dtype = (mp_policy.param_dtype, mp_policy.reduce_dtype) + self.orig_dtype = self.sharded_param.dtype + # Clamp `param_dtype` to `None` if no casting is required + if param_dtype == self.orig_dtype: + param_dtype = None + self.param_dtype = param_dtype + self.reduce_dtype = reduce_dtype + # None indicates that the mixed precision is not enabled + + def _init_extensions(self) -> None: + inner_tensor = self._sharded_local_tensor + has_fsdp_pre_all_gather = hasattr(inner_tensor, "fsdp_pre_all_gather") + has_fsdp_post_all_gather = hasattr(inner_tensor, "fsdp_post_all_gather") + if has_fsdp_pre_all_gather != has_fsdp_post_all_gather: + raise AssertionError( + "Both fsdp_pre_all_gather and fsdp_post_all_gather should be defined " + f"if using all-gather extensions: {inner_tensor}" + ) + if has_fsdp_pre_all_gather: + self._extensions_data = ExtensionsData() + self._unsharded_inner_tensors: List[core.Tensor] = [] + + def init_all_gather_outputs( + self, + all_gather_input_numels: List[int], + all_gather_input_dtypes: List[core.dtype], + world_size: int, + device: core.device, + force_recreate: bool = False, + ): + if not force_recreate and len(self.all_gather_outputs) > 0: + return # already initialized + self.all_gather_outputs = [ + core.empty(core.Size([numel * world_size]), dtype=dtype, device=device) + for numel, dtype in zip(all_gather_input_numels, all_gather_input_dtypes) + ] + + def init_unsharded_param(self): + """ + [Note: Invariants for core.compile Traceable FSDP2] + 1. Under compile, we always re-populate the content of `self._unsharded_param` + per AllGather using the slow path. + 2. Under compile, we always recreate `self.all_gather_outputs` per AllGather. + This is to ensure the buffer creation is internal to the graph and + avoid `self.all_gather_outputs` being captured as a graph input. + 3. Under compile, at the end of `free_unsharded_param()`, we always clean up + `self.all_gather_outputs` and `self._unsharded_inner_tensors`, + to avoid them being captured as graph output. + + With these invariants, only these tensors will be inputs to the graph: + - Sharded parameters + - Placeholders for the `self._unsharded_param` nn.Parameter + """ + if not compiled_autograd_enabled() and hasattr( + self, "_unsharded_param" + ): # after the 1st all-gather + inner_tensor = self._sharded_local_tensor + if not hasattr(inner_tensor, "fsdp_post_all_gather"): + return # already initialized + for tensor in self._unsharded_inner_tensors: + alloc_storage(tensor) + all_gather_outputs = self._unflatten_all_gather_outputs() + inner_tensor.fsdp_post_all_gather( + all_gather_outputs, + self._extensions_data.all_gather_metadata, + self.param_dtype or self.orig_dtype, + out=self._unsharded_param, + ) + self._extensions_data.clear() + return + inner_tensor = self._sharded_local_tensor + if not compiled_autograd_enabled() and hasattr( + inner_tensor, "fsdp_post_all_gather" + ): + all_gather_outputs = self._unflatten_all_gather_outputs() + ( + unsharded_tensor, + self._unsharded_inner_tensors, + ) = inner_tensor.fsdp_post_all_gather( + all_gather_outputs, + self._extensions_data.all_gather_metadata, + self.param_dtype or self.orig_dtype, + ) + self._extensions_data.clear() + else: + # For the default path (no post-all-gather), the all-gather output + # gives the unsharded parameter data directly + assert len(self.all_gather_outputs) == 1, f"{len(self.all_gather_outputs)}" + unsharded_tensor = self.all_gather_outputs[0] + unsharded_param = core.as_strided( + unsharded_tensor, + self._orig_size, + self._contiguous_orig_stride, + storage_offset=0, + ) + if self.is_dtensor: + unsharded_param = _from_local_no_grad(unsharded_param, self._tp_spec) + if hasattr(self, "_unsharded_param"): + assert compiled_autograd_enabled() + with core.no_grad(), core.autograd._unsafe_preserve_version_counter( + self._unsharded_param + ): + # NOTE: Under compile, if an unsharded param goes through + # resize_(full) -> copy_ -> resize_(0) pattern, we will remove those + # resize_ and copy_ ops in a compiler graph pass + # `remove_fsdp2_unsharded_param_graph_input_usage` to recover performance. + self._unsharded_param.untyped_storage().resize_( + self._unsharded_param.numel() * self._unsharded_param.itemsize + ) + core.ops.fsdp.copy_(self._unsharded_param, unsharded_param) + else: + self._unsharded_param = nn.Parameter( + unsharded_param, requires_grad=self.sharded_param.requires_grad + ) + + def _unflatten_all_gather_outputs(self) -> Tuple[core.Tensor, ...]: + return tuple( + t.view(-1, *s[1:]) + for t, s in zip( + self.all_gather_outputs, self._extensions_data.all_gather_input_sizes + ) + ) + + def to_sharded(self) -> None: + self._setattr_on_modules(self.sharded_param) + self.free_unsharded_param() + self.sharded_state = ShardedState.SHARDED + + def to_sharded_post_forward(self) -> None: + if self.is_dtensor: + raise NotImplementedError( + "Resharding to smaller mesh with TP is not supported yet" + ) + self._assert_in_states(ShardedState.UNSHARDED) + assert self.post_forward_mesh_info is not None # mypy + assert len(self.all_gather_outputs) == 1 + shard_world_size = self.post_forward_mesh_info.shard_mesh_size + if (numel := self.all_gather_outputs[0].numel()) % shard_world_size != 0: + _raise_assert_with_print( + f"All-gather output size ({numel}) must be divisible by the shard " + f"world size ({shard_world_size})" + ) + shard_rank = self.post_forward_mesh_info.shard_mesh_rank + sharded_numel = numel // shard_world_size + self._sharded_post_forward_param_data = ( + self.all_gather_outputs[0].narrow( + 0, sharded_numel * shard_rank, sharded_numel + ) + ).clone() # clone to be able to free all-gather output + sharded_post_forward_tensor = core.as_strided( + self._sharded_post_forward_param_data, + size=self.sharded_post_forward_size, + stride=self.contiguous_sharded_post_forward_stride, + storage_offset=0, + ) + self._sharded_post_forward_param = nn.Parameter( + self.to_sharded_post_forward_dtensor(sharded_post_forward_tensor) + ) + self._setattr_on_modules(self._sharded_post_forward_param) + self.free_unsharded_param() + self.sharded_state = ShardedState.SHARDED_POST_FORWARD + + def to_unsharded(self) -> None: + # Assume that the data has been allocated and all-gathered + set_requires_grad_if_needed(self.sharded_param, self._unsharded_param) + self._setattr_on_modules(self._unsharded_param) + if self.sharded_state == ShardedState.SHARDED_POST_FORWARD: + # The data is allocated in the default stream via the post-forward + # reshard and must be kept alive for the next all-gather copy-in. + # Since we call this method after the copy-out, the data's lifetime + # is ensured without further synchronization. + self._sharded_post_forward_param = None + self._sharded_post_forward_param_data = None # free + self.sharded_state = ShardedState.UNSHARDED + + def _setattr_on_modules(self, param: nn.Parameter) -> None: + unsafe_setattr_param( + self._module_info.module, self._module_info.param_name, param + ) + for shared_module, shared_param_name in zip( + self._module_info.shared_modules, self._module_info.shared_param_names + ): + unsafe_setattr_param(shared_module, shared_param_name, param) + + def to_sharded_dtensor(self, tensor: core.Tensor) -> DTensor: + """ + Converts a local tensor representing either the sharded parameter or + sharded gradient to DTensor. + """ + if tensor.shape != self.sharded_size: + _raise_assert_with_print( + f"Expects size {self.sharded_size} but got {tensor.shape}" + ) + return _from_local_no_grad( + tensor, + self._sharding_spec, + ) + + def to_sharded_post_forward_dtensor(self, tensor: core.Tensor) -> DTensor: + if tensor.shape != self.sharded_post_forward_size: + _raise_assert_with_print( + f"Expects size {self.sharded_post_forward_size} but got {tensor.shape}" + ) + assert isinstance(self.post_forward_mesh_info, HSDPMeshInfo) + # TODO: Prefer this DTensor to be read-only and generalize the + # placement once we support TP. + post_forward_sharding_spec = DTensorSpec( + self.post_forward_mesh_info.mesh, + (Replicate(), Shard(0)), + tensor_meta=self._sharding_spec.tensor_meta, + ) + return _from_local_no_grad(tensor, post_forward_sharding_spec) + + def to_accumulated_grad_if_needed(self) -> None: + # Access `_unsharded_param` to bypass the sharded state check since we + # prefer to reshard before upcasting the gradient to save memory + if ( + self.reduce_dtype is None + or self._unsharded_param.grad is None + or self._unsharded_param.grad.dtype == self.reduce_dtype + ): + return + unsharded_grad = self._unsharded_param.grad + self._unsharded_param.grad = None + self.unsharded_accumulated_grad = unsharded_grad.to(self.reduce_dtype) + + def accumulate_unsharded_grad_if_needed(self) -> None: + if ( + self.unsharded_accumulated_grad is not None + and self.unsharded_param.grad is not None + ): + self.unsharded_accumulated_grad += self.unsharded_param.grad + self.unsharded_param.grad = None + + def alloc_all_gather_outputs(self) -> None: + for tensor in self.all_gather_outputs: + alloc_storage(tensor) + + def free_unsharded_param(self) -> None: + if compiled_autograd_enabled(): + """ + Assumptions under compile: + - `self._unsharded_param` is NOT an alias of `self.all_gather_outputs`. + Instead, we resize `self._unsharded_param` storage size to full and then + explicitly *copy* the data from `self.all_gather_outputs` to `self._unsharded_param` + in `init_unsharded_param()`. (For full-graph FSDP2 case, we will then remove + the resize_ and copy_ ops in a compiler graph pass to recover performance.) + - `self.all_gather_outputs` and `self._unsharded_inner_tensors` are NOT + graph inputs. They are created within the graph and is guaranteed to be freed + by the end of the graph. They don't leak outside of the graph. + """ + self._unsharded_param.untyped_storage().resize_(0) + self.all_gather_outputs = [] + self._unsharded_inner_tensors = [] + else: + for tensor in itertools.chain( + self.all_gather_outputs, self._unsharded_inner_tensors + ): + free_storage(tensor) + + @property + def all_gather_inputs(self) -> List[core.Tensor]: # 1D + self._assert_in_states(ShardedState.SHARDED, ShardedState.SHARDED_POST_FORWARD) + if self.sharded_state == ShardedState.SHARDED: + if not compiled_autograd_enabled() and hasattr( + self._sharded_local_tensor, "fsdp_pre_all_gather" + ): + sharded_local_tensor = self._sharded_local_tensor + if self.offload_to_cpu: + sharded_local_tensor = sharded_local_tensor.to( + self.device, non_blocking=True + ) + pre_all_gather_signature = inspect.signature( + sharded_local_tensor.fsdp_pre_all_gather + ) + num_fn_params = len(pre_all_gather_signature.parameters) + # Old signature only passes mesh; keep for BC for now + assert num_fn_params in ( + 1, + 5, + ), ( + f"Invalid fsdp_pre_all_gather: {pre_all_gather_signature}\n" + "Expects fsdp_pre_all_gather(self, mesh: DeviceMesh, " + "module: nn.Module, mp_policy: MixedPrecisionPolicy)" + ) + if num_fn_params == 1: + ( + all_gather_inputs, + self._extensions_data.all_gather_metadata, + ) = sharded_local_tensor.fsdp_pre_all_gather(self.shard_mesh) + else: + ( + all_gather_inputs, + self._extensions_data.all_gather_metadata, + ) = sharded_local_tensor.fsdp_pre_all_gather( + self.shard_mesh, + self._orig_size, + self._contiguous_orig_stride, + self._module_info.module, + self.mp_policy, + ) + if ( + sharded_local_tensor.size() != self.padded_sharded_param_size + and any( + all_gather_input.size() != self.padded_sharded_param_size + for all_gather_input in all_gather_inputs + ) + ): + # NOTE: Since this error can only be raised on the + # ranks that have padding, this can manifest as a NCCL + # watchdog timeout, as the other ranks will not error. + raise AssertionError( + "When a parameter is unevenly sharded by FSDP " + f"(orig size={self._orig_size}, FSDP world size={self.mesh_info.mesh.size()}), " + "fsdp_pre_all_gather must return all-gather inputs with the padded sharded size " + f"{self.padded_sharded_param_size} but got {[t.size() for t in all_gather_inputs]}" + ) + self._extensions_data.all_gather_input_sizes = [ + t.size() for t in all_gather_inputs + ] + return [t.view(-1) for t in all_gather_inputs] + sharded_param_data = self._sharded_param_data + if self.offload_to_cpu: + sharded_param_data = sharded_param_data.to( + self.device, non_blocking=True + ) + return [_to_dtype_if_needed(sharded_param_data, self.param_dtype)] + elif self.sharded_state == ShardedState.SHARDED_POST_FORWARD: + if not compiled_autograd_enabled() and hasattr( + self._sharded_local_tensor, "fsdp_pre_all_gather" + ): + raise NotImplementedError + all_gather_input = _to_dtype_if_needed( + cast(core.Tensor, self._sharded_post_forward_param_data), + self.param_dtype, + ) + return [all_gather_input] + return [core.empty(0)] # mypy + + @property + def unsharded_param(self) -> nn.Parameter: # ND + return self._unsharded_param + + @property + def unsharded_grad_data(self) -> core.Tensor: + grad = self.unsharded_param.grad + assert grad is not None, "Expects unsharded_param.grad to not be None" + return self._get_grad_inner_tensor(grad) + + @property + def unsharded_accumulated_grad_data(self) -> core.Tensor: + grad = self.unsharded_accumulated_grad + assert grad is not None, "Expects unsharded_accumulated_grad to not be None" + return self._get_grad_inner_tensor(grad) + + def _get_grad_inner_tensor(self, grad: core.Tensor) -> core.Tensor: + if self.is_dtensor: + if isinstance(grad, AsyncCollectiveTensor): + grad = grad.wait() + assert isinstance(grad, DTensor), f"{type(grad)}" + placements = self._tp_spec.placements + if placements != grad.placements: + assert len(self._tp_spec.placements) == len( + grad.placements + ), f"{self._tp_spec=} {grad.placements=}" + grad = grad.redistribute(placements=placements) + grad = grad._local_tensor + return grad + + @property + def _sharded_local_tensor(self) -> core.Tensor: + return cast(DTensor, self.sharded_param)._local_tensor + + @property + def shard_mesh(self): + mesh = self.mesh_info.mesh + if mesh.ndim == 1: + return mesh + elif mesh.ndim == 2: + assert mesh.mesh_dim_names is not None + return mesh[mesh.mesh_dim_names[-1]] + raise ValueError(f"Invalid mesh: {mesh}") + + def _assert_in_states(self, *states: ShardedState) -> None: + if self.sharded_state not in states: + _raise_assert_with_print( + f"Expects to be in one of {states}, not {self.sharded_state}" + ) + + def reset_sharded_param(self): + # For ops like `nn.Module._apply` or `load_state_dict(assign=True)` + # that change the sharded parameter tensor, we may need to re-pad the + # sharded local tensor and re-save the reference. + module_info = self._module_info + new_param = getattr(module_info.module, module_info.param_name) + if new_param is not self.sharded_param: + if core.__future__.get_swap_module_params_on_conversion(): + raise AssertionError( + f"Expects swap_tensors to preserve object but got {new_param} " + f"instead of {self.sharded_param}" + ) + self.sharded_param = new_param + local_tensor = new_param._local_tensor + if local_tensor.is_meta: + return + updated_local_tensor = False + padded_sharded_size = self.padded_sharded_param_size + shard_dim = self.fsdp_placement.dim + length = local_tensor.size(shard_dim) if local_tensor.numel() > 0 else 0 + if local_tensor.size() != padded_sharded_size: + assert ( + shard_dim == 0 + ), f"Shard({shard_dim}) requires even sharding: {local_tensor.size()=}" + padded_local_tensor = local_tensor.new_zeros(padded_sharded_size) + padded_local_tensor.narrow(dim=shard_dim, start=0, length=length).copy_( + local_tensor + ) + local_tensor = padded_local_tensor + updated_local_tensor = True + if self.pin_memory and not local_tensor.is_pinned(): + local_tensor = local_tensor.cpu().pin_memory(device=self.device) + updated_local_tensor = True + self._sharded_param_data = local_tensor.view(-1) + assert isinstance(self.sharded_param, DTensor) # mypy + if updated_local_tensor: + # Only change the local tensor object if needed + self.sharded_param._local_tensor = local_tensor.narrow( + dim=shard_dim, start=0, length=length + ) + assert self.sharded_param._local_tensor.is_contiguous() + self._sharding_spec = self.sharded_param._spec + + def __repr__(self): + return f"FSDPParam(fqn={self._param_fqn}, orig_size={self._orig_size})" + + +def alloc_storage(tensor: core.Tensor) -> None: + size = tensor.numel() * tensor.itemsize + if (storage := tensor.untyped_storage()).size() != size: + storage.resize_(size) + + +def free_storage(tensor: core.Tensor) -> None: + if (storage := tensor.untyped_storage()).size() != 0: + storage.resize_(0) + + +# NOTE: These bypass `nn.Module.__setattr__` checks, which incur non-trivial +# CPU overhead, if the module did not override it. For FSDP, we know we do not +# need those checks when transitioning between sharded/unsharded parameters. +def unsafe_setattr_param( + module: nn.Module, param_name: str, param: nn.Parameter +) -> None: + if getattr(module.__setattr__, "__func__", None) is nn.Module.__setattr__: + module._parameters[param_name] = param + else: # slow path + setattr(module, param_name, param) + + +def set_requires_grad_if_needed( + src_tensor: core.Tensor, dst_tensor: core.Tensor +) -> None: + # Only call `requires_grad_` if needed to avoid the Python <> C++ context + # switch overhead + if src_tensor.requires_grad != dst_tensor.requires_grad: + dst_tensor.requires_grad_(src_tensor.requires_grad) diff --git a/mindnlp/core/distributed/_composable/fsdp/_fsdp_param_group.py b/mindnlp/core/distributed/_composable/fsdp/_fsdp_param_group.py new file mode 100644 index 000000000..b8ed558a1 --- /dev/null +++ b/mindnlp/core/distributed/_composable/fsdp/_fsdp_param_group.py @@ -0,0 +1,725 @@ +# mypy: allow-untyped-defs +import contextlib +import logging +from typing import Any, Callable, cast, Dict, List, NamedTuple, Optional, Set, Tuple + +from mindnlp import core +from mindnlp import core.distributed as dist +from mindnlp import core.nn as nn +from core.distributed.device_mesh import _get_device_handle +from core.distributed.fsdp._common_utils import _named_parameters_with_duplicates +from core.distributed.tensor import Shard +from core.profiler import record_function +from core.utils._pytree import tree_flatten, tree_unflatten +from core.utils.hooks import RemovableHandle + +from ._fsdp_api import CPUOffloadPolicy, MixedPrecisionPolicy, OffloadPolicy +from ._fsdp_collectives import ( + AllGatherResult, + foreach_all_gather, + foreach_all_gather_copy_out, + foreach_reduce, +) +from ._fsdp_common import ( + compiled_autograd_enabled, + FSDPMeshInfo, + HSDPMeshInfo, + TrainingState, +) +from ._fsdp_param import FSDPParam, ParamModuleInfo, ShardedState + + +logger = logging.getLogger("core.distributed._composable.fsdp") + +_ModuleToHandleDict = Dict[nn.Module, RemovableHandle] # for state dict + + +""" +[Note: Overlapping all-gather copy-in and all-gather] +For implicit forward prefetching, we want to overlap the next copy-in with the +current all-gather. We do so using a separate copy-in stream. However, since +we have the all-gather input as a view into the output, we must make sure to +copy into different memory from the current all-gather's output. Thus, we keep +a reference to the current all-gather's output and have the next FSDP parameter +group free it after its copy-in. Finally, we have the last FSDP state flush the +reference to avoid holding onto memory after forward. +""" + + +class FSDPCommContext: + """This has the communication state shared across FSDP states/parameter groups.""" + + def lazy_init(self, device: core.device): + self.device_handle = _get_device_handle(device.type) + # Setting the all-gather/reduce-scatter streams to be higher priority + # can help avoid some issues where their copies in/out are delayed and + # block computation (this is different from high-pri NCCL streams) + high_priority = -1 + # All-gather state and copy-in stream allow overlapping the next + # copy-in with the current all-gather in forward; copy-in overlaps with + # reduce-scatter in backward without the separate copy-in stream + self.all_gather_copy_in_stream = self.device_handle.Stream( + priority=high_priority + ) + # All-gather stream allows overlapping next all-gather with current + # forward compute + self.all_gather_stream = self.device_handle.Stream(priority=high_priority) + # Reduce-scatter stream gives separate execution "thread" for post- + # backward logic like pre/post-gradient division and reduce-scatter + self.reduce_scatter_stream = self.device_handle.Stream(priority=high_priority) + # Run the HSDP all-reduces concurrently with all-gather/reduce-scatter + # since collectives use different network resources and can overlap + # in the typical intra-node sharding / inter-node replication case + self.all_reduce_stream = self.device_handle.Stream() + # All-gather/reduce-scatter states keep references to collective + # tensors produced in one stream and used in another and accompanying + # CUDA events for synchronization + self.all_gather_state: Optional[AllGatherState] = None + self.reduce_scatter_state: Optional[ReduceScatterState] = None + # Post-forward order for explicit backward prefetching + self.post_forward_order: List[FSDPParamGroup] = [] # will cause ref cycles + + def get_all_gather_streams( + self, async_op: bool, training_state: TrainingState + ) -> Tuple[core.Stream, core.Stream]: + if not async_op and training_state in ( + TrainingState.FORWARD, + TrainingState.PRE_BACKWARD, + ): + # Use separate streams for implicit prefetching + return self.all_gather_copy_in_stream, self.all_gather_stream + current_stream = self.device_handle.current_stream() + return current_stream, current_stream + + +# See [Note: Overlapping all-gather copy-in and all-gather] +class AllGatherState(NamedTuple): + all_gather_result: AllGatherResult + event: core.Event # all-gather copy-out + + +class ReduceScatterState(NamedTuple): + reduce_scatter_input: core.Tensor + event: core.Event # reduce-scatter event + + +class AllReduceState(NamedTuple): + all_reduce_input: core.Tensor + event: core.Event # all-reduce event + + +class FSDPParamGroup: + """This class represents a parameter group to communicate together.""" + + _orig_dtype: core.dtype + _reduce_dtype: Optional[core.dtype] + + def __init__( + self, + params: List[nn.Parameter], + modules: Tuple[nn.Module, ...], + mesh_info: FSDPMeshInfo, + post_forward_mesh_info: Optional[FSDPMeshInfo], + device: core.device, + shard_placement_fn: Optional[Callable[[nn.Parameter], Optional[Shard]]], + mp_policy: MixedPrecisionPolicy, + offload_policy: OffloadPolicy, + ): + self.modules = modules # permit ref cycle because 1:1 lifetime + param_module_infos = _get_param_module_infos(params, modules) + + self.fsdp_params = [ + FSDPParam( + param, + module_info, + mesh_info, + post_forward_mesh_info, + device, + shard_placement_fn, + mp_policy, + offload_policy, + ) + for param, module_info in zip(params, param_module_infos) + ] + self.mesh_info = mesh_info + self.post_forward_mesh_info = post_forward_mesh_info + self.device = device + self.device_handle = _get_device_handle(device.type) + self.mp_policy = mp_policy + self.offload_policy = offload_policy + self._training_state = TrainingState.IDLE + # Group's sharded state always matches its parameters' sharded states + self._sharded_state = ShardedState.SHARDED + self._module_fqn: Optional[str] = None # prefixed from root module + # Only consider resetting sharded parameters once in lazy init since it + # can incur nontrivial overhead to reset them + self._reset_sharded_params: bool = False + + # - Hook state + self._module_to_pre_save_state_dict_hook_handle: _ModuleToHandleDict = {} + self._module_to_pre_load_state_dict_hook_handle: _ModuleToHandleDict = {} + + # - Communication and communication/computation overlap + self.comm_ctx = FSDPCommContext() + # Group's indices in the shared post-forward order + self._post_forward_indices: List[int] = [] + # Whether to reduce gradients at all (whether for FSDP or HSDP) + self.reduce_grads: bool = True + # Whether to all-reduce gradients for HSDP; only used if + # `self.reduce_grads` is true, in which case setting this to false + # means reduce-scatter but no all-reduce + self.all_reduce_grads: bool = True + # Whether to reshard parameters after backward (only useful for + # gradient accumulation) + self.reshard_after_backward: bool = True + # Optional custom reduce-scatter reduce op (e.g. to divide by a + # factor other than the shard world size) + self.reduce_scatter_reduce_op: Optional[dist.ReduceOp] = None + # `async_op` arg used for pre-forward/pre-backward unshard; can be + # overridden to only do explicit prefetching and avoid inter-stream + # fragmentation from using separate unshard streams + self.unshard_async_op: bool = False + # Whether to unshard in backward: can be overridden by the user if the + # parameters in this group are not needed for backward (e.g. embedding) + self.unshard_in_backward: bool = True + + # - CUDA events for stream synchronization + # Holds the all-gather output buffer, sync objects, and metadata + self._all_gather_result: Optional[AllGatherResult] = None + # Holds the reduce-scatter/all-reduce view-out CUDA event that marks the end of + # the group's post-backward (e.g. reduce-scatter, all-reduce and div), which + # should be waited on at the end of backward + self._post_reduce_event: Optional[core.Event] = None + # Holds the reshard-after-forward CUDA event when resharding to a + # different world size, which should be waited on in the next unshard + self._reshard_after_forward_event: Optional[core.Event] = None + + # Only for HSDP, if accumulating gradients without all-reduce, save the + # partial reduce output (only reduce-scattered but not all-reduced) + self._partial_reduce_output: Optional[core.Tensor] = None + # Holds the all-reduce input and all-reduce event to keep it alive + # until the end of backward (critical when doing bf16 reduction with + # fp32 parameters since the all-reduce input is allocated in the RS + # stream and will have no refs to it after being upcast to fp32) + self._all_reduce_state: Optional[AllReduceState] = None + + # Initialization # + def _init_mp_dtypes(self) -> None: + for fsdp_param in self.fsdp_params: + fsdp_param.init_dtype_attrs(self.mp_policy) + orig_dtypes = {fsdp_param.orig_dtype for fsdp_param in self.fsdp_params} + if len(orig_dtypes) != 1: + # This can be relaxed if we copy-out for the reduce-scatter + raise AssertionError( + f"FSDP expects uniform original parameter dtype but got {orig_dtypes}" + ) + self._orig_dtype = next(iter(orig_dtypes)) + reduce_dtypes = {fsdp_param.reduce_dtype for fsdp_param in self.fsdp_params} + if len(reduce_dtypes) != 1: + # This can be relaxed if we issue one reduce-scatter per reduce + # dtype (but we would need a way for users to specify multiple + # reduce dtypes) + raise AssertionError( + f"FSDP expects uniform reduce dtype but got {reduce_dtypes}" + ) + self._reduce_dtype = next(iter(reduce_dtypes)) + + def lazy_init(self): + # Lazy init should be idempotent + # Users may change or register parameters after construction time. + # For example, DoRA (https://arxiv.org/abs/2402.09353) initializes linear magnitudes based on + # other parameters (e.g. loaded from the state dict). + if not hasattr(self.comm_ctx, "device_handle"): + self.comm_ctx.device_handle = _get_device_handle(self.device.type) + if self.is_sharded and not self._reset_sharded_params: + for fsdp_param in self.fsdp_params: + fsdp_param.reset_sharded_param() + fsdp_param._init_extensions() # allow monkey patch after init + self._reset_sharded_params = True + self._validate_no_meta_params() + self._validate_cpu_offload_params() + # Initialize mixed precision attributes lazily in case the user changes + # the parameter dtypes after construction time but before forward + self._init_mp_dtypes() + self._register_state_dict_hooks() + + # Runtime # + def unshard(self, async_op: bool = False): + if self._all_gather_result is not None: # already called, pending wait + return + if self.is_unsharded: + return # no-op + if ( + not self.unshard_in_backward + and self._training_state == TrainingState.PRE_BACKWARD + ): + return + if self._reshard_after_forward_event is not None: + # Resharded parameter data is allocated in the default stream and + # used in the all-gather streams + self._wait_all_gather_streams_on_event(self._reshard_after_forward_event) + self._reshard_after_forward_event = None + with record_function(self._with_fqn("FSDP::all_gather")): + self._all_gather_result = foreach_all_gather( + self.fsdp_params, + self._all_gather_process_group, + async_op, + *self.comm_ctx.get_all_gather_streams(async_op, self._training_state), + self.device, + ) + + def wait_for_unshard(self): + """ + 1. In forward with implict prefetching, to overlap the current copy-out + with the next all-gather, we save a reference to the current all-gather + result to free after the next copy-out. + 2. Otherwise (explicit prefetching or in backward), we free the + all-gather result immediately after the current copy-out since we can + already overlap the current copy-out with the previous reduce-scatter. + """ + if not self._all_gather_result: + return # no preceding unshard + async_op = self._all_gather_result.all_gather_work is not None + if self._training_state == TrainingState.FORWARD: # implicit prefetch + if prev_all_gather_state := self.comm_ctx.all_gather_state: + self._wait_all_gather_streams_on_event(prev_all_gather_state.event) + self.comm_ctx.all_gather_state = None # free the all-gather result + with record_function(self._with_fqn("FSDP::all_gather_copy_out")): + foreach_all_gather_copy_out( + self._all_gather_result, + self.fsdp_params, + self._all_gather_process_group, + ) + for fsdp_param in self.fsdp_params: + fsdp_param.init_unsharded_param() + self._to_unsharded() + all_gather_copy_out_event = self.device_handle.Event() + all_gather_copy_out_event.record() + if not async_op and self._training_state == TrainingState.FORWARD: + # Defer free to allow for overlap of this copy-out with next + # all-gather collective + self.comm_ctx.all_gather_state = AllGatherState( + self._all_gather_result, all_gather_copy_out_event + ) + else: + self._wait_all_gather_streams_on_event(all_gather_copy_out_event) + self._all_gather_result = None # free unless saved in `all_gather_state` + + def _wait_all_gather_streams_on_event(self, event: core.Event): + # Calling `unshard` before lazy init means streams are not initialized + if hasattr(self.comm_ctx, "all_gather_copy_in_stream"): + self.comm_ctx.all_gather_copy_in_stream.wait_event(event) + if hasattr(self.comm_ctx, "all_gather_stream"): + self.comm_ctx.all_gather_stream.wait_event(event) + + def reshard(self): + if self._training_state == TrainingState.FORWARD: + if not self._reshard_after_forward: + return + if self._use_post_forward_mesh: + self._to_sharded_post_forward() + self._reshard_after_forward_event = self.device_handle.Event() + if self._reshard_after_forward_event is not None: + self._reshard_after_forward_event.record() + return + self._to_sharded() + + def pre_forward( + self, module: nn.Module, args: Tuple[Any, ...], kwargs: Dict[str, Any] + ) -> Tuple[Tuple[Any, ...], Dict[str, Any]]: + if not compiled_autograd_enabled(): + logger.debug("%s", self._with_fqn("FSDP::pre_forward")) + with record_function(self._with_fqn("FSDP::pre_forward")): + self._training_state = TrainingState.FORWARD + self.unshard(self.unshard_async_op) + self.wait_for_unshard() + args, kwargs = self._register_post_backward_hook(args, kwargs) + return args, kwargs + + def post_forward(self, module: nn.Module, input: Any, output: Any): + if not compiled_autograd_enabled(): + logger.debug("%s", self._with_fqn("FSDP::post_forward")) + with record_function(self._with_fqn("FSDP::post_forward")): + self.reshard() + self._record_post_forward() + self._training_state = TrainingState.IDLE + return output + + def _record_post_forward(self) -> None: + # Since a group has one pre-backward unshard for each forward call + # before the backward, we record each usage (with multiplicity) + post_forward_index = len(self.comm_ctx.post_forward_order) + self.comm_ctx.post_forward_order.append(self) + self._post_forward_indices.append(post_forward_index) + + def pre_backward(self, default_prefetch: bool, *unused: Any): + if ( + compiled_autograd_enabled() + and self._training_state == TrainingState.PRE_BACKWARD + ): + # Traceable FSDP2 cannot trigger the param group's `post_backward` immediately after param usage; + # instead it relies on this to trigger the previously unexecuted `post_backward`. + self.post_backward() + if self._training_state == TrainingState.PRE_BACKWARD: + return + if not compiled_autograd_enabled(): + logger.debug("%s", self._with_fqn("FSDP::pre_backward")) + with record_function(self._with_fqn("FSDP::pre_backward")): + self._training_state = TrainingState.PRE_BACKWARD + self.unshard(self.unshard_async_op) # no-op if prefetched + self.wait_for_unshard() + if default_prefetch and not compiled_autograd_enabled(): + self._backward_prefetch() + + def post_backward(self, *unused: Any): + if not compiled_autograd_enabled(): + logger.debug("%s", self._with_fqn("FSDP::post_backward")) + self._training_state = TrainingState.POST_BACKWARD + with record_function(self._with_fqn("FSDP::post_backward_accumulate")): + for fsdp_param in self.fsdp_params: + fsdp_param.accumulate_unsharded_grad_if_needed() + with record_function(self._with_fqn("FSDP::post_backward_reshard")): + if not self.reduce_grads: + if self.reshard_after_backward: + self.reshard() + for fsdp_param in self.fsdp_params: + fsdp_param.to_accumulated_grad_if_needed() + return + # Save the autograd-computed gradients before resharding to only + # access the unsharded parameters when their data is present + fsdp_params_with_grad: List[FSDPParam] = [] + unsharded_grads: List[core.Tensor] = [] + for fsdp_param in self.fsdp_params: + # May have an accumulated gradient of the reduce dtype if the + # previous backward did not reduce-scatter + if fsdp_param.unsharded_accumulated_grad is not None: + fsdp_params_with_grad.append(fsdp_param) + unsharded_grads.append(fsdp_param.unsharded_accumulated_grad_data) + fsdp_param.unsharded_accumulated_grad = None + elif fsdp_param.unsharded_param.grad is not None: + fsdp_params_with_grad.append(fsdp_param) + unsharded_grads.append(fsdp_param.unsharded_grad_data) + fsdp_param.unsharded_param.grad = None + if self.reshard_after_backward: + self.reshard() + if len(fsdp_params_with_grad) == 0: + return + with record_function(self._with_fqn("FSDP::post_backward_reduce")): + if self.comm_ctx.reduce_scatter_state is not None: + self.device_handle.current_stream().wait_event( + self.comm_ctx.reduce_scatter_state.event + ) + self.comm_ctx.reduce_scatter_state = None + self._wait_for_post_backward() + ( + reduce_scatter_input, + reduce_scatter_event, + self._post_reduce_event, + all_reduce_input, + all_reduce_event, + self._partial_reduce_output, + ) = foreach_reduce( + fsdp_params_with_grad, + unsharded_grads, + self._reduce_scatter_process_group, + self.comm_ctx.reduce_scatter_stream, + self._orig_dtype, + self._reduce_dtype, + self.device, + self.reduce_scatter_reduce_op, + self._all_reduce_process_group if self._is_hsdp else None, + self.comm_ctx.all_reduce_stream, + self.all_reduce_grads, + self._partial_reduce_output, + ) + self.comm_ctx.reduce_scatter_state = ReduceScatterState( + reduce_scatter_input, reduce_scatter_event + ) + if all_reduce_input is not None: + assert all_reduce_event is not None + self._all_reduce_state = AllReduceState( + all_reduce_input, all_reduce_event + ) + + def finalize_backward(self): + self._wait_for_post_backward() + for fsdp_param in self.fsdp_params: + if fsdp_param.grad_offload_event is not None: + fsdp_param.grad_offload_event.synchronize() + fsdp_param.grad_offload_event = None + if self._all_gather_result is not None: + # If there was a mistargeted unshard without a corresponding wait, + # then we wait here and clear the unshard + if (event := self._all_gather_result.all_gather_event) is not None: + core.cuda.current_stream().wait_event(event) + work = self._all_gather_result.all_gather_work + if isinstance(work, dist.distributed_c10d.Work): + work.wait() + self._all_gather_result = None + self._post_forward_indices.clear() + + def _wait_for_post_backward(self): + if self._post_reduce_event is not None: + self.device_handle.current_stream().wait_event(self._post_reduce_event) + self._post_reduce_event = None + if self._all_reduce_state is not None: + self.device_handle.current_stream().wait_event(self._all_reduce_state.event) + self._all_reduce_state = None + + def _backward_prefetch(self) -> None: + if self._training_state == TrainingState.PRE_BACKWARD: + if not self._post_forward_indices: + # Can be cleared if running multiple `backward`s + return + curr_index = self._post_forward_indices.pop() + if (target_index := curr_index - 1) < 0: + return + # Prefetch naively using the reverse post-forward order, which may + # have mistargeted prefetches if not all modules used in forward + # are used in this backward + target_fsdp_param_group = self.comm_ctx.post_forward_order[target_index] + self._prefetch_unshard(target_fsdp_param_group, "backward") + + @staticmethod + def _prefetch_unshard( + target_fsdp_param_group: "FSDPParamGroup", pass_type: str + ) -> None: + if pass_type == "backward": + training_state = TrainingState.PRE_BACKWARD + elif pass_type == "forward": + training_state = TrainingState.FORWARD + else: + raise ValueError(f"Unknown pass type: {pass_type}") + target_fqn = target_fsdp_param_group._module_fqn + with record_function( + f"FSDP::{pass_type}_prefetch for {target_fqn}" + ), target_fsdp_param_group.use_training_state(training_state): + async_op = target_fsdp_param_group.unshard_async_op + target_fsdp_param_group.unshard(async_op) + + # Utilities # + def _to_sharded(self): + if not self.is_sharded: + for fsdp_param in self.fsdp_params: + fsdp_param.to_sharded() + self._sharded_state = ShardedState.SHARDED + + def _to_sharded_post_forward(self): + if not self.is_sharded_post_forward: + for fsdp_param in self.fsdp_params: + fsdp_param.to_sharded_post_forward() + self._sharded_state = ShardedState.SHARDED_POST_FORWARD + + def _to_unsharded(self): + if not self.is_unsharded: + for fsdp_param in self.fsdp_params: + fsdp_param.to_unsharded() + self._sharded_state = ShardedState.UNSHARDED + + @property + def is_sharded(self) -> bool: + return self._sharded_state == ShardedState.SHARDED + + @property + def is_sharded_post_forward(self) -> bool: + return self._sharded_state == ShardedState.SHARDED_POST_FORWARD + + @property + def is_unsharded(self) -> bool: + return self._sharded_state == ShardedState.UNSHARDED + + @contextlib.contextmanager + def use_training_state(self, training_state: TrainingState): + old_training_state = self._training_state + self._training_state = training_state + try: + yield + finally: + self._training_state = old_training_state + + # Hook Registration # + def _register_post_backward_hook( + self, args: Tuple[Any, ...], kwargs: Dict[str, Any] + ) -> Tuple[Tuple[Any, ...], Dict[str, Any]]: + # Traceable FSDP2 relies on `root_post_backward_callback` to call each + # `FSDPParamGroup.post_backward` + if (not core._dynamo.config.skip_fsdp_hooks) or compiled_autograd_enabled(): + return args, kwargs + if not core.is_grad_enabled(): + return args, kwargs + args_list, args_spec = tree_flatten(args) + kwargs_list, kwargs_spec = tree_flatten(kwargs) + args_kwargs_list = list(args_list) + list(kwargs_list) + inp_tensor_indices: List[int] = [] + inp_tensors: List[core.Tensor] = [] + for i, obj in enumerate(args_kwargs_list): + if core.is_tensor(obj) and obj.requires_grad: + inp_tensor_indices.append(i) + inp_tensors.append(obj) + if len(inp_tensors) == 0: + return args, kwargs # no tensors that require gradients + inp_tensors = RegisterPostBackwardFunction.apply(self, *inp_tensors) + for inp_tensor_idx, inp_tensor in zip(inp_tensor_indices, inp_tensors): + args_kwargs_list[inp_tensor_idx] = inp_tensor + args_list = args_kwargs_list[: len(args_list)] + kwargs_list = args_kwargs_list[len(args_list) :] + args = tree_unflatten(args_list, args_spec) + kwargs = tree_unflatten(kwargs_list, kwargs_spec) + return args, kwargs + + def _register_state_dict_hooks(self) -> None: + num_pre_save_hooks = len(self._module_to_pre_save_state_dict_hook_handle) + num_pre_load_hooks = len(self._module_to_pre_load_state_dict_hook_handle) + assert ( + num_pre_save_hooks == num_pre_load_hooks + ), f"Pre-save: {num_pre_save_hooks} pre-load: {num_pre_load_hooks}" + if num_pre_save_hooks > 0: + return # already registered + modules_with_fsdp_params: Set[nn.Module] = { + fsdp_param._module_info.module for fsdp_param in self.fsdp_params + } + + def to_sharded_hook(*args: Any, **kwargs: Any) -> None: + self._to_sharded() + + for module in modules_with_fsdp_params: + self._module_to_pre_save_state_dict_hook_handle[ + module + ] = module.register_state_dict_pre_hook(to_sharded_hook) + self._module_to_pre_load_state_dict_hook_handle[ + module + ] = module._register_load_state_dict_pre_hook(to_sharded_hook) + + # Properties # + @property + def _reshard_after_forward(self) -> bool: + return self.post_forward_mesh_info is not None + + @property + def _use_post_forward_mesh(self) -> bool: + return ( + self._reshard_after_forward + and self.mesh_info != self.post_forward_mesh_info + ) + + @property + def _is_hsdp(self) -> bool: + return isinstance(self.mesh_info, HSDPMeshInfo) + + @property + def _all_gather_process_group(self) -> dist.ProcessGroup: + mesh_info = ( + cast(FSDPMeshInfo, self.post_forward_mesh_info) + if self.is_sharded_post_forward + else self.mesh_info + ) + assert isinstance(mesh_info, FSDPMeshInfo) + return mesh_info.shard_process_group + + @property + def _reduce_scatter_process_group(self) -> dist.ProcessGroup: + assert isinstance(self.mesh_info, FSDPMeshInfo) + return self.mesh_info.shard_process_group + + @property + def _all_reduce_process_group(self) -> dist.ProcessGroup: + assert isinstance(self.mesh_info, HSDPMeshInfo) + return self.mesh_info.replicate_process_group + + def _with_fqn(self, label: str) -> str: + if self._module_fqn: + return f"{label} ({self._module_fqn})" + return label + + def __repr__(self): + return f"FSDPParamGroup(fqn={self._module_fqn})" + + def _validate_no_meta_params(self): + param_names_on_meta = [ + fsdp_param._param_fqn + for fsdp_param in self.fsdp_params + if fsdp_param.sharded_param.device.type == "meta" + ] + if param_names_on_meta: + raise RuntimeError( + "FSDP parameters should be materialized from meta device before training, " + f"but the following were still on meta device: {param_names_on_meta}\n" + "For example, call module.to_empty(device) to materialize to device and " + "call module.reset_parameters() on each module to initialize values." + ) + + def _validate_cpu_offload_params(self): + if not isinstance(self.offload_policy, CPUOffloadPolicy): + return + fsdp_params_not_on_cpu = [ + fsdp_param + for fsdp_param in self.fsdp_params + if fsdp_param.sharded_param.device.type != "cpu" + ] + if fsdp_params_not_on_cpu: + raise RuntimeError( + "FSDP parameters should be materialized on CPU when enabling CPU offloading. " + 'For example, load a CPU state dict or call module.to_empty(device="cpu"). ' + "Found following parameters on non-CPU device: " + f"{[(fsdp_param._param_fqn, fsdp_param.sharded_param.device) for fsdp_param in fsdp_params_not_on_cpu]}\n" + ) + + +def _get_param_module_infos( + params: List[nn.Parameter], modules: Tuple[nn.Module, ...] +) -> List[ParamModuleInfo]: + """ + Shared parameter: lin1.weight = lin2.weight + Shared module: mlp.lin1 = mlp.lin2 + We do not remove duplicates when traversing both modules and parameters to + find shared modules' parameters and shared parameters within a module. + """ + params_set = set(params) + param_to_module_info: Dict[nn.Parameter, ParamModuleInfo] = {} + for module in modules: + for _, submodule in module.named_modules(remove_duplicate=False): + for param_name, param in _named_parameters_with_duplicates( + submodule, recurse=False + ): + if param in params_set: + if param not in param_to_module_info: + param_to_module_info[param] = ParamModuleInfo( + submodule, param_name + ) + else: + param_to_module_info[param].shared_modules.append(submodule) + param_to_module_info[param].shared_param_names.append( + param_name + ) + if len(param_to_module_info) != len(params): + raise AssertionError(f"Some parameters are not in the module tree of {module}") + return [param_to_module_info[param] for param in params] + + +class RegisterPostBackwardFunction(core.autograd.Function): + @staticmethod + def _assert_not_tracing_fsdp(): + if compiled_autograd_enabled(): + # TODO: Find a way to print the offending FSDP2 module. + msg = """\ +When Traceable FSDP2 is enabled, we should not be calling into `RegisterPostBackwardFunction`. +Instead, we rely on the param group's next `pre_backward` hook to trigger its previously unexecuted +`post_backward`, and we rely on FSDPState's `root_post_backward_callback` to trigger the resharding +of any leftover unsharded param groups. +If you are here, it means the forward part of this FSDP2 instance is not compiled, and you must also +compile the forward part if you want to use Traceable FSDP2.""" + core._dynamo.comptime.comptime.print(msg) + raise RuntimeError(msg) + + @staticmethod + def forward(ctx, param_group: FSDPParamGroup, *inputs: core.Tensor): + # All tensors in `inputs` should require gradient + RegisterPostBackwardFunction._assert_not_tracing_fsdp() + ctx.param_group = param_group + return inputs + + @staticmethod + def backward(ctx, *grads: core.Tensor): + RegisterPostBackwardFunction._assert_not_tracing_fsdp() + ctx.param_group.post_backward() + return (None,) + grads diff --git a/mindnlp/core/distributed/_composable/fsdp/_fsdp_state.py b/mindnlp/core/distributed/_composable/fsdp/_fsdp_state.py new file mode 100644 index 000000000..4c2bf9109 --- /dev/null +++ b/mindnlp/core/distributed/_composable/fsdp/_fsdp_state.py @@ -0,0 +1,394 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +import functools +import logging +from typing import ( + Any, + Callable, + Dict, + List, + Optional, + Sequence, + Set, + Tuple, + TYPE_CHECKING, +) + +from mindnlp import core +from mindnlp import core.nn as nn +from core._logging import warning_once +from core.autograd import Variable +from core.autograd.graph import _MultiHandle +from core.distributed._composable_state import ( + _get_module_state, + _insert_module_state, + _State, +) +from core.distributed.device_mesh import _get_device_handle +from core.distributed.utils import _to_kwargs +from core.utils._pytree import tree_flatten, tree_map + +from ._fsdp_api import MixedPrecisionPolicy +from ._fsdp_common import ( + _cast_fp_tensor, + compiled_autograd_enabled, + detect_compiled_autograd, + TrainingState, +) +from ._fsdp_param_group import FSDPCommContext, FSDPParamGroup + + +if TYPE_CHECKING: + from ._fsdp_param import FSDPParam + + +logger = logging.getLogger("core.distributed._composable.fsdp") + + +class FSDPStateContext: + """This has state shared across FSDP states.""" + + def __init__(self) -> None: + # All FSDP states in the root state's module tree + self.all_states: List[FSDPState] = [] + # Iteration's forward root runs the once-per-forward logic; this root + # may not be the overall root set by lazy initialization in cases where + # only a submodule runs forward (e.g. encoder-only for eval) + self.iter_forward_root: Optional[FSDPState] = None + # Final callback should only be queued once per backward + self.post_backward_final_callback_queued: bool = False + # Whether to finalize backward in this backward's final callback + self.is_last_backward: bool = True + # Optional user-provided event recorded after optimizer for the + # all-gather streams to wait on in the root pre-forward + self.post_optim_event: Optional[core.Event] = None + + +def disable_if_config_true(func): + @functools.wraps(func) + def fsdp_hook_wrapper(*args, **kwargs): + if core._dynamo.config.skip_fsdp_hooks: + return core._dynamo.disable(func, recursive=True)(*args, **kwargs) + else: + return func(*args, **kwargs) + + return fsdp_hook_wrapper + + +class FSDPState(_State): + def __init__(self) -> None: + super().__init__() + self._fsdp_param_group: Optional[FSDPParamGroup] = None + self._is_root: Optional[bool] = None # root set during lazy init + self._state_ctx = FSDPStateContext() + self._comm_ctx = FSDPCommContext() + self._training_state: TrainingState = TrainingState.IDLE + self._states_to_forward_prefetch: List[FSDPState] = [] + self._states_to_backward_prefetch: List[FSDPState] = [] + self._modules_to_run_forward: Set[nn.Module] = set() + + # Define a separate init since `__init__` is called in the contract + def init( + self, + modules: Tuple[nn.Module, ...], + device: core.device, + mp_policy: MixedPrecisionPolicy, + ) -> None: + for module in modules: + _insert_module_state(module, self) + self._modules = modules + self._device = device + self._device_handle = _get_device_handle(device.type) + self._mp_policy = mp_policy + if len(modules) == 1: + self._pre_forward_hook_handle = modules[0].register_forward_pre_hook( + self._pre_forward, prepend=True, with_kwargs=True + ) + self._post_forward_hook_handle = modules[0].register_forward_hook( + self._post_forward, prepend=False + ) + else: + hook_handle = _register_group_forward_hooks( + modules, + self._pre_forward, + self._post_forward, + self._modules_to_run_forward, + ) + self._pre_forward_hook_handle = hook_handle + self._post_forward_hook_handle = hook_handle + + def _root_pre_forward( + self, module: nn.Module, args: Tuple[Any, ...], kwargs: Dict[str, Any] + ) -> Tuple[Tuple[Any, ...], Dict[str, Any]]: + self._lazy_init() + if self._state_ctx.iter_forward_root is not None: + return args, kwargs + if not compiled_autograd_enabled(): + logger.debug("FSDP::root_pre_forward") + self._state_ctx.iter_forward_root = self + with core.profiler.record_function("FSDP::root_pre_forward"): + # Wait for optimizer before implicitly prefetched all-gathers + if (event := self._state_ctx.post_optim_event) is not None: + self._comm_ctx.all_gather_copy_in_stream.wait_event(event) + self._comm_ctx.all_gather_stream.wait_event(event) + self._state_ctx.post_optim_event = None + else: + current_stream = self._device_handle.current_stream() + self._comm_ctx.all_gather_copy_in_stream.wait_stream(current_stream) + self._comm_ctx.all_gather_stream.wait_stream(current_stream) + if self._device.type in ["cuda", "hpu"]: + with core.profiler.record_function("FSDP::inputs_to_device"): + args_tuple, kwargs_tuple = _to_kwargs( + args, kwargs, self._device, False + ) # same as DDP + args, kwargs = args_tuple[0], kwargs_tuple[0] + return args, kwargs + + def _lazy_init(self) -> None: + """ + Lazy initialization represents when all modules' parallelisms have + finalized (e.g. FSDP has been applied to all desired modules). This + means that we can determine which state is the root, and we do so by + the 1st state to run forward. + """ + if self._is_root is not None: + return # no-op: already initialized + self._is_root = True + if len(self._modules) > 1: + raise RuntimeError( + f"FSDP requires a single root module but got {self._modules}" + ) + detect_compiled_autograd() + root_module = self._modules[0] + visited_states: Set[FSDPState] = set() + for module_name, module in root_module.named_modules(): + if (state := _get_module_fsdp_state(module)) is None: + continue + if module is not root_module: + if state not in visited_states and state._is_root is not None: + raise RuntimeError( + "FSDP state has already been lazily initialized for " + f"{module_name}\nFSDP requires running forward through " + "the root module first" + ) + state._is_root = False + self._state_ctx.all_states.append(state) + visited_states.add(state) + if self._fsdp_param_group: + # For the root, do not reshard after forward since for training, + # the parameters would be freed and all-gathered immediately + self._fsdp_param_group.post_forward_mesh_info = None + self._init_fqns() + self._init_shared_state() + # Run parameter group lazy inits after initializing FQNs for improved + # error messages + for state in self._state_ctx.all_states: + if state._fsdp_param_group: + state._fsdp_param_group.lazy_init() + + def _init_shared_state(self) -> None: + self._comm_ctx.lazy_init(self._device) + for state in self._state_ctx.all_states: + state._state_ctx = self._state_ctx + state._comm_ctx = self._comm_ctx + if fsdp_param_group := state._fsdp_param_group: + fsdp_param_group.comm_ctx = self._comm_ctx + + def _init_fqns(self) -> None: + """Sets module and parameter FQN attributes for debugging.""" + assert self._is_root + root_module = self._modules[0] + param_to_fsdp_param: Dict[nn.Parameter, FSDPParam] = {} + module_to_fsdp_param_group: Dict[nn.Module, FSDPParamGroup] = {} + for state in self._state_ctx.all_states: + if fsdp_param_group := state._fsdp_param_group: + for fsdp_param in fsdp_param_group.fsdp_params: + param_to_fsdp_param[fsdp_param.sharded_param] = fsdp_param + for module in fsdp_param_group.modules: + module_to_fsdp_param_group[module] = fsdp_param_group + for param_name, param in root_module.named_parameters(): + if param in param_to_fsdp_param: + param_to_fsdp_param[param]._param_fqn = param_name + for module_name, module in root_module.named_modules(): + if module in module_to_fsdp_param_group: + module_fqn = module_to_fsdp_param_group[module]._module_fqn + if module_fqn is None: + module_to_fsdp_param_group[module]._module_fqn = module_name + else: + assert isinstance(module_fqn, str), f"{module_fqn}" + module_fqn += f", {module_name}" + module_to_fsdp_param_group[module]._module_fqn = module_fqn + + @disable_if_config_true + def _pre_forward( + self, module: nn.Module, args: Tuple[Any, ...], kwargs: Dict[str, Any] + ) -> Tuple[Tuple[Any, ...], Dict[str, Any]]: + # When composing with module-hook-based activation checkpointing, the + # the pre-backward hook is responsible for the unshard + if self._training_state == TrainingState.PRE_BACKWARD: + return args, kwargs + self._training_state = TrainingState.FORWARD + args, kwargs = self._root_pre_forward(module, args, kwargs) + if self._mp_policy.cast_forward_inputs and self._mp_policy.param_dtype: + with core.profiler.record_function("FSDP::cast_forward_inputs"): + cast_fn = functools.partial( + _cast_fp_tensor, self._mp_policy.param_dtype + ) + args, kwargs = tree_map(cast_fn, args), tree_map(cast_fn, kwargs) + if self._fsdp_param_group: + args, kwargs = self._fsdp_param_group.pre_forward(module, args, kwargs) + for fsdp_state in self._states_to_forward_prefetch: + if (target_param_group := fsdp_state._fsdp_param_group) is not None: + FSDPParamGroup._prefetch_unshard(target_param_group, "forward") + return args, kwargs + + @disable_if_config_true + def _post_forward(self, module: nn.Module, input: Any, output: Any) -> Any: + # When composing with module-hook-based activation checkpointing, the + # post-backward hook is responsible for the reshard + if self._training_state == TrainingState.PRE_BACKWARD: + return output + if self._fsdp_param_group: + output = self._fsdp_param_group.post_forward(module, input, output) + output = self._register_pre_backward_hook(output) + self._training_state = TrainingState.IDLE + if self._state_ctx.iter_forward_root is self: + if all_gather_state := self._comm_ctx.all_gather_state: + # Free the last all-gather result if needed; refer to + # [Note: Overlapping all-gather copy-in and all-gather] + self._comm_ctx.all_gather_copy_in_stream.wait_event( + all_gather_state.event + ) + self._comm_ctx.all_gather_stream.wait_event(all_gather_state.event) + self._comm_ctx.all_gather_state = None # free the all-gather result + self._state_ctx.iter_forward_root = None + if self._mp_policy.output_dtype is not None: + with core.profiler.record_function("FSDP::cast_forward_outputs"): + output = tree_map( + functools.partial(_cast_fp_tensor, self._mp_policy.output_dtype), + output, + ) + return output + + def _pre_backward(self, grad: core.Tensor) -> core.Tensor: + self._training_state = TrainingState.PRE_BACKWARD + self._register_root_post_backward_final_callback() + if self._fsdp_param_group: + default_prefetch = len(self._states_to_backward_prefetch) == 0 + self._fsdp_param_group.pre_backward(default_prefetch) + for fsdp_state in self._states_to_backward_prefetch: + if (target_param_group := fsdp_state._fsdp_param_group) is not None: + FSDPParamGroup._prefetch_unshard(target_param_group, "backward") + return grad + + def _root_post_backward_final_callback(self) -> None: + if not compiled_autograd_enabled(): + logger.debug("FSDP::root_post_backward") + with core.profiler.record_function("FSDP::root_post_backward_callback"): + for state in self._state_ctx.all_states: + fsdp_param_group = state._fsdp_param_group + if fsdp_param_group and ( + fsdp_param_group.is_unsharded + or not fsdp_param_group.unshard_in_backward + ): + # Run post-backward in case forward inputs did not require + # gradient so the autograd backward did not run + fsdp_param_group.post_backward() + state._training_state = TrainingState.IDLE + if fsdp_param_group: + fsdp_param_group._training_state = TrainingState.IDLE + if self._state_ctx.is_last_backward: + state._finalize_backward() + if self._state_ctx.is_last_backward: + self._comm_ctx.post_forward_order.clear() + if self._comm_ctx.reduce_scatter_state is not None: + self._device_handle.current_stream().wait_event( + self._comm_ctx.reduce_scatter_state.event + ) + self._comm_ctx.reduce_scatter_state = None + self._state_ctx.post_backward_final_callback_queued = False + + def _finalize_backward(self) -> None: + if self._modules_to_run_forward: + msg = ( + f"{len(self._modules_to_run_forward)} of the {len(self._modules)} " + f"modules passed to fully_shard did not run forward before backward, " + "which is error-prone since FSDP post-forward/pre-backward logic " + "will not run for these modules. We recommend passing only modules " + "that run forward together. Modules that did not run forward: " + f"{list(self._modules_to_run_forward)}" + ) + warning_once(logger, msg, stacklevel=2) + # Clear since we want the next forward to run + self._modules_to_run_forward.clear() + if self._fsdp_param_group: + self._fsdp_param_group.finalize_backward() + + def _register_pre_backward_hook(self, output: Any) -> Any: + if not core.is_grad_enabled(): + return output + flat_outputs, _ = tree_flatten(output) + for t in flat_outputs: + if core.is_tensor(t) and t.requires_grad: + t.register_hook(self._pre_backward) + return output + + def _register_root_post_backward_final_callback(self): + if self._state_ctx.post_backward_final_callback_queued: + return + self._state_ctx.post_backward_final_callback_queued = True + Variable._execution_engine.queue_callback( + self._root_post_backward_final_callback + ) + + +def _get_module_fsdp_state(module: nn.Module) -> Optional[FSDPState]: + state = _get_module_state(module) + if isinstance(state, FSDPState): + return state + return None + + +def _register_group_forward_hooks( + modules: Sequence[nn.Module], + pre_hook: Callable, + post_hook: Callable, + modules_to_run: Set[nn.Module], +): + """ + Registers group forward pre and post-hooks. The pre-hook runs upon the + first module pre-forward, and the post-hook runs upon the last. If at least + one module does not run forward, then the post-hook does not run. + """ + modules_set = set(modules) + + @disable_if_config_true + @functools.wraps(pre_hook) + def wrapped_pre_hook(*args: Any, **kwargs: Any): + if len(modules_to_run) == 0: # first to run + modules_to_run.update(modules_set) + return pre_hook(*args, **kwargs) + + @disable_if_config_true + def get_wrapped_post_hook(module: nn.Module): + @functools.wraps(post_hook) + def wrapped_post_hook(*args: Any, **kwargs: Any): + modules_to_run.discard(module) + if len(modules_to_run) == 0: + return post_hook(*args, **kwargs) + + return wrapped_post_hook + + pre_handles = [ + module.register_forward_pre_hook( + wrapped_pre_hook, prepend=True, with_kwargs=True + ) + for module in modules + ] + post_handles = [ + module.register_forward_hook( + get_wrapped_post_hook(module), prepend=False, always_call=True + ) + for module in modules + ] + return _MultiHandle(tuple(pre_handles + post_handles)) diff --git a/mindnlp/core/distributed/_composable/fsdp/fully_shard.py b/mindnlp/core/distributed/_composable/fsdp/fully_shard.py new file mode 100644 index 000000000..b896cf802 --- /dev/null +++ b/mindnlp/core/distributed/_composable/fsdp/fully_shard.py @@ -0,0 +1,501 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +import functools +from typing import ( + Any, + Callable, + cast, + Dict, + Iterable, + List, + NoReturn, + Optional, + Type, + Union, +) + +from mindnlp import core +from mindnlp import core.nn as nn +from core.distributed._composable import contract +from core.distributed.tensor import DeviceMesh, Shard +from core.distributed.utils import _get_root_modules + +from ._fsdp_api import MixedPrecisionPolicy, OffloadPolicy +from ._fsdp_common import FSDPMeshInfo, HSDPMeshInfo +from ._fsdp_init import ( + _get_device_from_mesh, + _get_managed_modules, + _get_managed_states, + _get_post_forward_mesh_info, + _init_default_fully_shard_mesh, + _move_states_to_device, +) +from ._fsdp_param_group import FSDPParamGroup +from ._fsdp_state import _get_module_fsdp_state, FSDPState + + +cls_to_fsdp_cls: Dict[Type, Type] = {} + + +# The decorator adds a state object to `module` that can be accessed via +# `fully_shard.state(module)`. The state object and module are 1:1. +@contract(state_cls=FSDPState) # type: ignore[operator] +def fully_shard( + module: Union[nn.Module, List[nn.Module]], + *, + mesh: Optional[DeviceMesh] = None, + reshard_after_forward: Union[bool, int] = True, + shard_placement_fn: Optional[Callable[[nn.Parameter], Optional[Shard]]] = None, + mp_policy: MixedPrecisionPolicy = MixedPrecisionPolicy(), + offload_policy: OffloadPolicy = OffloadPolicy(), +): + """ + Shard module parameters across data parallel workers. + + This function applies fully sharded data parallelism (FSDP) or a variant to + ``module``, a technique for memory savings at the cost of communication. + Parameters are sharded across ``mesh``, and in turn, so are their gradients + and optimizer states. + + The sharded parameters are all-gathered to construct the unsharded + parameters for forward or backward computation. The unsharded parameters + are freed after computation to save memory. The gradients are reduced + across the mesh and divided by the mesh size for data parallelism. The + optimizer step runs on the sharded parameters. + + Each call to ``fully_shard`` constructs one communication group that + includes the parameters in ``module.parameters()`` except those already + assigned to a group from a nested call. Each group's parameters and its + gradients are communicated together in one collective, respectively. + Constructing multiple groups across the model (e.g. "layer by layer") + allows for peak memory savings and communication/computation overlap. + + Implementation-wise, the sharded parameters are represented as + :class:`DTensor` s, sharded on dim-0, and the unsharded parameters are + represented as :class:`Tensor` s. A module forward pre-hook all-gathers the + parameters, and a module forward hook frees them. Similar backward hooks + gather parameters and later free parameters/reduce gradients. + + Args: + module (Union[nn.Module, List[nn.Module]): The module or modules to + shard with FSDP and group together for communication. + mesh (Optional[DeviceMesh]): This data parallel mesh defines the + sharding and device. If 1D, then parameters are fully sharded + across the 1D mesh (FSDP). If 2D, then parameters are sharded + across the 0th dim and replicated across the 1st dim (HSDP). The + mesh's device type gives the device type used for communication; + if a CUDA or CUDA-like device type, then we use the current device. + reshard_after_forward (Union[bool, int]): This controls the parameter + behavior after forward and can trade off memory and communication: + - If ``True``, then this reshards parameters after forward and + all-gathers in backward. + - If ``False``, then this keeps the unsharded parameters in memory + after forward and avoids the all-gather in backward. + - If an ``int``, then this represents the world size to reshard to + after forward. It should be a non-trivial divisor of the ``mesh`` + shard dim size (i.e. excluding 1 and the dim size itself). A choice + may be the intra-node size (e.g. ``core.cuda.device_count()``). + This allows the all-gather in backward to be over a smaller world + size at the cost of higher memory usage than setting to ``True``. + - The root FSDP state has its value specially set to ``False`` as a + heuristic since its parameters would typically be immediately + all-gathered for backward. + - After forward, the parameters registered to the module depend on + to this: The registered parameters are the sharded parameters if + ``True``; unsharded parameters if ``False``; and the paramters + resharded to the smaller mesh otherwise. To modify the parameters + between forward and backward, the registered parameters must be the + sharded parameters. For ``False`` or an ``int``, this can be done + by manually resharding via :meth:`reshard`. + shard_placement_fn (Optional[Callable[[nn.Parameter], Optional[Shard]]]): + This callable can be used to override the sharding placement for a + parameter to shard a parameter on a dimension other than dim-0. If + this callable returns a ``Shard`` placement (not ``None``), then + FSDP will shard according to that placement (e.g. ``Shard(1)``). + If sharding on a nonzero dim, we currently require even sharding, + i.e. the tensor dim size on that dim must be divisible by the FSDP + shard mesh size. + mp_policy (MixedPrecisionPolicy): This controls the mixed precision + policy, which offers parameter/reduction mixed precision for this + module. See :class:`MixedPrecisionPolicy` for details. + offload_policy (OffloadPolicy): This controls the offloading policy, + which offers parameter/gradient/optimizer state offloading. See + :class:`OffloadPolicy` and its subclasses for details. + """ + if isinstance(module, (nn.ModuleList, nn.ModuleDict)): + raise ValueError( + f"fully_shard does not support containers that do not implement forward: {module}" + ) + mesh = mesh or _init_default_fully_shard_mesh() + if mesh.ndim not in (1, 2): + raise ValueError(f"fully_shard expects a 1D or 2D DeviceMesh but got {mesh}") + elif mesh.ndim == 1: + mesh_info = FSDPMeshInfo(mesh, shard_mesh_dim=0) + else: + if mesh.mesh_dim_names is None: + raise AssertionError( + "Please init the 2D mesh for HSDP with mesh_dim_names specified" + ) + mesh_info = HSDPMeshInfo(mesh, shard_mesh_dim=1, replicate_mesh_dim=0) + device = _get_device_from_mesh(mesh) + post_forward_mesh_info = _get_post_forward_mesh_info( + reshard_after_forward, mesh_info + ) + + arg_module = module + modules = ( + (module,) if isinstance(module, nn.Module) else tuple(_get_root_modules(module)) + ) + state = fully_shard.state(modules[0]) + state.init(modules, device, mp_policy) + + managed_modules = _get_managed_modules(modules) + params, buffers = _get_managed_states(managed_modules) + _move_states_to_device(params, buffers, device) + if params: + state._fsdp_param_group = FSDPParamGroup( + params, + modules, + mesh_info, + post_forward_mesh_info, + device, + shard_placement_fn, + mp_policy, + offload_policy, + ) + + # For Dynamo + for managed_module in managed_modules: + managed_module._is_fsdp_managed_module = True # type: ignore[assignment] + managed_module._fsdp_use_orig_params = True # type: ignore[assignment] + + # Place FSDP leftmost for highest priority in the method resolution order + for module in modules: + cls = module.__class__ + new_cls = cls_to_fsdp_cls.get(cls, None) + if not new_cls: + dct = {"__deepcopy__": unimplemented_deepcopy} + new_cls = type(f"FSDP{cls.__name__}", (FSDPModule, cls), dct) + cls_to_fsdp_cls[cls] = new_cls + module.__class__ = new_cls + return arg_module + + +def unimplemented_deepcopy(*args: Any, **kwargs: Any) -> NoReturn: + raise AssertionError( + "FSDP does not support deepcopy. Please use state dict for serialization." + ) + + +class FSDPModule: + def __new__(cls, *args, **kwargs): + """ + Override ``__new__`` to remove the FSDP class and directly construct + the original class for cases like indexing into a container module. + """ + # Use index 2 since 0 is the dynamically constructed `FSDP<...>` class + # and index 1 is the `FSDPModule` class itself + orig_cls = cls.__mro__[2] + self = orig_cls.__new__(orig_cls, *args, **kwargs) + self.__init__(*args, **kwargs) + return self + + def reshard(self) -> None: + """ + Reshards the module's parameters, registering the sharded parameters + to the module and freeing the unsharded parameters if needed. This + method is *not* recursive. + """ + state = self._get_fsdp_state() + if fsdp_param_group := state._fsdp_param_group: + fsdp_param_group.reshard() + + def unshard(self, async_op: bool = False) -> Optional["UnshardHandle"]: + """ + Unshards the module's parameters by allocating memory and all-gathering + the parameters. This method is *not* recursive. + + Args: + async_op (bool): If ``True``, then returns a :class:`UnshardHandle` + that has a :meth:`wait` method to wait on the unshard op. If + ``False``, then returns ``None`` and waits on the handle inside + this function. + + .. warning:: This method is experimental and subject to change. + + .. note:: If ``async_op=True``, then the user does not have to call + :meth:`wait` on the returned handle if waiting on the unshard op + in the module's pre-forward is tolerable. FSDP will wait on the + pending unshard op in the pre-forward automatically. + """ + state = self._get_fsdp_state() + fsdp_param_group = state._fsdp_param_group + if fsdp_param_group is not None: + fsdp_param_group.lazy_init() + fsdp_param_group.unshard(async_op=async_op) + handle = UnshardHandle(fsdp_param_group) + if async_op: + return handle + handle.wait() + return None + + def set_is_last_backward(self, is_last_backward: bool) -> None: + """ + Sets whether the next backward is the last one, meaning that FSDP + should wait for gradient reduction to finish and clear internal data + structures used for explicit prefetching. + """ + state = self._get_fsdp_state() + state._state_ctx.is_last_backward = is_last_backward + + def set_requires_gradient_sync( + self, requires_gradient_sync: bool, *, recurse: bool = True + ) -> None: + """ + Sets if the module should sync gradients. This can be used to implement + gradient accumulation without communication. For HSDP, this controls + both reduce-scatter and all-reduce together. + + Args: + requires_gradient_sync (bool): Whether to reduce gradients for the + module's parameters. + recurse (bool): Whether to set for all submodules or just the + passed-in module. + """ + self_module = cast(nn.Module, self) + modules = list(self_module.modules()) if recurse else [self_module] + for module in modules: + if isinstance(module, FSDPModule): + state = module._get_fsdp_state() + if fsdp_param_group := state._fsdp_param_group: + fsdp_param_group.reduce_grads = requires_gradient_sync + fsdp_param_group.all_reduce_grads = requires_gradient_sync + + def set_requires_all_reduce( + self, requires_all_reduce: bool, *, recurse: bool = True + ) -> None: + """ + Sets if the module should all-reduce gradients. This can be used to + implement gradient accumulation with only reduce-scatter but not + all-reduce for HSDP. + """ + self_module = cast(nn.Module, self) + modules = list(self_module.modules()) if recurse else [self_module] + for module in modules: + if isinstance(module, FSDPModule): + state = module._get_fsdp_state() + if fsdp_param_group := state._fsdp_param_group: + fsdp_param_group.all_reduce_grads = requires_all_reduce + + def set_reshard_after_backward( + self, reshard_after_backward: bool, *, recurse: bool = True + ) -> None: + """ + Sets if the module should reshard parameters after backward. This can + be used during gradient accumulation to trade off higher memory for + reduced communication. + + Args: + reshard_after_backward (bool): Whether to reshard parameters after + backward. + recurse (bool): Whether to set for all submodules or just the + passed-in module. + """ + self_module = cast(nn.Module, self) + modules = list(self_module.modules()) if recurse else [self_module] + for module in modules: + if isinstance(module, FSDPModule): + state = module._get_fsdp_state() + if fsdp_param_group := state._fsdp_param_group: + fsdp_param_group.reshard_after_backward = reshard_after_backward + + def set_modules_to_forward_prefetch(self, modules: List["FSDPModule"]) -> None: + """ + Sets the FSDP modules for which this FSDP module should explicitly + prefetch all-gathers in forward. The prefetching runs after this + module's all-gather copy-out. + + Passing a singleton list containing the next FSDP module gives the same + all-gather overlap behavior as the default overlap behavior, except the + prefetched all-gather is issued earlier from the CPU. Passing a list + with at least length two is required for more aggressive overlap and + will use more reserved memory. + + Args: + modules (List[FSDPModule]): FSDP modules to prefetch. + """ + _assert_all_fsdp_modules(modules) + self._get_fsdp_state()._states_to_forward_prefetch = [ + module._get_fsdp_state() for module in modules + ] + + def set_modules_to_backward_prefetch(self, modules: List["FSDPModule"]) -> None: + """ + Sets the FSDP modules for which this FSDP module should explicitly + prefetch all-gathers in backward. This overrides the default backward + pretching implementation that prefetches the next FSDP module based on + the reverse post-forward order. + + Passing a singleton list containing the previous FSDP module gives the + same all-gather overlap behavior as the default overlap behavior. + Passing a list with at least length two is required for more aggressive + overlap and will use more reserved memory. + + Args: + modules (List[FSDPModule]): FSDP modules to prefetch. + """ + _assert_all_fsdp_modules(modules) + self._get_fsdp_state()._states_to_backward_prefetch = [ + module._get_fsdp_state() for module in modules + ] + + def set_post_optim_event(self, event: core.Event) -> None: + """ + Sets a post-optimizer-step event for the root FSDP module to wait the + all-gather streams on. + + By default, the root FSDP module waits the all-gather streams on the + current stream to ensure that the optimizer step has finished before + all-gathering. However, this may introduce false dependencies if + there is unrelated computation after the optimizer step. This API + allows the user to provide their own event to wait on. After the root + waits on the event, the event is discarded, so this API should be + called with a new event each iteration. + + Args: + event (core.Event): Event recorded after the optimizer step + to wait all-gather streams on. + """ + self._get_fsdp_state()._state_ctx.post_optim_event = event + + def set_reduce_scatter_divide_factor(self, factor: float) -> None: + """ + Sets a custom divide factor for the reduce-scatter. This becomes a + custom reduce op using NCCL's PreMulSum, which allows multiplying by + the factor before reduction. + + Args: + factor (float): Custom divide factor. + """ + state = self._get_fsdp_state() + if (fsdp_param_group := state._fsdp_param_group) is not None: + mul_factor = 1.0 / float(factor) + reduce_op = core.distributed._make_nccl_premul_sum(mul_factor) + fsdp_param_group.reduce_scatter_reduce_op = reduce_op + + def set_unshard_in_backward(self, unshard_in_backward: bool) -> None: + """ + Sets whether the FSDP module's parameters need to be unsharded in + backward. This can be used in expert cases when the user knows that all + parameters in this FSDP module's parameter group are not needed for + backward computation (e.g. embedding). + """ + state = self._get_fsdp_state() + if (fsdp_param_group := state._fsdp_param_group) is not None: + fsdp_param_group.unshard_in_backward = unshard_in_backward + + def _set_unshard_async_op(self, async_op: bool): + """ + Sets whether to use ``async_op=True`` or ``False`` for the pre-forward + and pre-backward unshard op. This defaults to ``False`` but can be set + to ``True`` with this method. + + Setting this to ``True`` allows the all-gather allocations to happen in + the default stream, avoiding inter-stream memory fragmentation. + However, you must use explicit prefetching (e.g. via :meth:`unshard`) + in forward to still get overlap, and the pre-all-gather ops like dtype + casting and copy-in will not overlap with compute. + """ + self_module = cast(nn.Module, self) + for module in self_module.modules(): + if isinstance(module, FSDPModule): + state = module._get_fsdp_state() + if fsdp_param_group := state._fsdp_param_group: + fsdp_param_group.unshard_async_op = async_op + + def _get_fsdp_state(self) -> FSDPState: + if (state := _get_module_fsdp_state(cast(nn.Module, self))) is None: + raise AssertionError(f"No FSDP state found on {self}") + return state + + def _apply(self, *args: Any, **kwargs: Any) -> Any: + # Reshard to ensure that sharded parameters are registered + self.reshard() + ret = super()._apply(*args, **kwargs) # type: ignore[misc] + state = self._get_fsdp_state() + if not (fsdp_param_group := state._fsdp_param_group): + return ret + # TODO: Remove this padding logic once DTensor pads the local tensor: + # https://github.com/pytorch/pytorch/issues/113045 + with core.no_grad(): + for fsdp_param in fsdp_param_group.fsdp_params: + fsdp_param.reset_sharded_param() + return ret + + +class UnshardHandle: + """ + A handle to wait on the unshard op. + + Args: + fsdp_param_group (FSDPParamGroup, optional): FSDP parameter group to + unshard. This should be ``None`` iff the FSDP module does not + manage any parameters, meaning the unshard is a no-op. + """ + + def __init__(self, fsdp_param_group: Optional[FSDPParamGroup]): + self._fsdp_param_group = fsdp_param_group + + def wait(self): + """ + Waits on the unshard op. + + This ensures that the current stream can use the unsharded parameters, + which are now registered to the module. + """ + if self._fsdp_param_group is not None: + self._fsdp_param_group.wait_for_unshard() + # Avoid keeping a reference + self._fsdp_param_group = None + + +def register_fsdp_forward_method(module: nn.Module, method_name: str) -> None: + """ + Registers a method on ``module`` to be a forward method for FSDP. + + FSDP only knows to run its pre-forward and post-forward hooks on the + default :meth:`nn.Module.forward` method. This function patches a user + specified method to run the pre/post-forward hooks before/after the method, + respectively. If ``module`` is not an :class:`FSDPModule`, then this is a + no-op. + + Args: + module (nn.Module): Module to register the forward method on. + method_name (str): Name of the forward method. + """ + if not isinstance(module, FSDPModule): + # Make no-op to allow including both when using/not using FSDP + return + if not hasattr(module, method_name): + raise ValueError(f"{type(module)} does not have a method {method_name}") + orig_method = getattr(module, method_name) + + @functools.wraps(orig_method) + def wrapped_method(self, *args, **kwargs): + fsdp_state = self._get_fsdp_state() + args, kwargs = fsdp_state._pre_forward(self, args, kwargs) + out = orig_method(*args, **kwargs) + return fsdp_state._post_forward(self, args, out) + + # Use `__get__` to make `wrapped_method` an instance method + setattr( + module, + method_name, + wrapped_method.__get__(module, type(module)), # type:ignore[attr-defined] + ) + + +def _assert_all_fsdp_modules(modules: Iterable[Any]) -> None: + for module in modules: + if not isinstance(module, FSDPModule): + raise ValueError(f"Expects FSDPModule but got {type(module)}: {module}") diff --git a/mindnlp/core/distributed/_composable/fully_shard.py b/mindnlp/core/distributed/_composable/fully_shard.py new file mode 100644 index 000000000..bd780d261 --- /dev/null +++ b/mindnlp/core/distributed/_composable/fully_shard.py @@ -0,0 +1,132 @@ +# mypy: allow-untyped-decorators +from typing import Callable, Iterable, Optional, Union +from typing_extensions import deprecated + +from mindnlp import core +from mindnlp import core.distributed as dist +from mindnlp import core.nn as nn +from core.distributed._composable.contract import contract +from core.distributed._composable_state import _get_module_state, _insert_module_state +from core.distributed.fsdp._common_utils import _FSDPState +from core.distributed.fsdp._dynamo_utils import _annotate_modules_for_dynamo +from core.distributed.fsdp._init_utils import ( + _init_buffer_state, + _init_core_state, + _init_device_handle, + _init_ignored_module_states, + _init_param_handle_from_module, + _init_prefetching_state, + _init_process_group_state, + _init_runtime_state, + _init_state_dict_state, + HYBRID_SHARDING_STRATEGIES, +) +from core.distributed.fsdp._runtime_utils import ( + _register_post_forward_hook, + _register_pre_forward_hook, + _register_root_pre_forward_hook, +) +from core.distributed.fsdp._state_dict_utils import _register_all_state_dict_hooks +from core.distributed.fsdp._wrap_utils import _auto_wrap +from core.distributed.fsdp.api import ( + BackwardPrefetch, + CPUOffload, + MixedPrecision, + ShardingStrategy, +) +from core.distributed.fsdp.wrap import _Policy + + +@contract(state_cls=_FSDPState) +@deprecated( + "`core.distributed._composable.fully_shard` is being deprecated. " + "You can continue to use the wrapper based FSDP. " + "See usage in: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/fully_sharded_data_parallel.py. " + "`core.distributed._composable.fully_shard` will be removed after PyTorch 2.5. " + "If you are looking for FSDP2, please see `core.distributed._composable.fsdp.fully_shard.`", + category=FutureWarning, +) +def fully_shard( + module: nn.Module, + *, + process_group: Optional[dist.ProcessGroup] = None, + policy: Optional[_Policy] = None, + strategy: Optional[ShardingStrategy] = None, + mixed_precision: Optional[MixedPrecision] = None, + cpu_offload: Optional[CPUOffload] = None, + ignored_modules: Optional[Iterable[core.nn.Module]] = None, + device_id: Optional[Union[int, core.device]] = None, + param_init_fn: Optional[Callable[[nn.Module], None]] = None, + sync_module_states: bool = False, + forward_prefetch: bool = False, + ignored_states: Union[ + Optional[Iterable[core.nn.Parameter]], Optional[Iterable[core.nn.Module]] + ] = None, +) -> nn.Module: + """Applies ``FullyShardedDataParallel`` (FSDP) semantics to ``module``.""" + core._C._log_api_usage_once("core.distributed.fully_shard") + # Enforce the new auto wrap policy + if policy is not None and not isinstance(policy, _Policy): + raise ValueError(f"Expects a `_Policy` but got {policy}") + state = fully_shard.state(module) + state = _init_ignored_module_states(state, module, ignored_modules, ignored_states) + state = _init_device_handle(state, module, state._ignored_params, device_id) + _annotate_modules_for_dynamo(module, state._ignored_modules, True) + state = _init_process_group_state(state, process_group, strategy, policy) + if policy is not None: + root_kwargs = { + "process_group": process_group, + "strategy": strategy, + "mixed_precision": mixed_precision, + "cpu_offload": cpu_offload, + "ignored_modules": ignored_modules, + "device_id": device_id, + "param_init_fn": param_init_fn, + "sync_module_states": sync_module_states, + "forward_prefetch": forward_prefetch, + "ignored_states": ignored_states, + } + if strategy in HYBRID_SHARDING_STRATEGIES: + root_kwargs["process_group"] = (state.process_group, state._inter_node_pg) + _auto_wrap( + module, + policy, + state._ignored_modules, + state._ignored_params, + root_kwargs, + fully_shard, + ) + state = _init_core_state( + state, + strategy or ShardingStrategy.FULL_SHARD, + mixed_precision, + cpu_offload, + limit_all_gathers=True, + use_orig_params=True, + backward_prefetch_limit=1, + forward_prefetch_limit=1, + ) + state = _init_runtime_state(state) + state = _init_prefetching_state( + state, BackwardPrefetch.BACKWARD_PRE, forward_prefetch=forward_prefetch + ) + state = _init_buffer_state(state, module) + state = _init_param_handle_from_module( + state, module, device_id, param_init_fn, sync_module_states + ) + state = _init_state_dict_state(state) + _register_all_state_dict_hooks(state) + _register_pre_forward_hook(state, module) + _register_post_forward_hook(state, module) + _register_root_pre_forward_hook(state, module) # prepend last + # Always insert the state for the passed-in module even if it has no + # managed parameters, in which case it has no handles and does not appear + # in `_fully_sharded_module_to_handles` + _insert_module_state(module, state) + for submodule in module.modules(): + if ( + submodule in state._fully_sharded_module_to_handle + and _get_module_state(submodule) is None + ): + _insert_module_state(submodule, state) + return module diff --git a/mindnlp/core/distributed/_composable/replicate.py b/mindnlp/core/distributed/_composable/replicate.py new file mode 100644 index 000000000..013e15316 --- /dev/null +++ b/mindnlp/core/distributed/_composable/replicate.py @@ -0,0 +1,254 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +import weakref +from typing import Any, cast, Dict, Iterable, List, NoReturn, Optional, Set, Tuple + +from mindnlp import core +from mindnlp import core.nn as nn +from core.distributed._composable_state import _State +from core.nn.parallel import DistributedDataParallel + +from .contract import _get_registry, contract + + +_ROOT_MODULE_PREFIX = "" + + +class _ReplicateState(_State): + def __init__(self) -> None: + super().__init__() + self.module: nn.Module = nn.ParameterList() + self.has_initialized: bool = False + self._param_list: nn.ParameterList = nn.ParameterList() + # TODO(@fegin): this variable is originally create for testing, we + # should remove this if possible. + self._orig_module = self.module + self._param_names: List[str] = [] + self._no_sync: bool = False + self._init_args: Optional[Tuple[Any, ...]] = None + self._init_kwargs: Dict[str, Any] = {} + self._comm_hook_args: List[Any] = [] + + def _collect_params( + self, + module: nn.Module, + ignored_modules: Set[nn.Module], + ignored_params: Set[nn.Parameter], + prefix: str = _ROOT_MODULE_PREFIX, + ) -> None: + # skip if managed by fully_sharded API + if _is_fully_sharded(module): + return + + # if a module is ignored, all descendants of the module are ignored. + if module in ignored_modules: + return + + recurse_prefix = ( + f"{prefix}." if prefix != _ROOT_MODULE_PREFIX else _ROOT_MODULE_PREFIX + ) + + for n, p in module.named_parameters(recurse=False): + if p not in ignored_params: + self._param_list.append(p) + self._param_names.append(f"{recurse_prefix}{n}") + + for name, child_module in module.named_children(): + self._collect_params( + child_module, + ignored_modules, + ignored_params, + prefix=f"{recurse_prefix}{name}", + ) + + def lazy_init(self) -> None: + @core._disable_dynamo(recursive=True) + def _lazy_init(): + assert self._init_args is not None + self.init(*self._init_args, **self._init_kwargs) + self.register_comm_hook() + self._init_args = () + self._init_kwargs = {} + + _lazy_init() + + def init( + self, + module: nn.Module, + ignored_modules: Set[nn.Module], + **kwargs, + ) -> None: + if self.has_initialized: + return + + self.has_initialized = True + self.module = module + ignored_params = {p for m in ignored_modules for p in m.parameters()} + for submodule in module.modules(): + if _is_fully_sharded(submodule): + ignored_params.update(submodule.parameters()) + from core.distributed.tensor.parallel.ddp import _localize_dtensor + + _localize_dtensor(module, ignored_params=ignored_params) + self._collect_params(module, ignored_modules, ignored_params) + + if "device_id" in kwargs: + # replicate() supports a small usability enhancement where + # user can pass in device_id as a Union[int, core.device] even for + # CPU devices so users don't have to change code for CPU/GPU runs. + # We derive the right device_ids to feed into DDP to support this. + if kwargs["device_id"] is not None: + device_id = kwargs["device_id"] + # Convert to device_ids that DDP expects. + if isinstance(device_id, core.device) and device_id.type == "cpu": + # CPU modules receive device_ids None + kwargs["device_ids"] = None + else: + # GPU modules expect device_ids=[cuda_device] + kwargs["device_ids"] = [device_id] + else: + kwargs["device_ids"] = None + kwargs.pop("device_id") + + self._ddp = DistributedDataParallel(self._param_list, **kwargs) + # Weakref to the DDP instance is currently only used for testing. + replicate.state(self.module)._ddp_weakref = weakref.ref(self._ddp) + + def register_comm_hook(self) -> None: + for comm_args, comm_kwargs in self._comm_hook_args: + self._ddp.register_comm_hook(*comm_args, **comm_kwargs) + self._comm_hook_args.clear() + + def record_init_args(self, *args, **kwargs) -> None: + self._init_args = args + self._init_kwargs = kwargs + + def forward_pre_hook( + self, module: nn.Module, args: Tuple[Any, ...], kwargs: Dict[str, Any] + ) -> Any: + if self._init_args or self._init_kwargs: + self.lazy_init() + self._ddp.require_backward_grad_sync = not self._no_sync + return self._ddp._pre_forward(*args, **kwargs) + + def forward_post_hook( + self, + module: nn.Module, + input: Tuple[core.Tensor], + output: core.Tensor, + ) -> core.Tensor: + return self._ddp._post_forward(output) + + +def unimplemented_deepcopy(*args: Any, **kwargs: Any) -> NoReturn: + raise AssertionError( + "DDP does not support deepcopy. Please use state dict for serialization." + ) + + +# Follow the same pattern as FSDP/fully_shard +class DDP: + def __new__(cls, *args, **kwargs): + """ + Override ``__new__`` to remove the DDP class and directly construct + the original class for cases like indexing into a container module. + """ + # Use index 2 since 0 is the dynamically constructed `DDP<...>` class + # and index 1 is the `DDP` class itself + orig_cls = cls.__mro__[2] + return orig_cls.__new__(orig_cls, *args, **kwargs) + + def set_requires_gradient_sync(self, requires_gradient_sync: bool) -> None: + """ + Sets if the module should sync gradients. This can be used to implement + gradient accumulation without communication. + + Args: + requires_gradient_sync (bool): Whether to reduce gradients for the + module's parameters. + """ + replicate.state(self)._no_sync = not requires_gradient_sync + + def register_comm_hook(self, *args, **kwargs) -> None: + replicate.state(self)._comm_hook_args.append((args, kwargs)) + + +@contract(state_cls=_ReplicateState) +def replicate( + module: nn.Module, + ignored_modules: Optional[Iterable[core.nn.Module]] = None, + **kwargs, +) -> nn.Module: + r"""Replicates a module + + Args: + module (core.nn.Module): module to replicate + + Example:: + >>> # xdoctest: +REQUIRES(module:core._C._distributed_c10d) + >>> module = nn.Linear(3, 3) + >>> replicate(module) + """ + core._C._log_api_usage_once("core.distributed.replicate") + + # TODO(fegin): using kwargs is not a good idea if we would like to make + # replicate a formal API to replace DDP. + if "device_id" in kwargs: + if not isinstance(kwargs["device_id"], (int, core.device)): + raise RuntimeError( + "Expected device_id to be int or core.device, " + f"but got {type(kwargs['device_id'])}" + ) + + if _is_fully_sharded(module): + raise RuntimeError( + "Cannot apply `replicate()` on a Module already managed by `fully_shard`" + ) + + if ignored_modules is None: + ignored_modules = {} + else: + ignored_modules = set(ignored_modules) + + state = cast(_ReplicateState, replicate.state(module)) + module.register_forward_pre_hook(state.forward_pre_hook, with_kwargs=True) + device_mesh = kwargs.get("device_mesh", None) + if device_mesh is not None: + from core.distributed.device_mesh import _mesh_resources + + root_mesh = _mesh_resources.get_root_mesh(device_mesh) + # if a root mesh is not the same as device_mesh, + # meaning the device_mesh is sliced out from the root mesh. + if root_mesh != device_mesh: + # TODO: This is a temporary work around to enable DDP + TP. + # We should do the logic in DDP so that the 2D implementation is + # sound and the state_dict works out of the box. + # + # This won't conflict with what is done in DDP class as the module + # replicate is going to pass is NOT the original module. + from core.distributed.tensor.parallel.ddp import ( + _localize_dtensor, + _reconstruct_dtensor, + ) + + module.register_forward_pre_hook(_reconstruct_dtensor) + module.register_forward_hook(_localize_dtensor) + + module.register_forward_hook(state.forward_post_hook) # type: ignore[arg-type] + + state.record_init_args(module, ignored_modules, **kwargs) + + # Place DDP leftmost for highest priority in the method resolution order + cls = module.__class__ + dct = {"__deepcopy__": unimplemented_deepcopy} + new_cls = type(f"DDP{cls.__name__}", (DDP, cls), dct) + module.__class__ = new_cls + return module + + +def _is_fully_sharded(module: nn.Module) -> bool: + r"""Check if module is marked with fully_shard.""" + registry = _get_registry(module) + if registry is None: + return False + return "fully_shard" in registry diff --git a/mindnlp/core/distributed/_composable_state.py b/mindnlp/core/distributed/_composable_state.py new file mode 100644 index 000000000..ad8f4fc5b --- /dev/null +++ b/mindnlp/core/distributed/_composable_state.py @@ -0,0 +1,44 @@ +import weakref +from typing import cast, Optional + +from mindnlp import core.nn as nn + + +class _State: + pass + + +_module_state_mapping: weakref.WeakKeyDictionary[ + nn.Module, weakref.ReferenceType[_State] +] = weakref.WeakKeyDictionary() + + +def _insert_module_state(module: nn.Module, state: _State) -> None: + global _module_state_mapping + assert module not in _module_state_mapping, f"Inserting {module} more than once." + _module_state_mapping[module] = weakref.ref(state) + + +def _get_module_state(module: nn.Module) -> Optional[_State]: + """ + Return the ``_State`` in ``model``. + + Given a ``module``, this API finds out if the module is also a ``_State`` + instance or if the module is managed by a composable API. If the module + is also a ``_State``, ``module`` will be casted to ``_State` and returned. + If it is managed by a composable API, the corresponding ``_State`` will + be returned. + """ + global _module_state_mapping + if isinstance(module, _State): + return cast(_State, module) + else: + # https://github.com/pytorch/pytorch/issues/107054 + if module in _module_state_mapping: + state_ref = _module_state_mapping[module] + state = state_ref() + if state is None: + raise AssertionError("State has already been garbage collected") + return state + else: + return None diff --git a/mindnlp/core/distributed/_functional_collectives.py b/mindnlp/core/distributed/_functional_collectives.py new file mode 100644 index 000000000..7282b1b32 --- /dev/null +++ b/mindnlp/core/distributed/_functional_collectives.py @@ -0,0 +1,1196 @@ +# mypy: allow-untyped-defs +import contextlib +import sys +import warnings +from typing import Any, cast, List, Optional, Tuple, Type, TYPE_CHECKING, Union + +from mindnlp import core +from mindnlp import core.distributed as dist +from mindnlp import core.distributed.distributed_c10d as c10d +from core.distributed.device_mesh import DeviceMesh +# from core.fx.experimental.proxy_tensor import get_proxy_mode + +from . import _functional_collectives_impl as fun_col_impl + + +# try: +# from core.utils._cxx_pytree import tree_map_only +# except ImportError: +# from core.utils._pytree import tree_map_only # type: ignore[no-redef] + + +# if core._running_with_deploy(): + +# def is_torchdynamo_compiling(): +# """Can't from mindnlp import coredynamo in torchdeploy builds currently.""" +# return False + +# else: +# try: +# from core.compiler import is_dynamo_compiling as is_torchdynamo_compiling +# except Exception: +# warnings.warn( +# "Unable to from mindnlp import coredynamo util `is_torchdynamo_compiling`, so won't support torchdynamo correctly" +# ) + +# def is_torchdynamo_compiling(): +# return False + + +""" +New traceable, functional collectives. +RFC: https://github.com/pytorch/pytorch/issues/93173 + + compiler: trace these ops with plain-old-data schemas, then choose how to lower them. + eager: execute these 'functional' ops which in eager return AsyncCollectiveTensor subclasses, + automatically calling .wait() on underlying/hidden async 'work' obj only when fed to + a downstream op. + +Issues: +* Where should these ops live? Couldn't `from mindnlp import core` if putting these ops in existing core.distributed files +* Proper support for eager requires inplace ops. We should explore having it as an option for the API. +""" + +""" +Functional collectives are asynchronous only and we perform implicit stream synchronization +on behalf of the user. + +We use AsyncCollectiveTensor to wrap the result tensor of a collective and it lets us witness +first usage of the tensor and insert cross stream sync at the right place. + +The above are the easy bits, the hard one is how we match the Work object returned by +c10d and the tensor AsyncCollectiveTensor wraps. We alloc the tensor inside the collective +op implementation (see ``clone()`` call in ``_all_reduce``) and then it's handled by the +dispatcher which might call other implementations that are allowed to change the returned +tensor - even return a tensor with a different shape (see ``core.vmap``). + +This means the caller of our ops receives a Tensor that is not guaranteed to be the same +allocated by our implementations and that makes pairing The AsyncTensor to the original +tensor a lot harder. This pairing is needed so we can lookup the Work object to use. + +Originally, we tried WeakKeyDictionary to map from Tensor to Work, but because Tensor's +identity is not stable across dispatch, the op caller would end up with a different Tensor +instance that would not match any in the dictionary. + +With Tensor identity out of the question, we decided use the tensor data pointer, which +should be stable across all the Tensor changes done during dispatch. + +We have a dictionary of tensor::data_ptr -> Work that we insert right after we call into c10d. + +We use this dictionary when AsyncCollectiveTensor is used to invoke Work::wait() + +Finally, we setup a finalizer against the tensor wrapper to observe it getting collected so we +can clean up stale entries in the dictionary. + +To eliminate the possibility of races we have a global version counter that is used by the finalizer. + +As a wise man said once: Don't cross the streams (https://www.youtube.com/watch?v=wyKQe_i9yyo) + +""" + +""" +Functional collectives can accept any of these types to describe the ranks participating in collectives. + +The different types will be desugared to a canonical format +""" +RANK_TYPES = Union[ + List[int], + List[List[int]], + dist.ProcessGroup, + DeviceMesh, + Tuple["dist.tensor.DeviceMesh", int], + str, +] + + +""" +User facing APIs for functional collectives +------------------------------------------- + +These apis are called by user code and expected to work both in eager execution and compilation, +but there are significant differences to how the two modes are implemented underneath. + +Eager execution is 'optimized' using a tensor subclass that schedules the synchronization (via wait_tensor() op) +just before the tensor is first used. Compiled tracing currently relies on the compiler to perform this optimization, +and cannot yet correctly trace the AsyncTensor wrapper class. In the future, these paths may be unified +if sufficient subclass support is added in dynamo. + +Example: all_reduce is an entrypoint API, and other collectives follow a similar pattern. + +Here's how it works under core.compile/dynamo: +all_reduce(...) + |--> _expand_group(...) - desugars processgroup into canonical/traceable format + |--> c10d_functional.all_reduce(...) - dynamo captures this op call, doesn't trace deeper + |--> _maybe_wrap_tensor(...) - wait_tensor() op is immediately called, no AsyncTensor subclass needed + +And under eager execution: +all_reduce(...) + |--> _expand_group(...) - same as above, but less critical for eager + |--> c10d_functional.all_reduce(...) - dispatches to real kernel OR records op in trace + |--> _maybe_wrap_tensor(...) - AsyncTensor wrapper applied to returned tensor, + which issues wait_tensor() at the time of first use +""" + + +def wait_tensor(tensor): + """ + Wait on a tensor returned by the collectives ops. + + Waiting follows device semantics, which means blocking on CPU and synchronizing streams on CUDA. + """ + return core.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined] + + +def broadcast(self: core.Tensor, src: int, group: RANK_TYPES, tag: str = ""): + """ + Broadcasts the tensor to all processes in the given process group. + + Args: + src (int): Source rank + group (ProcessGroup or List[int]): The process group to work on. + tag (str, optional): A unique identifier for the collective. Default: empty string + """ + group_name = _resolve_group_name(group, tag) + tensor = core.ops._c10d_functional.broadcast(self, src, group_name) + return _maybe_wrap_tensor(tensor) + + +def all_reduce(self: core.Tensor, reduceOp: str, group: RANK_TYPES, tag: str = ""): + """ + Reduces the tensor data across all machines in such a way that all get + the final result. + + The input tensor is left unmodified. + + Group can be one of: + List[int]: ranks participating in the collective. + List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD. + ProcessGroup: Will perform a collective using the ranks and tag of the PG. + DeviceMesh: Do a SPMD collective over all ranks of the mesh + (DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh + + :: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover + that information and perform collective algebraic optimization. Use other forms of input for that. + """ + group_name = _resolve_group_name(group, tag) + tensor = core.ops._c10d_functional.all_reduce(self, reduceOp.lower(), group_name) + return _maybe_wrap_tensor(tensor) + + +def all_gather_tensor( + self: core.Tensor, + gather_dim: int, + group: RANK_TYPES, + tag: str = "", +): + """ + Gather tensor data across from all machines and concatenate over ``gather_dim``. + + Note that it currently only supports gather_dim = 0. + + The input tensor is left unmodified. + Group can be one of: + List[int]: ranks participating in the collective. + List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD. + ProcessGroup: Will perform a collective using the ranks and tag of the PG. + DeviceMesh: Do a SPMD collective over all ranks of the mesh + (DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh + + :: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover + that information and perform collective algebraic optimization. Use other forms of input for that. + """ + assert self.is_contiguous() + group_name = _resolve_group_name(group, tag) + group_size = c10d._get_group_size_by_name(group_name) + tensor = core.ops._c10d_functional.all_gather_into_tensor( + self, group_size, group_name + ) + res = _maybe_wrap_tensor(tensor) + # TODO this should be done inside AsyncCollectiveTensor to delay the wait() call + if gather_dim != 0: + # core.cat access the data so we already need to wait here, first do wait + # and then chunk + cat avoid us going through ACT dispatching logic again + if isinstance(res, AsyncCollectiveTensor): + res = res.wait() # type: ignore[attr-defined] + res = core.cat(core.chunk(res, group_size, dim=0), dim=gather_dim) + return res + + +def all_gather_tensor_autograd( + self: core.Tensor, + gather_dim: int, + group: RANK_TYPES, + tag: str = "", +): + """ + Gather tensor data across from all machines and concatenate over ``gather_dim``. + + Note that it currently only supports gather_dim = 0. + + This function is the same as all_gather_tensor but will propagate the + backwards gradient across workers. + + See all_gather_tensor for more details on usage. + """ + group_name = _resolve_group_name(group, tag) + group_size = c10d._get_group_size_by_name(group_name) + + tensor = core.ops._c10d_functional_autograd.all_gather_into_tensor( + self, group_size, group_name + ) + res = _FromTorchTensor.apply(tensor) + # TODO this should be done inside AsyncCollectiveTensor to delay the wait() call + if gather_dim != 0: + # core.cat access the data so we already need to wait here, first do wait + # and then chunk + cat avoid us going through ACT dispatching logic again + if isinstance(res, AsyncCollectiveTensor): + res = res.wait() # type: ignore[attr-defined] + res = core.cat(core.chunk(res, group_size, dim=0), dim=gather_dim) + return res + + +def reduce_scatter_tensor( + self: core.Tensor, + reduceOp: str, + scatter_dim: int, + group: RANK_TYPES, + tag: str = "", +): + """ + Reduces the tensor data across all machines in such a way that all get + the final result, then scatter the results to corresponding ranks. + + + The input tensor is left unmodified. + Group can be one of: + List[int]: ranks participating in the collective. + List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD. + ProcessGroup: Will perform a collective using the ranks and tag of the PG. + DeviceMesh: Do a SPMD collective over all ranks of the mesh + (DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh + :: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover + that information and perform collective algebraic optimization. Use other forms of input for that. + """ + group_name = _resolve_group_name(group, tag) + group_size = c10d._get_group_size_by_name(group_name) + + assert ( + self.size(scatter_dim) % group_size == 0 + ), f"input dimension 0 ({self.size(0)} must be a multiple of group_size {group_size}" + if scatter_dim != 0: + tensor_list = core.chunk(self, group_size, dim=scatter_dim) + self = core.cat(tensor_list) + + tensor = core.ops._c10d_functional.reduce_scatter_tensor( + self, + reduceOp.lower(), + group_size, + group_name, # type: ignore[possibly-undefined] + ) + res = _maybe_wrap_tensor(tensor) + return res + + +def reduce_scatter_tensor_autograd( + self: core.Tensor, + reduceOp: str, + scatter_dim: int, + group: RANK_TYPES, + tag: str = "", +): + """ + Reduces the tensor data across all machines in such a way that all get + the final result, then scatter the results to corresponding ranks. + + This function is the same as reduce_scatter_tensor but will propagate the + backwards gradient across workers. + + Currently only the "sum" reduceOp is supported. + + See reduce_scatter_tensor for more details on usage. + """ + + group_name = _resolve_group_name(group, tag) + group_size = c10d._get_group_size_by_name(group_name) + + assert ( + self.size(scatter_dim) % group_size == 0 + ), f"input dimension 0 ({self.size(0)} must be a multiple of group_size {group_size}" + if scatter_dim != 0: + tensor_list = core.chunk(self, group_size, dim=scatter_dim) + self = core.cat(tensor_list) + + tensor = core.ops._c10d_functional_autograd.reduce_scatter_tensor( + self, + reduceOp.lower(), + group_size, + group_name, # type: ignore[possibly-undefined] + ) + res = _FromTorchTensor.apply(tensor) + return res + + +def all_reduce_coalesced( + self: List[core.Tensor], reduceOp: str, group: RANK_TYPES, tag: str = "" +) -> List[core.Tensor]: + """ + Reduces a list of tensors across all machines in such a way that all get + the final result. + + The all tensors in the input list are left unmodified. + + Group can be one of: + List[int]: ranks participating in the collective. + List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD. + ProcessGroup: Will perform a collective using the ranks and tag of the PG. + DeviceMesh: Do a SPMD collective over all ranks of the mesh + (DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh + + :: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover + that information and perform collective algebraic optimization. Use other forms of input for that. + """ + group_name = _resolve_group_name(group, tag) + tensor_list = core.ops._c10d_functional.all_reduce_coalesced( # type: ignore[attr-defined] + self, + reduceOp.lower(), + group_name, + ) + return list(map(_maybe_wrap_tensor, tensor_list)) + + +def all_gather_into_tensor_coalesced( + self: List[core.Tensor], group: RANK_TYPES, tag: str = "" +) -> List[core.Tensor]: + """ + Gather a list of tensors across from all machines. + + Note that it currently only supports gather_dim = 0. + + The input tensor is left unmodified. + Group can be one of: + List[int]: ranks participating in the collective. + List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD. + ProcessGroup: Will perform a collective using the ranks and tag of the PG. + DeviceMesh: Do a SPMD collective over all ranks of the mesh + (DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh + + :: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover + that information and perform collective algebraic optimization. Use other forms of input for that. + """ + group_name = _resolve_group_name(group, tag) + group_size = c10d._get_group_size_by_name(group_name) + tensor_list = core.ops._c10d_functional.all_gather_into_tensor_coalesced( # type: ignore[attr-defined] + self, + group_size, + group_name, + ) + return list(map(_maybe_wrap_tensor, tensor_list)) + + +def reduce_scatter_tensor_coalesced( + inputs: List[core.Tensor], + reduceOp: str, + scatter_dim: List[int], + group: RANK_TYPES, + tag: str = "", +) -> List[core.Tensor]: + """ + Reduces a list of tensors across all machines in such a way that all get + the final result, then scatter the results to corresponding ranks. + + The input tensors are left unmodified. + Group can be one of: + List[int]: ranks participating in the collective. + List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD. + ProcessGroup: Will perform a collective using the ranks and tag of the PG. + DeviceMesh: Do a SPMD collective over all ranks of the mesh + (DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh + + :: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover + that information and perform collective algebraic optimization. Use other forms of input for that. + """ + group_name = _resolve_group_name(group, tag) + group_size = c10d._get_group_size_by_name(group_name) + + assert len(scatter_dim) == len(inputs) + for idx, (dim, tensor) in enumerate(zip(scatter_dim, inputs)): + assert ( + tensor.size(dim) % group_size == 0 + ), f"input dimension {dim} ({tensor.size(dim)} must be a multiple of group_size {group_size} for tensor at index {idx}" + if dim != 0: + tensor_list = core.chunk(tensor, group_size, dim=dim) + inputs[idx] = core.cat(tensor_list) + + tensor_list = core.ops._c10d_functional.reduce_scatter_tensor_coalesced( # type: ignore[attr-defined] + inputs, + reduceOp.lower(), + group_size, + group_name, # type: ignore[possibly-undefined] + ) + + return list(map(_maybe_wrap_tensor, tensor_list)) + + +# This is a bit unsafe: it checks if the first argument in the schema reports as a non-mutable alias. +# Today, this maps 1:1 with "aten ops that are views". +def _is_view_op(tgt): + assert isinstance(tgt, core._ops.OpOverload) + schema = tgt._schema + if len(schema.arguments) > 0: + first_arg = schema.arguments[0] + # check if op is a view + return first_arg.alias_info is not None and not first_arg.alias_info.is_write + + +def all_to_all_single( + self: core.Tensor, + output_split_sizes: Optional[List[int]], + input_split_sizes: Optional[List[int]], + group: RANK_TYPES, + tag: str = "", +) -> core.Tensor: + """ + Each process splits input tensor and then scatters the split list + to all processes in a group. Then concatenate the received tensors from all + the processes in the group and return single output tensor. + + Group can be one of: + List[int]: ranks participating in the collective. + List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD. + ProcessGroup: Will perform a collective using the ranks and tag of the PG. + DeviceMesh: Do a SPMD collective over all ranks of the mesh + (DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh + + :: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover + that information and perform collective algebraic optimization. Use other forms of input for that. + """ + if output_split_sizes is not None: + assert all( + isinstance(size, (int, core.SymInt)) for size in output_split_sizes + ), output_split_sizes + if input_split_sizes is not None: + assert all( + isinstance(size, (int, core.SymInt)) for size in input_split_sizes + ), input_split_sizes + group_name = _resolve_group_name(group, tag) + group_size = c10d._get_group_size_by_name(group_name) + if output_split_sizes is None or input_split_sizes is None: + assert output_split_sizes is None and input_split_sizes is None, ( + "output_split_sizes and input_split_sizes must either be " + "specified together or both set to None" + ) + output_split_sizes = [self.shape[0] // group_size] * group_size + input_split_sizes = output_split_sizes + tensor = core.ops._c10d_functional.all_to_all_single( # type: ignore[attr-defined] + self, + output_split_sizes, + input_split_sizes, + group_name, + ) + return _maybe_wrap_tensor(tensor) + + +def all_to_all_single_autograd( + self: core.Tensor, + output_split_sizes: Optional[List[int]], + input_split_sizes: Optional[List[int]], + group: RANK_TYPES, + tag: str = "", +) -> core.Tensor: + """ + Same as all_to_all_single but supports autograd. + """ + if output_split_sizes is not None: + assert all( + isinstance(size, (int, core.SymInt)) for size in output_split_sizes + ), output_split_sizes + if input_split_sizes is not None: + assert all( + isinstance(size, (int, core.SymInt)) for size in input_split_sizes + ), input_split_sizes + + group_name = _resolve_group_name(group, tag) + group_size = c10d._get_group_size_by_name(group_name) + if output_split_sizes is None or input_split_sizes is None: + assert output_split_sizes is None and input_split_sizes is None, ( + "output_split_sizes and input_split_sizes must either be " + "specified together or both set to None" + ) + output_split_sizes = [self.shape[0] // group_size] * group_size + input_split_sizes = output_split_sizes + tensor = core.ops._c10d_functional_autograd.all_to_all_single( # type: ignore[attr-defined] + self, + output_split_sizes, + input_split_sizes, + group_name, + ) + return _FromTorchTensor.apply(tensor) + + +def permute_tensor( + self: core.Tensor, + src_dst: List[int], + group: RANK_TYPES, + tag: str = "", +) -> core.Tensor: + """ + Permutes the elements of the tensor according to the given source/destination pairs. `src_dst` should + be defined such that src_dst[m] == n means m sends to n. + + Group can be one of: + List[int]: ranks participating in the collective. + List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD. + ProcessGroup: Will perform a collective using the ranks and tag of the PG. + DeviceMesh: Do a SPMD collective over all ranks of the mesh + (DeviceMesh, int): Do a MPMD collective over one + """ + t, rankset, group_size = _expand_group(group, tag) + local_pg = c10d._find_or_create_pg_by_ranks_and_tag(t, rankset, group_size) + + output_split_sizes = [0] * group_size + input_split_sizes = [0] * group_size + for src, dst in enumerate(src_dst): + if src == dist.get_rank(local_pg): + input_split_sizes[dst] = self.numel() + if dst == dist.get_rank(local_pg): + output_split_sizes[src] = self.numel() + + return all_to_all_single(self, output_split_sizes, input_split_sizes, group, tag) + + +class AsyncCollectiveTensor(core.Tensor): + r""" + A Tensor wrapper subclass that is used to trigger a call to wait + prior to first use of the underlying tensor. + Use it inside functional collective pytorch wrappers like the following: + def functional_collective(self, group, tag): + tag, rankset, group_size = _expand_group(group, tag) + tensor = core.ops.c10d_functional.{collective}(self, tag, rankset, group_size) + return _maybe_wrap_tensor(tensor) + """ + elem: core.Tensor + completed: bool + + __slots__ = ["elem", "completed"] + + @staticmethod + def __new__(cls, elem: core.Tensor): + r = core.Tensor._make_wrapper_subclass( # type: ignore[attr-defined] + cls, + elem.size(), + strides=elem.stride(), + storage_offset=elem.storage_offset(), + dtype=elem.dtype, + layout=elem.layout, + device=elem.device, + requires_grad=elem.requires_grad, + ) + r.elem = elem + r.completed = False + return r + + def __tensor_flatten__(self): + return ["elem"], None + + def tolist(self): + return self.trigger_wait().tolist() + + @staticmethod + def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride): + assert meta is None + elem = inner_tensors["elem"] + return AsyncCollectiveTensor(elem) + + def __coerce_same_metadata_as_tangent__( + self, expected_metadata: Any, expected_type: Optional[Type] = None + ): + if expected_type is not core.Tensor: + return None + + return self.trigger_wait() + + def __repr__(self) -> str: # type: ignore[override] + return f"AsyncCollectiveTensor({self.trigger_wait()})" + + def trigger_wait(self): + if not self.completed: + out = wait_tensor(self.elem) + self.completed = True + return out + else: + return self.elem + + def wait(self) -> core.Tensor: + return wait_tensor(self.elem) + + def _get_acs_underlying_tensor(self): + """This method enables _functional_collectives_impl to test if a tensor is an ACS""" + return self.elem + + @classmethod + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + if func == core.ops.aten.view.default: + # Fast handle aten.view as a lot of view related op goes to aten.view + # eventually, this avoids pytree slowdown + res = func(args[0].elem, args[1]) + wrapper_res = AsyncCollectiveTensor(res) + return wrapper_res + + is_view_op = _is_view_op(func) + + def unwrap(e: AsyncCollectiveTensor): + # wait_tensor is idepotent and will do stream sync only once + if not is_view_op: + return e.trigger_wait() + return e.elem + + def wrap(e: core.Tensor): + # wait_tensor is idepotent and will do stream sync only once + assert not isinstance(e, AsyncCollectiveTensor) + res = AsyncCollectiveTensor(e) + return res + + unwrapped_args = tree_map_only(AsyncCollectiveTensor, unwrap, args) + unwrapped_kwargs = tree_map_only(AsyncCollectiveTensor, unwrap, kwargs) + + # we don't wrap the result as it doesn't need to be waited on. + out = func(*unwrapped_args, **unwrapped_kwargs) + + # View ops dont require a sync, so we should re-wrap the outputs. + if is_view_op: + out = tree_map_only(core.Tensor, wrap, out) + + return out + + def numpy(self): # type: ignore[override] + return self.wait().numpy() + + +""" +Utils and infrastructure for tracing support +""" + + +def _expand_group(group: RANK_TYPES, tag: str = "") -> Tuple[str, List[int], int]: + """ + _expand_group desugars the different RANK_TYPES types into a canonical format that is traceable. + + By having this be part of the explicit eager codepath, we avoid having to specialize behavior inside + torchdynamo and can still interoperate with processgroup objects or other untraceable forms. + """ + # had to define this hack _inside_ expand_group to avoid + # graph_break [('core.* op returned non-Tensor int + # caused by 'cast_*` functions being treated as 'core.*' ops (iiuc) + if TYPE_CHECKING: + + def cast_listlistint(x): + return cast(List[List[int]], x) + + def cast_listint(x): + return cast(List[int], x) + + else: + # fake cast op for use at runtime since dynamo doesn't support real cast + # also, dynamo didn't like encountering 'typing' objects () + # NotImplementedError: argument of type: + def cast_listlistint(x): + return x + + def cast_listint(x): + return x + + rankset: List[int] + if isinstance(group, list): + if isinstance(group[0], list): + nested_list = cast_listlistint(group) + rankset = [] + group_size = -1 + for rs in nested_list: + rankset.extend(rs) + if group_size != -1 and group_size != len(rs): + raise ValueError( + f"group sizes must be identical found {group_size} and {len(rs)}" + ) + group_size = len(rs) + else: + rankset = cast_listint(group) + group_size = len(rankset) + elif isinstance(group, dist.ProcessGroup): + rankset = dist.get_process_group_ranks(group) + group_size = len(rankset) + tag = tag or c10d._get_group_tag(group) + elif isinstance(group, DeviceMesh): + assert ( + group.ndim == 1 + ), "Only 1D mesh is supported, pass in (DeviceMesh, int) together if mesh > 1D" + # TODO: it should run collective in the whole mesh instead of dim 0 + tag, rankset, _ = group._dim_group_infos[0] + group_size = len(rankset) + elif isinstance(group, tuple): + if ( + len(group) == 2 + and isinstance(group[0], DeviceMesh) + and isinstance(group[1], int) + ): + dmesh = group[0] + dim = group[1] + tag, rankset, _ = dmesh._dim_group_infos[dim] + group_size = len(rankset) + else: + raise ValueError("Invalid tuple for group must be (DeviceMesh, int)") + else: + raise ValueError( + "Invalid type for group, must be one of List, Processgroup, DeviceMesh or (DeviceMesh, int)." + ) + + return (tag, rankset, group_size) + + +def _resolve_group_name(group: RANK_TYPES, tag: str = "") -> str: + """ + Given group in RANK_TYPES, return the group name. + """ + # `tag` will be deprecated. See details in: + # https://github.com/pytorch/pytorch/issues/93173#issuecomment-1907095208 + if isinstance(group, dist.ProcessGroup): + return group.group_name + elif isinstance(group, str): + return group + elif isinstance(group, DeviceMesh): + assert ( + group.ndim == 1 + ), "Only 1D mesh is supported, pass in (DeviceMesh, int) together if mesh > 1D" + return group._dim_group_infos[0][2] + elif isinstance(group, tuple): + if ( + len(group) == 2 + and isinstance(group[0], DeviceMesh) + and isinstance(group[1], int) + ): + dmesh = group[0] + dim = group[1] + return dmesh._dim_group_infos[dim][2] + else: + raise ValueError("Invalid tuple for group must be (DeviceMesh, int)") + elif isinstance(group, list): + if not is_torchdynamo_compiling(): + warnings.warn( + "The combination of ranks + tag as process group " + "identifier has been deprecated. Please switch to " + "using ProcessGroup, DeviceMesh, or group name instead.", + FutureWarning, + stacklevel=3, + ) + return c10d._resolve_group_name_by_ranks_and_tag(cast(List[int], group), tag) + else: + raise ValueError(f"Unsupported group type: {type(group)}, {group}") + + +class _FromTorchTensor(core.autograd.Function): + """ + _FromTorchTensor allows autograd to propagate from a normal Tensor to an + AsyncCollectiveTensor. + """ + + @staticmethod + def forward( # type: ignore[override] + ctx, # pyre-ignore[2]: Parameter must be annotated. + input: core.Tensor, + ) -> core.Tensor: + return _maybe_wrap_tensor(input) + + @staticmethod + def backward(ctx, grad_output: core.Tensor) -> core.Tensor: # type: ignore[override] + return grad_output + + +def _are_we_tracing() -> bool: + if is_torchdynamo_compiling(): + return True + # If functionalization is turned on, we are almost definitely compiling/tracing. + # (In particular, AOTAutograd traces a model once with functionalization on + # but proxy tracing turned of, so this is how we detect it). + if ( + core._C._get_dispatch_mode(core._C._TorchDispatchModeKey.FUNCTIONAL) + is not None + ): + return True + return get_proxy_mode() is not None + + +def _maybe_wrap_tensor(self) -> core.Tensor: + if _are_we_tracing(): + return wait_tensor(self) + res = AsyncCollectiveTensor(self) + return cast(core.Tensor, res) + + +@contextlib.contextmanager +def allow_inflight_collective_as_graph_input_ctx(value: bool = True): + """ + Context manager to temporarily set whether inflight collectives are allowed as core.compile graph inputs. + Common use case is when the collective is issued in eager (with `async_op=True`) but waited in compiled region: + ``` + def all_reduce_eager(x): + y = x * x + req = dist.all_reduce(y, op=dist.ReduceOp.SUM, async_op=True) + return y + + @core.compile(fullgraph=True) + def all_reduce_wait_compiled(y): + core.ops.c10d_functional.wait_tensor(y) + return y * y + + x = core.ones(1280, 1280, device="cuda") + self.rank + # the context manager ensures that `wait_tensor(y)` will wait on the correct work object + with allow_inflight_collective_as_graph_input_ctx(): + y = all_reduce_eager(x) + z = all_reduce_wait_compiled(y) + ``` + With this context manager, when a collective is called, under the hood the work object of the collective + will be registered in the work registry, and the wait_tensor() in compiled region called on + the output tensor of the collective will wait on the correct work object. + """ + previous = core._C._distributed_c10d._allow_inflight_collective_as_graph_input() + + try: + core._C._distributed_c10d._set_allow_inflight_collective_as_graph_input(value) + yield + finally: + core._C._distributed_c10d._set_allow_inflight_collective_as_graph_input( + previous + ) + + +def _all_gather_into_tensor_coalesced_meta(self, tag, rankset, group_size): + def mk_out_tensor(shard): + out_size = list(shard.size()) + out_size[0] *= group_size + out_tensor = shard.new_empty(out_size) + return out_tensor + + return [mk_out_tensor(t) for t in self] + + +# We now register meta kernels to deal with tracing +def _broadcast_meta(self, *args): + return core.empty_like(self) + + +def _all_reduce_meta(self, *args): + return core.empty_like(self) + + +def _wait_tensor_meta(self, *args): + return core.empty_like(self) + + +def _all_gather_into_tensor_meta(shard, tag, rankset, group_size): + out_size = list(shard.size()) + out_size[0] *= group_size + return shard.new_empty(out_size) + + +def _reduce_scatter_tensor_meta(input, reduce_op, tag, rankset, group_size): + out_size = list(input.size()) + out_size[0] //= group_size + return input.new_empty(out_size) + + +def _all_reduce_coalesced_meta(self, *args): + return [core.empty_like(t) for t in self] + + +def _all_reduce__meta(inp, *args): + return inp + + +def _broadcast__meta(inp, *args): + return inp + + +def _all_reduce_coalesced__meta(inputs, *args): + return inputs + + +def _reduce_scatter_tensor_coalesced_meta(inputs, reduceOp, tag, rankset, group_size): + def mk_out_tensor(input): + out_size = list(input.size()) + out_size[0] //= group_size + out_tensor = input.new_empty(out_size) + return out_tensor + + return [mk_out_tensor(t) for t in inputs] + + +# NB: We often say all_to_all has dynamic output size, but this is not +# technically true: instead, what typically happens is you manually +# communicate the output_split_sizes ahead of time (which is dynamic), +# but then you pass those sizes explicitly, and the all to all itself +# isn't dynamic, it just follows the specified output splits +def _all_to_all_single_meta( + input, output_split_sizes, input_split_sizes, *args, **kwargs +): + if output_split_sizes is None: + return input.new_empty(input.size()) + else: + for s in output_split_sizes: + core._check_is_size(s) + out_size = list(input.size()) + out_size[0] = sum(output_split_sizes) + return input.new_empty(out_size) + + +def _all_gather_into_tensor_out_native_meta(input, group_size, group_name, *, out): + shape = list(input.size()) + shape[0] *= group_size + return input.new_empty(shape) + + +def _all_gather_into_tensor_native_meta(input, group_size, group_name): + shape = list(input.size()) + shape[0] *= group_size + return input.new_empty(shape) + + +def _all_gather_into_tensor_coalesced_native_meta(inputs, group_size, group_name): + return [ + _all_gather_into_tensor_native_meta(input, group_size, group_name) + for input in inputs + ] + + +def _reduce_scatter_tensor_native_meta(inp, reduce_op, group_size, group_name): + shape = list(inp.size()) + shape[0] //= group_size + return inp.new_empty(shape) + + +def _reduce_scatter_tensor_coalesced_native_meta( + inputs, reduce_op, group_size, group_name +): + return [ + _reduce_scatter_tensor_native_meta(inp, reduce_op, group_size, group_name) + for inp in inputs + ] + + +# if not core._running_with_deploy(): +# # Library MUST be defined at module scope or it doesn't work +# # Creating a "DEF" Library always crashes torch::deploy so we create our +# # Library instances here guarded against running inside it +# lib_impl = core.library.Library("_c10d_functional", "IMPL") +# lib_impl.impl("all_reduce", _all_reduce_meta, "Meta") +# lib_impl.impl("all_reduce_", _all_reduce__meta, "Meta") +# lib_impl.impl("all_reduce_coalesced", _all_reduce_coalesced_meta, "Meta") +# lib_impl.impl("all_reduce_coalesced_", _all_reduce_coalesced__meta, "Meta") +# lib_impl.impl("wait_tensor", _wait_tensor_meta, "Meta") +# lib_impl.impl( +# "all_gather_into_tensor_out", _all_gather_into_tensor_out_native_meta, "Meta" +# ) +# lib_impl.impl("all_gather_into_tensor", _all_gather_into_tensor_native_meta, "Meta") +# lib_impl.impl( +# "all_gather_into_tensor_coalesced", +# _all_gather_into_tensor_coalesced_native_meta, +# "Meta", +# ) +# lib_impl.impl("reduce_scatter_tensor", _reduce_scatter_tensor_native_meta, "Meta") +# lib_impl.impl( +# "reduce_scatter_tensor_coalesced", +# _reduce_scatter_tensor_coalesced_native_meta, +# "Meta", +# ) +# lib_impl.impl("all_to_all_single", _all_to_all_single_meta, "Meta") +# lib_impl.impl("broadcast", _broadcast_meta, "Meta") +# lib_impl.impl("broadcast_", _broadcast__meta, "Meta") + +# # mark these ops has side effect so that they won't be removed by DCE +# core.fx.node.has_side_effect(core.ops._c10d_functional.wait_tensor.default) +# core.fx.node.has_side_effect(core.ops._c10d_functional.wait_tensor) + +# # Register legacy ops for backward compatibility +# # TODO(yifu): remove these in functional collective beta release +# legacy_lib = core.library.Library("c10d_functional", "DEF") +# legacy_lib_impl = core.library.Library("c10d_functional", "IMPL") +# ops_defs = [ +# "broadcast(Tensor self, int src, str tag, int[] ranks, int group_size) -> Tensor", +# "all_reduce(Tensor self, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor", +# "all_reduce_coalesced(Tensor[] self, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor[]", +# "wait_tensor(Tensor self) -> Tensor", +# "all_gather_into_tensor(Tensor shard, str tag, int[] ranks, int group_size) -> Tensor", +# "all_gather_into_tensor_coalesced(Tensor[] input, str tag, int[] ranks, int group_size) -> Tensor[]", +# "reduce_scatter_tensor(Tensor input, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor", +# "reduce_scatter_tensor_coalesced(Tensor[] inputs, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor[]", +# "all_to_all_single(Tensor input, SymInt[]? output_split_sizes, SymInt[]? input_split_sizes, str tag, int[] ranks, int group_size) -> Tensor", # noqa: B950 +# ] + +# my_module = sys.modules[__name__] +# for op_def in ops_defs: +# op_name = op_def[0 : op_def.index("(")] +# backend_impl = getattr(fun_col_impl, f"_{op_name}") +# legacy_lib.define(op_def, tags=core.Tag.pt2_compliant_tag) +# legacy_lib_impl.impl(op_name, backend_impl, "CompositeImplicitAutograd") + +# else: +# warnings.warn( +# "PyTorch Distributed functional collectives do not work with torch::deploy." +# ) + + +""" +Dynamo Remappings allow seamless translation from non-functional collectives of supportable form into +functional collective calls followed by inplace copy ops, allowing them to be traced into a functional graph. + +We implement this by writing a decomposition and teaching dynamo how to associate it to a corresponding op via +the mapping dict below. + +These schemas intentionally match core.distributed.distributed_c10d.* ops that we are trying to remap from +""" + + +def all_gather_tensor_inplace( + output_tensor: core.Tensor, + input_tensor: core.Tensor, + group, # TODO add a type, + async_op: bool = False, + tag: str = "", + gather_dim: int = 0, +): + assert ( + not async_op + ), "Can't remap async version of inplace op to functional collective" + + group = group or dist.group.WORLD + assert group is not None + + return output_tensor.copy_(all_gather_tensor(input_tensor, gather_dim, group, tag)) + + +def reduce_scatter_tensor_inplace( + output: core.Tensor, + input: core.Tensor, + op: str = "sum", # TODO type is actually c10d ReduceOp. is this ok? + group=None, # TODO add a type + async_op: bool = False, + scatter_dim: int = 0, + tag: str = "", +): + assert ( + not async_op + ), "Can't remap async version of inplace op to functional collective" + + group = group or dist.group.WORLD + assert group is not None + + return output.copy_(reduce_scatter_tensor(input, op, scatter_dim, group, tag)) + + +REDUCE_OP_TO_STR = { + dist.ReduceOp.SUM: "sum", + dist.ReduceOp.AVG: "avg", + dist.ReduceOp.PRODUCT: "product", + dist.ReduceOp.MIN: "min", + dist.ReduceOp.MAX: "max", + dist.ReduceOp.BAND: "band", + dist.ReduceOp.BOR: "bor", + dist.ReduceOp.BXOR: "bxor", +} + + +def all_reduce_inplace( + tensor: core.Tensor, + op: str = "sum", + group=None, + async_op: bool = False, + tag: str = "", +): + assert ( + not async_op + ), "Can't remap async version of inplace op to functional collective" + + group = group or dist.group.WORLD + assert group is not None + + return tensor.copy_(all_reduce(tensor, op, group, tag)) + + +def all_to_all_inplace( + output: core.Tensor, + input: core.Tensor, + output_split_sizes=None, + input_split_sizes=None, + group=None, + async_op=False, + tag: str = "", +): + assert ( + not async_op + ), "Can't remap async version of inplace op to functional collective" + + group = group or dist.group.WORLD + assert group is not None + + return output.copy_( + all_to_all_single( + input, + output_split_sizes, + input_split_sizes, + group, + tag, + ) + ) + + +def all_gather_inplace( + tensor_list: List[core.Tensor], + tensor: core.Tensor, + group=None, + async_op=False, + tag: str = "", +): + assert ( + not async_op + ), "Can't remap async version of inplace op to functional collective" + assert all( + t.size(0) == tensor.size(0) for t in tensor_list + ), "Remapping variable size all_gather is not yet supported" + + group = group or dist.group.WORLD + assert group is not None + + output = all_gather_tensor(tensor, 0, group, tag) + + # Use aten.slice instead of aten.split because the latter causes + # tensor.shape(0) to be unnecessarily baked in when it's a SymInt. + output_splits = [] + offset = 0 + for t in tensor_list: + output_splits.append(output[offset : offset + t.size(0)]) + offset += t.size(0) + for dst, src in zip(tensor_list, output_splits): + dst.copy_(src) + return tensor_list + + +from core.distributed.distributed_c10d import ( + _all_gather_base as legacy_all_gather_base, + _reduce_scatter_base as legacy_reduce_scatter_base, + all_gather as legacy_all_gather, + all_gather_into_tensor as legacy_allgather, + all_reduce as legacy_allreduce, + all_to_all_single as legacy_all_to_all_single, + reduce_scatter_tensor as legacy_reducescatter, +) + + +# This dict should contain sets of functions that dynamo is allowed to remap. +# Functions in this set should accept the same args/kwargs 1:1 as their mapping. +traceable_collective_remaps = { + legacy_allgather: all_gather_tensor_inplace, + legacy_reducescatter: reduce_scatter_tensor_inplace, + legacy_allreduce: all_reduce_inplace, + legacy_all_to_all_single: all_to_all_inplace, + legacy_all_gather: all_gather_inplace, + legacy_reduce_scatter_base: reduce_scatter_tensor_inplace, + legacy_all_gather_base: all_gather_tensor_inplace, +} diff --git a/mindnlp/core/distributed/_functional_collectives_impl.py b/mindnlp/core/distributed/_functional_collectives_impl.py new file mode 100644 index 000000000..83f3ab722 --- /dev/null +++ b/mindnlp/core/distributed/_functional_collectives_impl.py @@ -0,0 +1,117 @@ +# mypy: allow-untyped-defs +from typing import List, Optional + +from mindnlp import core +from mindnlp import core.distributed.distributed_c10d as c10d + + +""" +This file contains the op impls for the legacy (c10d_functional) functional collectives. +These impls simply call into the native (_c10d_functional) functional collectives. +""" + + +def _broadcast(input, src, tag, ranks, group_size): + group_name = c10d._resolve_group_name_by_ranks_and_tag(ranks, tag) + return core.ops._c10d_functional.broadcast( + input, + src, + group_name, + ) + + +def _all_reduce(input, reduce_op, tag, ranks, group_size): + group_name = c10d._resolve_group_name_by_ranks_and_tag(ranks, tag) + return core.ops._c10d_functional.all_reduce( + input, + reduce_op, + group_name, + ) + + +def _all_reduce_coalesced(inputs, reduce_op, tag, ranks, group_size): + group_name = c10d._resolve_group_name_by_ranks_and_tag(ranks, tag) + return core.ops._c10d_functional.all_reduce_coalesced( + inputs, + reduce_op, + group_name, + ) + + +def _all_gather_into_tensor(input, tag, ranks, group_size): + group_name = c10d._resolve_group_name_by_ranks_and_tag(ranks, tag) + return core.ops._c10d_functional.all_gather_into_tensor( + input, + group_size, + group_name, + ) + + +def _all_gather_into_tensor_coalesced(input, tag, ranks, group_size): + group_name = c10d._resolve_group_name_by_ranks_and_tag(ranks, tag) + return core.ops._c10d_functional.all_gather_into_tensor_coalesced( + input, + group_size, + group_name, + ) + + +def _reduce_scatter_tensor( + input: core.Tensor, + reduce_op: str, + tag: str, + ranks: List[int], + group_size: int, +): + group_name = c10d._resolve_group_name_by_ranks_and_tag(ranks, tag) + return core.ops._c10d_functional.reduce_scatter_tensor( + input, + reduce_op, + group_size, + group_name, + ) + + +def _reduce_scatter_tensor_coalesced( + inputs: List[core.Tensor], + reduce_op: str, + tag: str, + ranks: List[int], + group_size: int, +): + group_name = c10d._resolve_group_name_by_ranks_and_tag(ranks, tag) + return core.ops._c10d_functional.reduce_scatter_tensor_coalesced( + inputs, + reduce_op, + group_size, + group_name, + ) + + +def _all_to_all_single( + input: core.Tensor, + output_split_sizes: Optional[List[int]], + input_split_sizes: Optional[List[int]], + tag: str, + ranks: List[int], + group_size: int, +): + if output_split_sizes is None or input_split_sizes is None: + assert output_split_sizes is None and input_split_sizes is None, ( + "output_split_sizes and input_split_sizes must either be " + "specified together or both set to None" + ) + output_split_sizes = [input.shape[0] // group_size] * group_size + input_split_sizes = output_split_sizes + + group_name = c10d._resolve_group_name_by_ranks_and_tag(ranks, tag) + return core.ops._c10d_functional.all_to_all_single( + input, + output_split_sizes, + input_split_sizes, + group_name, + ) + + +def _wait_tensor(tensor: core.Tensor) -> core.Tensor: + return core.ops._c10d_functional.wait_tensor(tensor) diff --git a/mindnlp/core/distributed/_shard/__init__.py b/mindnlp/core/distributed/_shard/__init__.py new file mode 100644 index 000000000..23542f81c --- /dev/null +++ b/mindnlp/core/distributed/_shard/__init__.py @@ -0,0 +1 @@ +from .api import _shard_tensor, load_with_process_group, shard_module, shard_parameter diff --git a/mindnlp/core/distributed/_shard/_utils.py b/mindnlp/core/distributed/_shard/_utils.py new file mode 100644 index 000000000..2d304e6c6 --- /dev/null +++ b/mindnlp/core/distributed/_shard/_utils.py @@ -0,0 +1,32 @@ +from typing import Sequence + +from mindnlp import core +from core.distributed._shard.metadata import ShardMetadata + + +DEPRECATE_MSG = "Please use DTensor instead and we are deprecating ShardedTensor." + + +def narrow_tensor_by_index( + tensor: core.Tensor, + offsets: Sequence[int], + sizes: Sequence[int], +) -> core.Tensor: + """ + Narrow the tensor according to ``offsets`` and ``sizes``. + """ + narrowed_tensor = tensor + for idx, (offset, size) in enumerate(zip(offsets, sizes)): + if size < tensor.size(idx): + # Reshape to get shard for this rank and we don't want autograd + # recording here for the narrow op and 'local_shard' should be a + # leaf variable in the autograd graph. + narrowed_tensor = narrowed_tensor.narrow(idx, offset, size) + return narrowed_tensor + + +def narrow_tensor(tensor: core.Tensor, metadata: ShardMetadata) -> core.Tensor: + """ + Narrow the tensor according to the metadata + """ + return narrow_tensor_by_index(tensor, metadata.shard_offsets, metadata.shard_sizes) diff --git a/mindnlp/core/distributed/_shard/api.py b/mindnlp/core/distributed/_shard/api.py new file mode 100644 index 000000000..08f9951e0 --- /dev/null +++ b/mindnlp/core/distributed/_shard/api.py @@ -0,0 +1,306 @@ +# mypy: allow-untyped-defs +from contextlib import contextmanager +from typing import Optional + +from mindnlp import core +from mindnlp import core.distributed as dist +from mindnlp import core.nn as nn +from core.distributed import distributed_c10d +from core.distributed._shard.sharded_tensor import ShardedTensor + +from .sharder import Sharder +from .sharding_plan import ShardingPlan +from .sharding_spec import ChunkShardingSpec, ShardingSpec + + +def _shard_tensor( + tensor: core.Tensor, sharding_spec: ShardingSpec, src_rank=0, process_group=None +) -> ShardedTensor: + """ + Given a :class:`core.Tensor`, it shards that tensor according to the provided + ``sharding_spec``. ``src_rank`` denotes the source rank which would be + used as the ground truth of the data which would be scattered as shards + across the rest of the ranks. + + Args: + tensor (:class:`core.Tensor`): Tensor needs to be sharded. + sharding_spec (:class:`core.distributed._shard.sharding_spec.ShardingSpec`): The specification + describing how to shard the Tensor. + + Keyword args: + src_rank (int, optional): The source rank which is used as the ground truth of + the data for the parameter that would be sharded and scattered + across the rest of the ranks. + Default: 0. + process_group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + + Returns: + A :class:`ShardedTensor` sharded from the given tensor. + + .. warning:: + Only :class:`core.distributed._shard.sharding_spec.ChunkShardingSpec` is + currently supported as the ``sharding_spec``. + """ + if not tensor.is_contiguous(): + raise ValueError("input tensor is not a contiguous Tensor") + + pg = ( + process_group + if process_group is not None + else distributed_c10d._get_default_group() + ) + world_size = dist.get_world_size(pg) + current_rank = dist.get_rank(pg) + + # Validate src_rank and sharding_spec are same across all ranks. + gathered_list = [None] * world_size + dist.all_gather_object(gathered_list, (src_rank, sharding_spec), group=pg) + + for idx, entry in enumerate(gathered_list): + if src_rank != entry[0]: # type: ignore[index] + raise ValueError( + f"src_rank={src_rank} on rank: {current_rank} does not " # type: ignore[index] + f"match with src_rank={entry[0]} on rank: {idx}" # type: ignore[index] + ) + if sharding_spec != entry[1]: # type: ignore[index] + raise ValueError( + f"sharding_spec={sharding_spec} on rank: {current_rank} does not " # type: ignore[index] + f"match with sharding_spec={entry[1]} on rank: {idx}" # type: ignore[index] + ) + + st = sharding_spec.shard(tensor, src_rank=src_rank, process_group=pg) + + return st + + +def shard_parameter( + module: core.nn.Module, + param_name: str, + sharding_spec: ShardingSpec, + src_rank=0, + process_group=None, +): + """ + Given a :class:`core.nn.Module`, a ``param_name`` for a parameter in that + module, it shards that parameter according to the provided + ``sharding_spec``. ``src_rank`` denotes the source rank which would be + used as the ground truth of the data which would be scattered as shards + across the rest of the ranks. + + This method replaces ``module.param_name`` with a + :class:`core.distributed._sharded_tensor.ShardedTensor` + + Args: + module (:class:`core.nn.Module`): Module whose parameter needs to be sharded. + param_name (str): Name of the parameter of ``module`` that needs to be sharded. + sharding_spec (:class:`core.distributed._shard.sharding_spec.ShardingSpec`): The specification + describing how to shard the Tensor. + + Keyword args: + src_rank (int, optional): The source rank which is used as the ground truth of + the data for the parameter that would be sharded and scattered + across the rest of the ranks. + Default: 0. + process_group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + + .. warning:: + Only :class:`core.distributed._shard.sharding_spec.ChunkShardingSpec` is + currently supported as the ``sharding_spec``. + """ + # Perform some validation first. + if not hasattr(module, param_name): + raise AttributeError(f"{module._get_name()} has no attribute `{param_name}`") + + tensor = getattr(module, param_name) + if not isinstance(tensor, core.Tensor): + raise ValueError( + f"Expected {type(module).__name__}.{param_name} to be a Tensor, but found {type(tensor).__name__}" + ) + + if not tensor.is_contiguous(): + raise ValueError(f"param: {param_name} is not a contiguous Tensor") + + st = _shard_tensor(tensor, sharding_spec, src_rank, process_group) + + # Replace param with ShardedTensor. + module.register_parameter(param_name, nn.Parameter(st)) + + +# Tracks the current process group in the load context manager. +_CURRENT_PROCESS_GROUP: Optional[dist.ProcessGroup] = None + + +@contextmanager +def load_with_process_group(process_group): + """ + Context manager to set the process group with which to load a ShardedTensor. + """ + global _CURRENT_PROCESS_GROUP + if _CURRENT_PROCESS_GROUP is not None: + raise RuntimeError( + 'ProcessGroup already set by previous "load_with_process_group" ' + "context manager" + ) + _CURRENT_PROCESS_GROUP = process_group + try: + yield process_group + finally: + _CURRENT_PROCESS_GROUP = None + + +def _get_current_process_group(): + """ + Retrieves the current process group set by ``load_with_process_group``. + If not set, it just returns the default group. + """ + global _CURRENT_PROCESS_GROUP + if _CURRENT_PROCESS_GROUP is None: + return distributed_c10d._get_default_group() + else: + return _CURRENT_PROCESS_GROUP + + +def _reshard_output( + module: core.nn.Module, resharding_spec: ShardingSpec +) -> core.nn.Module: + """ + Hook a module with output resharding in the forward pass according + to the given ``resharding_spec``. + + Args: + module (:class:`core.nn.Module`): Module whose output needs to be resharded. + resharding_spec (:class:`core.distributed._shard.sharding_spec.ShardingSpec`): + The specification describing how the output of the module will be resharded. + + Returns: + A :class:`core.nn.Module` object with reshard API hooked. + """ + + def hook_func(_module, _input, output): + if isinstance(output, ShardedTensor): + return output.reshard(resharding_spec) + return output + + module.register_forward_hook(hook_func) + return module + + +def _collect_local_shard(module: core.nn.Module) -> core.nn.Module: + """ + Hook a module with local shards collection in the forward pass. + + This API is typically used to convert a sharded representation back to data parallel + representation. In particular, it returns the local tensor for this Shard. If the + size along the sharding dimension for the local tensor is 1, this dimension is removed + from the final result. For example a [4, 16] ShardedTensor across 4 ranks is typically + a local Tensor of size [16] across each rank and not [1, 16] across each rank. + + Args: + module (:class:`core.nn.Module`): Module whose output is ShardedTensor and the + local tensor value needs to be returned. + + Returns: + A :class:`core.nn.Module` object with collection API hooked. + """ + + def hook_func(_module, _input, output): + if isinstance(output, ShardedTensor): + local_tensor = output.local_tensor() + # Squeeze the # of dimensions manually, only applicable to ChunkShardingSpec + sharding_spec = output._sharding_spec + if ( + isinstance(sharding_spec, ChunkShardingSpec) + and local_tensor.size(sharding_spec.dim) == 1 # type: ignore[attr-defined, arg-type] + ): + local_tensor = local_tensor.squeeze( + output._sharding_spec.dim # type: ignore[attr-defined] + ) + return local_tensor + + module.register_forward_hook(hook_func) + return module + + +def shard_module(module: nn.Module, plan: ShardingPlan, src_rank=0, process_group=None): + """ + Shards a given module according to the provided sharding `plan`. This method + first shards all the parameters according to the given sharding `plan`. Then if + `output_plan` and `return_local_tensor` are specified in the sharding `plan`, it + will tag the output of modules according `output_plan`, convert the module's + output back to data parallel according to `return_local_tensor`. + + Needs to be called on all ranks in an SPMD fashion. + + Args: + module (:class:`core.nn.Module`): The module to apply sharding to + plan (:class:`core.distributed._shard.sharding_plan.ShardingPlan`): + The ShardingPlan which specified param name to ShardingSpec to apply to + each parameter. + + Keyword args: + src_rank (int, optional): The source rank which is used as the ground truth of + the data for the module that would be sharded and scattered across the rest + of the ranks. + Default: 0. + process_group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + """ + # record Sharder paths for sanity check on the plan to ensure items in the plan + # does not conflict with the submodule tree that the Sharder is working with + sharder_paths = [] + for name, spec in plan.plan.items(): + if isinstance(spec, Sharder): + sharder_paths.append(name) + + # shard the parameter according to the ShardingPlan + for name, spec in plan.plan.items(): + if isinstance(spec, ShardingSpec): + # if found a sharding spec, try to shard the parameter + module_path, _, param_name = name.rpartition(".") + + for sharder_path in sharder_paths: + if module_path.startswith(sharder_path): + raise RuntimeError( + f"ShardingPlan is in-valid, trying to shard a parameter: {name}," + f" but there's already a Sharder entry for module {sharder_path}," + f" parameter sharding should not conflict with the submodule tree" + f" that a Sharder is working with!" + ) + + mod = module.get_submodule(module_path) + shard_parameter( + mod, param_name, spec, src_rank=src_rank, process_group=process_group + ) + elif isinstance(spec, Sharder): + parent_mod_path, _, _mod_name = name.rpartition(".") + if name == "": + raise KeyError("Module path must not be empty for custom sharder!") + mod = module.get_submodule(name) + parent_mod = module.get_submodule(parent_mod_path) + sharded_mod = spec.shard(mod) + # swap this submodule with the sharded module + parent_mod.mod_name = sharded_mod + else: + raise TypeError( + f"Only `ShardingSpec` and `Sharder` are supported to shard '{name}'" + ) + + # reshard output if there's an entry in `reshard_output` for this module + if plan.output_plan is not None: + for module_path, output_spec in plan.output_plan.items(): + if isinstance(output_spec, ShardingSpec): + mod = module.get_submodule(module_path) + _reshard_output(mod, output_spec) + else: + raise TypeError( + f"Only `ShardingSpec` is supported as output_plan for '{module_path}'" + ) + # convert the output back to data parallel for the modules appears in + # `return_local_tensor` of the plan, we will call `_collect_local_shard` + # to collect the local tensor for output of modules + if plan.return_local_tensor is not None: + for module_path in plan.return_local_tensor: + mod = module.get_submodule(module_path) + _collect_local_shard(mod) diff --git a/mindnlp/core/distributed/_shard/checkpoint/__init__.py b/mindnlp/core/distributed/_shard/checkpoint/__init__.py new file mode 100644 index 000000000..03afeffc6 --- /dev/null +++ b/mindnlp/core/distributed/_shard/checkpoint/__init__.py @@ -0,0 +1,19 @@ +# Keep old package for BC purposes, this file should be removed once +# everything moves to the `core.distributed.checkpoint` package. +import sys +import warnings + +from mindnlp import core +from core.distributed.checkpoint import * # noqa: F403 + + +with warnings.catch_warnings(): + warnings.simplefilter("always") + warnings.warn( + "`core.distributed._shard.checkpoint` will be deprecated, " + "use `core.distributed.checkpoint` instead", + DeprecationWarning, + stacklevel=2, + ) + +sys.modules["core.distributed._shard.checkpoint"] = core.distributed.checkpoint diff --git a/mindnlp/core/distributed/_shard/common_op_utils.py b/mindnlp/core/distributed/_shard/common_op_utils.py new file mode 100644 index 000000000..8f333d4b1 --- /dev/null +++ b/mindnlp/core/distributed/_shard/common_op_utils.py @@ -0,0 +1,65 @@ +# mypy: allow-untyped-defs +from typing import Optional + +from mindnlp import core +from core.utils import _pytree as pytree + + +def _basic_validation(op, args=(), kwargs=None): + """ + Common validation across all ops go in here. + """ + from core.distributed._shard.sharded_tensor import ShardedTensor + + if len(args) == 0 and (kwargs is None or len(kwargs) == 0): + raise ValueError(f" No input for '{op.__name__}'!") + + # Validate types + has_distributed_tensor = False + + def is_distributed_tensor(e): + nonlocal has_distributed_tensor + if isinstance(e, ShardedTensor): + has_distributed_tensor = True + + pytree.tree_map_(is_distributed_tensor, args) + pytree.tree_map_(is_distributed_tensor, kwargs) + + if not has_distributed_tensor: + raise TypeError( + f"torch function '{op.__name__}', with args: {args} and " + f"kwargs: {kwargs} are called without any distributed tensor!" + ) + + # Validate all distributed tensors use the same PG. + cur_pg: Optional[core.distributed.ProcessGroup] = None + + def validate_pg(e): + nonlocal cur_pg + if isinstance(e, ShardedTensor): + if cur_pg is not None and e._process_group is not cur_pg: + raise RuntimeError( + "All distributed tensors should use the " + "same ProcessGroup if used together in an op." + ) + cur_pg = e._process_group + + pytree.tree_map_(validate_pg, args) + pytree.tree_map_(validate_pg, kwargs) + + +def _register_default_op(op, decorator): + @decorator(op) + def tensor_default_op(types, args=(), kwargs=None, pg=None): + """ + Handles ``__torch_function__`` dispatch for the default tensor ops that + behave the same as ``core.Tensor`` such as ``core.Tensor.shape`` or + ``core.Tensor.dtype``. We simply lower to the real op call with + DisableTorchFunctionSubclass context like ``core.Tensor.__torch_function__`` + to avoid recursions. + """ + if kwargs is None: + kwargs = {} + + with core._C.DisableTorchFunctionSubclass(): + return op(*args, **kwargs) diff --git a/mindnlp/core/distributed/_shard/metadata.py b/mindnlp/core/distributed/_shard/metadata.py new file mode 100644 index 000000000..5f9672b27 --- /dev/null +++ b/mindnlp/core/distributed/_shard/metadata.py @@ -0,0 +1,64 @@ +# mypy: allow-untyped-defs +from dataclasses import dataclass +from functools import reduce +from typing import List, Optional, Union + +from core.distributed.remote_device import _remote_device + + +@dataclass +class ShardMetadata: + """ + Represents a shard of the overall Tensor including its + offsets, lengths and device placement. + + Args: + shard_offsets(List[int]): Offsets in the original tensor indicating + the start offsets for this shard. Should have the same rank as + the original tensor. + shard_sizes(List[int]): Integers indicating the size of each + dimension for this shard. Should have the same rank as the + original tensor. + placement(:class:`core.distributed._remote_device`): + Specifies the placement of this shard. + """ + + __slots__ = ["shard_offsets", "shard_sizes", "placement"] + + shard_offsets: List[int] + shard_sizes: List[int] + placement: Optional[_remote_device] + + def __init__( + self, + shard_offsets: List[int], + shard_sizes: List[int], + placement: Optional[Union[str, _remote_device]] = None, + ): + self.shard_offsets = shard_offsets + self.shard_sizes = shard_sizes + if isinstance(placement, str): + self.placement = _remote_device(placement) + else: + self.placement = placement + if len(self.shard_offsets) != len(self.shard_sizes): + raise ValueError( + f"shard_offsets and shard_sizes should have " + f"the same number of elements, found {len(self.shard_offsets)} " + f"and {self.shard_sizes} respectively" + ) + + for i in range(len(self.shard_offsets)): + if self.shard_offsets[i] < 0: + raise ValueError("shard_offsets should be >=0") + if self.shard_sizes[i] < 0: + raise ValueError("shard_sizes should be >= 0") + + def __hash__(self): + def _hash_reduce(a, b): + return (a << 8) + hash(b) + + res = reduce(_hash_reduce, self.shard_offsets, 37) + res = reduce(_hash_reduce, self.shard_sizes, res) + res = _hash_reduce(res, self.placement) + return res diff --git a/mindnlp/core/distributed/_shard/op_registry_utils.py b/mindnlp/core/distributed/_shard/op_registry_utils.py new file mode 100644 index 000000000..669af1cbb --- /dev/null +++ b/mindnlp/core/distributed/_shard/op_registry_utils.py @@ -0,0 +1,41 @@ +# mypy: allow-untyped-defs +import functools +from inspect import signature + +from .common_op_utils import _basic_validation + + +""" +Common utilities to register ops on ShardedTensor +and PartialTensor. +""" + + +def _register_op(op, func, op_table): + """ + Performs basic validation and registers the provided op in the given + op_table. + """ + if len(signature(func).parameters) != 4: + raise TypeError( + f"Custom sharded op function expects signature: " + f"(types, args, kwargs, process_group), but received " + f"signature: {signature(func)}" + ) + + op_table[op] = func + + +def _decorator_func(wrapped_func, op, op_table): + """ + Decorator function to register the given ``op`` in the provided + ``op_table`` + """ + + @functools.wraps(wrapped_func) + def wrapper(types, args, kwargs, process_group): + _basic_validation(op, args, kwargs) + return wrapped_func(types, args, kwargs, process_group) + + _register_op(op, wrapper, op_table) + return wrapper diff --git a/mindnlp/core/distributed/_shard/sharded_optim/__init__.py b/mindnlp/core/distributed/_shard/sharded_optim/__init__.py new file mode 100644 index 000000000..85310cc6f --- /dev/null +++ b/mindnlp/core/distributed/_shard/sharded_optim/__init__.py @@ -0,0 +1,52 @@ +from typing import Iterator, Tuple, Union + +from mindnlp import core.nn as nn +from core.distributed._shard.sharded_tensor import ShardedTensor + +from .api import ShardedOptimizer + + +def named_params_with_sharded_tensor( + module: nn.Module, + prefix: str = "", + recurse: bool = True, +) -> Iterator[Tuple[str, Union[nn.Parameter, ShardedTensor]]]: + r"""Returns an iterator over module parameters (together with the + ShardedTensor parameters), yielding both the name of the parameter + as well as the parameter itself. This is typically passed to a + :class:core.distributed._shard.sharded_optim.ShardedOptimizer + + Args: + prefix (str): prefix to prepend to all parameter names. + recurse (bool): if True, then yields parameters of this module + and all submodules. Otherwise, yields only parameters that + are direct members of this module. + + Yields: + (str, Union[Tensor, ShardedTensor]): Tuple containing + the name and parameter (or ShardedTensor parameter) + + Example:: + + >>> # xdoctest: +SKIP + >>> model = core.nn.Linear(*linear_size) + >>> shard_parameter(model, "weight", spec) + >>> for name, param in named_params_with_sharded_tensor(model): + >>> if name in ['weight']: + >>> print(param.size()) + + """ + modules = module.named_modules(prefix=prefix) if recurse else [(prefix, module)] + + memo = set() + for mod_prefix, mod in modules: + # find all sharded tensor params + for name, val in vars(mod).items(): + if isinstance(val, ShardedTensor) and val not in memo: + memo.add(val) + name = mod_prefix + ("." if mod_prefix else "") + name + yield name, val + + # find all nn.Parameters + for name, val in module.named_parameters(): + yield name, val diff --git a/mindnlp/core/distributed/_shard/sharded_optim/api.py b/mindnlp/core/distributed/_shard/sharded_optim/api.py new file mode 100644 index 000000000..6c82eb0c3 --- /dev/null +++ b/mindnlp/core/distributed/_shard/sharded_optim/api.py @@ -0,0 +1,101 @@ +# mypy: allow-untyped-defs +from typing import Any, Dict, List, Mapping, Union + +from mindnlp import core.optim as optim +from mindnlp.core import Tensor +from core.distributed._shard.sharded_tensor import ShardedTensor + + +class ShardedOptimizer(optim.Optimizer): + def __init__( + self, + named_params: Mapping[str, Union[Tensor, ShardedTensor]], + optimizer_class, + *optimizer_args, + **optimizer_kwargs, + ): + """ + ShardedOptimizer collects all tensors and local shard tensors of + ShardedTensor, then use these tensors as ``params`` for optimizers + + Args: + named_params (Dict[str, Union[Tensor, ShardedTensor]]) : a Dict + of parameters, where key is the parameter key, value is either + Tensor or ShardedTensor parameter. + optimizer_class (core.optim.Optimizer): the Optimizer to use + locally, i.e. core.optim.SGD, core.optim.Adagrad, etc. + *optimizer_args: the arguments to initialize the optimizer. + **optimizer_kwargs: the key-word arguments to initialize the optimizer. + + """ + tensors: List[Tensor] = [] + for value in named_params.values(): + if isinstance(value, ShardedTensor): + tensors.extend( + local_shard.tensor for local_shard in value.local_shards() + ) + else: + tensors.append(value) + + self.named_params = named_params + self._optim = optimizer_class(tensors, *optimizer_args, **optimizer_kwargs) + self.param_groups = self._optim.param_groups + self.state = self._optim.state + + def zero_grad(self, set_to_none: bool = True): # type: ignore[override] + r"""Resets the gradients of all optimized :class:`core.Tensor` s. + + Args: + set_to_none (bool): instead of setting to zero, set the grads to None. + This will in general have lower memory footprint, and can modestly improve performance. + However, it changes certain behaviors. For example: + 1. When the user tries to access a gradient and perform manual ops on it, + a None attribute or a Tensor full of 0s will behave differently. + 2. If the user requests ``zero_grad(set_to_none=True)`` followed by a backward pass, ``.grad``\ s + are guaranteed to be None for params that did not receive a gradient. + 3. ``core.optim`` optimizers have a different behavior if the gradient is 0 or None + (in one case it does the step with a gradient of 0 and in the other it skips + the step altogether). + """ + self._optim.zero_grad(set_to_none) + + def step(self, closure=None): + r"""Performs a single optimization step (parameter update). + + Args: + closure (Callable): A closure that reevaluates the model and + returns the loss. Optional for most optimizers. + + .. note:: + Unless otherwise specified, this function should not modify the + ``.grad`` field of the parameters. + """ + self._optim.step(closure) + + def state_dict(self) -> Dict[str, Any]: + """ + Returned state and param_groups will contain parameter keys + instead of parameter indices like core.optim.Optimizer. + This allows for advanced functionality like optimizer re-sharding to be implemented. + """ + # TODO: implement state_dict + raise NotImplementedError("ShardedOptimizer state_dict not implemented yet!") + + def load_state_dict(self, state_dict: Mapping[str, Any]): + r"""Loads the ShardedOptimizer state. + + Args: + state_dict (dict): ShardedOptimizer state. Should be an object returned + from a call to :meth:`state_dict`. + """ + # TODO: implement load_state_dict + raise NotImplementedError( + "ShardedOptimizer load_state_dict not implemented yet!" + ) + + def add_param_group(self, param_group: Any): + r"""Add a new param group""" + # TODO: implement add_param_group + raise NotImplementedError( + "ShardedOptimizer add_param_group not implemented yet!" + ) diff --git a/mindnlp/core/distributed/_shard/sharded_tensor/__init__.py b/mindnlp/core/distributed/_shard/sharded_tensor/__init__.py new file mode 100644 index 000000000..2135061d7 --- /dev/null +++ b/mindnlp/core/distributed/_shard/sharded_tensor/__init__.py @@ -0,0 +1,490 @@ +# mypy: allow-untyped-defs +import functools +from typing import List, TYPE_CHECKING + +from mindnlp import core +from core.distributed._shard.op_registry_utils import _decorator_func + +from .api import ( + _CUSTOM_SHARDED_OPS, + _SHARDED_OPS, + Shard, + ShardedTensor, + ShardedTensorBase, + ShardedTensorMetadata, + TensorProperties, +) +from .metadata import ShardMetadata # noqa: F401 + + +if TYPE_CHECKING: + from core.distributed._shard.sharding_spec import ShardingSpec +else: + ShardingSpec = "ShardingSpec" + + +def empty( + sharding_spec: ShardingSpec, + *size, + dtype=None, + layout=core.strided, + requires_grad=False, + pin_memory=False, + memory_format=core.contiguous_format, + process_group=None, + init_rrefs=False, +) -> ShardedTensor: + """ + Returns a :class:`ShardedTensor` filled with uninitialized data. + Needs to be called on all ranks in an SPMD fashion. + + Args: + sharding_spec (:class:`core.distributed._shard.sharding_spec.ShardingSpec`): The specification + describing how to shard the Tensor. + size (int...): a sequence of integers defining the shape of the output + tensor. Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + dtype (:class:`core.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`core.set_default_dtype`). + layout (:class:`core.layout`, optional): the desired layout of returned Tensor. + Default: ``core.strided``. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + memory_format (:class:`core.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``core.contiguous_format``. + process_group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + init_rrefs (bool, optional): Whether or not to initialize + :class:`core.distributed.rpc.RRef`s pointing to remote shards. + Need to initialize the RPC Framework if specified as ``True``. + Default: ``False``. + + Returns: + A :class:`ShardedTensor` object on each rank + """ + return ShardedTensor( + sharding_spec, + *size, + dtype=dtype, + layout=layout, + requires_grad=requires_grad, + pin_memory=pin_memory, + memory_format=memory_format, + process_group=process_group, + init_rrefs=init_rrefs, + ) + + +def ones( + sharding_spec: ShardingSpec, + *size, + dtype=None, + layout=core.strided, + requires_grad=False, + pin_memory=False, + memory_format=core.contiguous_format, + process_group=None, + init_rrefs=False, +) -> ShardedTensor: + """ + Returns a :class:`ShardedTensor` with the scalar value 1. + Needs to be called on all ranks in an SPMD fashion. + + Args: + sharding_spec (:class:`core.distributed._shard.sharding_spec.ShardingSpec`): The specification + describing how to shard the Tensor. + size (int...): a sequence of integers defining the shape of the output + tensor. Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + dtype (:class:`core.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`core.set_default_dtype`). + layout (:class:`core.layout`, optional): the desired layout of returned Tensor. + Default: ``core.strided``. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + process_group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + init_rrefs (bool, optional): Whether or not to initialize + :class:`core.distributed.rpc.RRef`s pointing to remote shards. + Need to initialize the RPC Framework if specified as ``True``. + Default: ``False``. + + Returns: + A :class:`ShardedTensor` object on each rank + """ + return full( + sharding_spec, + size, + fill_value=1, + dtype=dtype, + layout=layout, + requires_grad=requires_grad, + pin_memory=pin_memory, + memory_format=memory_format, + process_group=process_group, + init_rrefs=init_rrefs, + ) + + +def zeros( + sharding_spec: ShardingSpec, + *size, + dtype=None, + layout=core.strided, + requires_grad=False, + pin_memory=False, + memory_format=core.contiguous_format, + process_group=None, + init_rrefs=False, +) -> ShardedTensor: + """ + Returns a :class:`ShardedTensor` filled with the scalar value 0. + Needs to be called on all ranks in an SPMD fashion. + + Args: + sharding_spec (:class:`core.distributed._shard.sharding_spec.ShardingSpec`): The specification + describing how to shard the Tensor. + size (int...): a sequence of integers defining the shape of the output + tensor. Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + dtype (:class:`core.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`core.set_default_dtype`). + layout (:class:`core.layout`, optional): the desired layout of returned Tensor. + Default: ``core.strided``. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + process_group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + init_rrefs (bool, optional): Whether or not to initialize + :class:`core.distributed.rpc.RRef`s pointing to remote shards. + Need to initialize the RPC Framework if specified as ``True``. + Default: ``False``. + + Returns: + A :class:`ShardedTensor` object on each rank + """ + return full( + sharding_spec, + size, + fill_value=0, + dtype=dtype, + layout=layout, + requires_grad=requires_grad, + pin_memory=pin_memory, + memory_format=memory_format, + process_group=process_group, + init_rrefs=init_rrefs, + ) + + +def full( + sharding_spec: ShardingSpec, + size, + fill_value, + *, + dtype=None, + layout=core.strided, + requires_grad=False, + pin_memory=False, + memory_format=core.contiguous_format, + process_group=None, + init_rrefs=False, +) -> ShardedTensor: + """ + Creates a :class:`ShardedTensor` filled with fill_value. The tensor's dtype + is inferred from fill_value. If dtype is specified, it will override the + inferred type from fill_value. Needs to be called on all ranks in an SPMD fashion. + Args: + sharding_spec (:class:`core.distributed._sharding_spec.ShardingSpec`): The specification + describing how to shard the Tensor. + size (int...): a list, tuple, or `core.Size` of integers defining the shape of the + output tensor. + fill_value (Scalar) - the value to fill the output tensor with. + Keyword args: + dtype (:class:`core.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`core.set_default_dtype`). + layout (:class:`core.layout`, optional): the desired layout of returned Tensor. + Default: ``core.strided``. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + process_group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + init_rrefs (bool, optional): Whether or not to initialize + :class:`core.distributed.rpc.RRef`s pointing to remote shards. + Need to initialize the RPC Framework if specified as ``True``. + Default: ``False``. + Returns: + A :class:`ShardedTensor` object on each rank + """ + sharded_tensor = ShardedTensor( + sharding_spec, + *size, + dtype=dtype, + layout=layout, + requires_grad=requires_grad, + pin_memory=pin_memory, + memory_format=memory_format, + process_group=process_group, + init_rrefs=init_rrefs, + ) + core.nn.init.constant_(sharded_tensor, fill_value) # type: ignore[arg-type] + return sharded_tensor + + +def rand( + sharding_spec: ShardingSpec, + *size, + dtype=None, + layout=core.strided, + requires_grad=False, + pin_memory=False, + memory_format=core.contiguous_format, + process_group=None, + init_rrefs=False, +) -> ShardedTensor: + """ + Creates a :class:`ShardedTensor` filled with random numbers from a uniform distribution + on the interval :math:`[0, 1)`. The shape of the tensor is defined by the + variable argument `size`. Needs to be called on all ranks in an SPMD fashion. + + Args: + sharding_spec (:class:`core.distributed._shard.sharding_spec.ShardingSpec`): The specification + describing how to shard the Tensor. + size (int...): a list, tuple, or `core.Size` of integers defining the shape of the + output tensor. + + Keyword args: + dtype (:class:`core.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`core.set_default_dtype`). + layout (:class:`core.layout`, optional): the desired layout of returned Tensor. + Default: ``core.strided``. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + process_group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + init_rrefs (bool, optional): Whether or not to initialize + :class:`core.distributed.rpc.RRef`s pointing to remote shards. + Need to initialize the RPC Framework if specified as ``True``. + Default: ``False``. + + Returns: + A :class:`ShardedTensor` object on each rank + """ + sharded_tensor = ShardedTensor( + sharding_spec, + *size, + dtype=dtype, + layout=layout, + requires_grad=requires_grad, + pin_memory=pin_memory, + memory_format=memory_format, + process_group=process_group, + init_rrefs=init_rrefs, + ) + core.nn.init.uniform_(sharded_tensor, 0, 1) # type: ignore[arg-type] + return sharded_tensor + + +def randn( + sharding_spec: ShardingSpec, + *size, + dtype=None, + layout=core.strided, + requires_grad=False, + pin_memory=False, + memory_format=core.contiguous_format, + process_group=None, + init_rrefs=False, +) -> ShardedTensor: + """ + Creates a :class:`ShardedTensor` filled with random numbers from a uniform distribution + with mean `0` and variance `1` (also called standard normal distribution). The shape + of the tensor is defined by the variable argument `size`. Needs to be called on all ranks + in an SPMD fashion. + + Args: + sharding_spec (:class:`core.distributed._shard.sharding_spec.ShardingSpec`): The specification + describing how to shard the Tensor. + size (int...): a list, tuple, or `core.Size` of integers defining the shape of the + output tensor. + + Keyword args: + dtype (:class:`core.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`core.set_default_dtype`). + layout (:class:`core.layout`, optional): the desired layout of returned Tensor. + Default: ``core.strided``. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + process_group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + init_rrefs (bool, optional): Whether or not to initialize + :class:`core.distributed.rpc.RRef`s pointing to remote shards. + Need to initialize the RPC Framework if specified as ``True``. + Default: ``False``. + + Returns: + A :class:`ShardedTensor` object on each rank + """ + sharded_tensor = ShardedTensor( + sharding_spec, + *size, + dtype=dtype, + layout=layout, + requires_grad=requires_grad, + pin_memory=pin_memory, + memory_format=memory_format, + process_group=process_group, + init_rrefs=init_rrefs, + ) + core.nn.init.normal_(sharded_tensor, 0, 1) # type: ignore[arg-type] + return sharded_tensor + + +def init_from_local_shards( + local_shards: List[Shard], *global_size, process_group=None, init_rrefs=False +) -> ShardedTensor: + """ + Creates an :class:`ShardedTensor` from local shards and the global metadata. + Needs to be called on all ranks in an SPMD fashion. + + Args: + local_shards (List[:class `core.distributed._shard.sharded_tensor.Shard`]): A list + of shards that represent the local shards on this rank. + global_size (int...): a list, tuple, or `core.Size` of integers defining the + shape of the overall sharded tensor. + + Keyword args: + process_group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + init_rrefs (bool, optional): Whether or not to initialize + :class:`core.distributed.rpc.RRef`s pointing to remote shards. + Need to initialize the RPC Framework if specified as ``True``. + Default: ``False``. + + Returns: + A :class:`ShardedTensor` object handle on this rank + + + Examples: + Suppose we want construct a sharded tensor on two ranks, global size = (10, 5), + each shard have a (5, 5) local tensor, we can do it like below: + + on rank 0: + >>> # xdoctest: +SKIP("not distributed") + >>> local_shard_metadata = ShardMetadata( + >>> shard_offsets=[0, 0], + >>> shard_lengths=[5, 5], + >>> placement="rank:0/cuda:0" + >>> ) + >>> local_shards = [Shard(core.randn(5, 5), local_shard_metadata)] + >>> sharded_tensor = init_from_local_shards(local_shards, [10, 5]) + + on rank 1: + >>> # xdoctest: +SKIP("not distributed") + >>> local_shard_metadata = ShardMetadata( + >>> shard_offsets=[5, 0], + >>> shard_lengths=[5, 5], + >>> placement="rank:1/cuda:1" + >>> ) + >>> local_shards = [Shard(core.randn(5, 5), local_shard_metadata)] + >>> sharded_tensor = init_from_local_shards(local_shards, [10, 5]) + """ + return ShardedTensor._init_from_local_shards( + local_shards, *global_size, process_group=process_group, init_rrefs=init_rrefs + ) + + +def state_dict_hook(module, destination, prefix, local_metadata): + """ + Hook to add ShardedTensor to Module's ``state_dict``. Needs to be + registered to the Module using + :meth:`core.nn.Module._register_state_dict_hook`. + """ + for submodule_name, submodule in module.named_modules(): + for attr_name, attr in submodule.__dict__.items(): + if isinstance(attr, ShardedTensor): + mod_prefix = prefix + submodule_name + key = mod_prefix + ("." if mod_prefix else "") + attr_name + destination[key] = attr + + +def pre_load_state_dict_hook( + module, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, +): + """ + Pre-load state dict hook to add ShardedTensor to the module. + """ + for submodule_name, submodule in module.named_modules(): + for attr_name in submodule.__dict__.keys(): + mod_prefix = prefix + submodule_name + key = mod_prefix + ("." if mod_prefix else "") + attr_name + if key in state_dict: + if isinstance(state_dict[key], ShardedTensor): + setattr(submodule, attr_name, state_dict[key]) + + +def custom_sharded_op_impl(func): + """ + Provides a way for users to write their own custom sharded operator. This + can be used to override existing ShardedTensor operators or write a new + one not supported by ShardedTensor. If the operator in question is covered + by ``__torch_function__`` dispatch and has a ShardedTensor as any of its + parameters, the function provided will be invoked for that operator. + + Example:: + >>> # xdoctest: +SKIP + >>> @custom_sharded_op_impl(core.nn.functional.linear) + >>> def my_custom_sharded_linear(types, args, kwargs, process_group): + >>> ... + >>> # xdoctest: +SKIP("Undefined variables") + >>> input = core.rand(10, 32) + >>> weight = sharded_tensor.rand(32, 16) + >>> bias = core.rand(16) + >>> # This will call 'my_custom_sharded_linear' + >>> core.nn.functional.linear(input, weight, bias) + + The types, args and kwargs parameters are the same parameters that are + passed to ``__torch_function__`` dispatch API + (https://pycore.org/docs/stable/notes/extending.html#extending-torch). + There is an additional ``process_group`` parameter which is the + process_group used for the ShardedTensor and can be used by + implementations for communications within a sharded implementation. + + Args: + func(Callable): Torch function for which we want to provide a sharded + implementation (ex: core.nn.functional.linear) + """ + return functools.partial(_decorator_func, op=func, op_table=_CUSTOM_SHARDED_OPS) + + +def _sharded_op_impl(func): + """ + Decorator to register a default sharded op. + """ + return functools.partial(_decorator_func, op=func, op_table=_SHARDED_OPS) + + +# Import all builtin sharded ops +from ._ops import * # noqa: F403 diff --git a/mindnlp/core/distributed/_shard/sharded_tensor/_ops/__init__.py b/mindnlp/core/distributed/_shard/sharded_tensor/_ops/__init__.py new file mode 100644 index 000000000..b34a3de1d --- /dev/null +++ b/mindnlp/core/distributed/_shard/sharded_tensor/_ops/__init__.py @@ -0,0 +1,13 @@ +from mindnlp import core.distributed._shard.sharded_tensor._ops.misc_ops +from mindnlp import core.distributed._shard.sharded_tensor._ops.tensor_ops + +# Import all ChunkShardingSpec ops +from core.distributed._shard.sharding_spec.chunk_sharding_spec_ops.embedding import ( + sharded_embedding, +) +from core.distributed._shard.sharding_spec.chunk_sharding_spec_ops.embedding_bag import ( + sharded_embedding_bag, +) + +from .binary_cmp import allclose, equal +from .init import constant_, kaiming_uniform_, normal_, uniform_ diff --git a/mindnlp/core/distributed/_shard/sharded_tensor/_ops/_common.py b/mindnlp/core/distributed/_shard/sharded_tensor/_ops/_common.py new file mode 100644 index 000000000..ee29a4393 --- /dev/null +++ b/mindnlp/core/distributed/_shard/sharded_tensor/_ops/_common.py @@ -0,0 +1,113 @@ +# mypy: allow-untyped-defs +import functools + +from core.distributed._shard.common_op_utils import _basic_validation +from core.distributed._shard.sharded_tensor import ( + _sharded_op_impl, + Shard, + ShardedTensor, +) + + +def _sharded_op_common(op, early_stop_func, extra_check): + """ + Inject sharded tensor op registration with common logics executed before + different behaviors are done on either local shards or a local tensor. + + Example:: + >>> # xdoctest: +SKIP("Undefined variables") + >>> op = core.transpose + >>> @_sharded_op_impl(op) + >>> @_sharded_op_common(op, early_stop_func, extra_check) + >>> def sharded_tensor_op(types, args, kwargs, process_group): + >>> ... + >>> + >>> st = sharded_tensor.rand(32, 16) + >>> st.transpose(1, 2) + >>> # This will call '_sharded_op_common' + + Args: + op: The op to be registered and applied to all shards of the st. + early_stop_func (Callable, optional): the func for early stop. + Default: if ``None``, no early stop. + extra_check (Callable, optional): the func for extra condition check. + Default: if ``None``, no extra check. + + Return: + func (Callable): Torch function for which we want to provide a sharded + implementation (ex: core.transpose) + """ + + def decorator_sharded_func(wrapped_func): + @functools.wraps(wrapped_func) + def wrapper(types, args=(), kwargs=None, pg=None): + _basic_validation(op, args, kwargs) + + st = args[0] + if kwargs is None: + kwargs = {} + if extra_check: + extra_check(*args, **kwargs) + if early_stop_func: + early_stop = early_stop_func(*args, **kwargs) + if early_stop: + return st + return wrapped_func(types, args, kwargs, pg) + + return wrapper + + return decorator_sharded_func + + +def _register_sharded_op_on_local_shards( + op, early_stop_func=None, extra_check=None, customized_func=None +): + """ + Handles ``__torch_function__`` dispatch for ops which are performed on + each shard of the sharded tensor such as elementwise op like + ``core.nn.functional.gelu`` or ``core.nn.functional.relu``. + + For more complicated ops, a customized func can be used to generate + the new shards and sharded tensor size. + + This function expects that the original ShardingSpec for the ShardedTensor + is preserved irrespective of whether or not a customized function is used. + + Args: + op: The op to be registered and applied to all shards of the st. + early_stop_func (Callable, optional): the func for early stop. + Default: if ``None``, no early stop. + extra_check (Callable, optional): the func for extra condition check. + Default: if ``None``, no extra check. + customized_func (Callable, optional): the func for customized logic + to generate new shards and sharded tensor size. + Default: if ``None``, we simply lower to the real op call with + all local shards of the st. + + Return: + func (Callable): registered implementation for sharded op for + ``__torch_function__`` dispatch. + """ + + @_sharded_op_impl(op) + @_sharded_op_common(op, early_stop_func, extra_check) + def sharded_tensor_op_on_local_shards(types, args=(), kwargs=None, pg=None): + st = args[0] + st_metadata = st.metadata() + local_shards = st.local_shards() + local_shards_new = [] + if customized_func: + local_shards_new, st_metadata = customized_func(args, kwargs, pg) + else: + for local_shard in local_shards: + args = (local_shard.tensor, *args[1:]) + local_shards_new.append( + Shard(op(*args, **kwargs), local_shard.metadata) + ) + return ShardedTensor._init_from_local_shards_and_global_metadata( + local_shards_new, + st_metadata, + process_group=pg, + init_rrefs=st._init_rrefs, + sharding_spec=st.sharding_spec(), + ) diff --git a/mindnlp/core/distributed/_shard/sharded_tensor/_ops/binary_cmp.py b/mindnlp/core/distributed/_shard/sharded_tensor/_ops/binary_cmp.py new file mode 100644 index 000000000..06787804d --- /dev/null +++ b/mindnlp/core/distributed/_shard/sharded_tensor/_ops/binary_cmp.py @@ -0,0 +1,78 @@ +# mypy: allow-untyped-defs +from mindnlp import core +from mindnlp import core.distributed as dist +from mindnlp import core.distributed.distributed_c10d as distributed_c10d +from core.distributed._shard.sharded_tensor import _sharded_op_impl, ShardedTensor + + +def _communicate_result(result, pg): + # Gather results from all ranks. + if result: + result_tensor = core.ones(1, device=core.device(core.cuda.current_device())) + else: + result_tensor = core.zeros(1, device=core.device(core.cuda.current_device())) + + dist.all_reduce(result_tensor, group=pg) + + expected_result = core.ones( + 1, device=core.device(core.cuda.current_device()) + ) * dist.get_world_size(pg) + + return core.equal(result_tensor, expected_result) + + +def binary_cmp(cmp_fun, types, args, kwargs=None, process_group=None): + if len(args) != 2: + raise ValueError(f"Expected two arguments for core.{cmp_fun.__name__}") + + st1 = args[0] + st2 = args[1] + if not (isinstance(st1, ShardedTensor) and isinstance(st2, ShardedTensor)): + raise TypeError( + f"Both arguments to core.{cmp_fun.__name__} need to be of type ShardedTensor" + ) + + # Verify same PG + if st1._process_group != st2._process_group: + return False + + if distributed_c10d._rank_not_in_group( + st1._process_group + ) or distributed_c10d._rank_not_in_group(st2._process_group): + return distributed_c10d._rank_not_in_group( + st1._process_group + ) == distributed_c10d._rank_not_in_group(st2._process_group) + + # Verify metadata + if st1.metadata() != st2.metadata(): + return _communicate_result(False, st1._process_group) + + # Verify number of local shards + st1_local_shards = st1.local_shards() + st2_local_shards = st2.local_shards() + if len(st1_local_shards) != len(st2_local_shards): + return _communicate_result(False, st1._process_group) + + # kwargs must be dict-like + if kwargs is None: + kwargs = {} + # Verify each local shard + for idx in range(len(st1_local_shards)): + if st1_local_shards[idx].metadata != st2_local_shards[idx].metadata: + return _communicate_result(False, st1._process_group) + if not cmp_fun( + st1_local_shards[idx].tensor, st2_local_shards[idx].tensor, **kwargs + ): + return _communicate_result(False, st1._process_group) + + return _communicate_result(True, st1._process_group) + + +@_sharded_op_impl(core.equal) +def equal(types, args, kwargs, process_group): + return binary_cmp(core.equal, types, args, kwargs, process_group) + + +@_sharded_op_impl(core.allclose) +def allclose(types, args, kwargs, process_group): + return binary_cmp(core.allclose, types, args, kwargs, process_group) diff --git a/mindnlp/core/distributed/_shard/sharded_tensor/_ops/init.py b/mindnlp/core/distributed/_shard/sharded_tensor/_ops/init.py new file mode 100644 index 000000000..9e2c16bc3 --- /dev/null +++ b/mindnlp/core/distributed/_shard/sharded_tensor/_ops/init.py @@ -0,0 +1,151 @@ +# mypy: allow-untyped-defs +from mindnlp import core +from mindnlp import core.distributed._shard.sharded_tensor as sharded_tensor +from core.distributed._shard.sharded_tensor import _sharded_op_impl + + +def validate_param(param, param_name): + if param is None: + raise ValueError(f"param: {param_name} shouldn't be None!") + + +@_sharded_op_impl(core.nn.init.uniform_) +def uniform_(types, args=(), kwargs=None, pg=None): + r""" + Fills the Tensor in tensor.local_shards with values drawn from the uniform + distribution :math:`\mathcal{U}(a, b)`. + Args: + tensor: tensor sharded across devices + a: the lower bound of the uniform distribution + b: the upper bound of the uniform distribution + """ + validate_param(kwargs, "kwargs") + sharded_tensor = kwargs["tensor"] + validate_param(sharded_tensor, "tensor") + a = kwargs["a"] + validate_param(a, "a") + b = kwargs["b"] + validate_param(b, "b") + + for shard in sharded_tensor.local_shards(): + core.nn.init.uniform_(shard.tensor, a=a, b=b) + return sharded_tensor + + +@_sharded_op_impl(core.nn.init.normal_) +def normal_(types, args=(), kwargs=None, pg=None): + r""" + Fills the Tensors in tensor.local_shards with values drawn from the normal + distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`. + Args: + tensor: tensor sharded across devices + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + """ + validate_param(kwargs, "kwargs") + sharded_tensor = kwargs["tensor"] + validate_param(sharded_tensor, "tensor") + mean = kwargs["mean"] + validate_param(mean, "mean") + std = kwargs["std"] + validate_param(std, "std") + + for shard in sharded_tensor.local_shards(): + core.nn.init.normal_(shard.tensor, mean=mean, std=std) + return sharded_tensor + + +@_sharded_op_impl(core.nn.init.kaiming_uniform_) +def kaiming_uniform_(types, args=(), kwargs=None, pg=None): + r""" + Fills the Tensors in tensor.local_shards with values according to the method + described in `Delving deep into rectifiers: Surpassing human-level + performance on ImageNet classification` - He, K. et al. (2015), using a + uniform distribution. The resulting tensor will have values sampled from + :math:`\mathcal{U}(-\text{bound}, \text{bound})` where + .. math:: + \text{bound} = \text{gain} \times \sqrt{\frac{3}{\text{fan\_mode}}} + Also known as He initialization. + Args: + tensor: tensor sharded across devices + a: the negative slope of the rectifier used after this layer (only + used with ``'leaky_relu'``) + mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'`` + preserves the magnitude of the variance of the weights in the + forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the + backwards pass. + nonlinearity: the non-linear function (`nn.functional` name), + recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default). + """ + validate_param(kwargs, "kwargs") + sharded_tensor = kwargs["tensor"] + validate_param(sharded_tensor, "tensor") + a = kwargs["a"] + validate_param(a, "a") + mode = kwargs["mode"] + validate_param(mode, "mode") + nonlinearity = kwargs["nonlinearity"] + validate_param(nonlinearity, "nonlinearity") + + for shard in sharded_tensor.local_shards(): + core.nn.init.kaiming_uniform_( + shard.tensor, a=a, mode=mode, nonlinearity=nonlinearity + ) + return sharded_tensor + + +@_sharded_op_impl(core.nn.init.constant_) +def constant_(types, args=(), kwargs=None, pg=None): + r""" + Fills the input ShardedTensor with the value \text{val}val. + Args: + tensor: tensor sharded across devices + val: the value to fill the tensor with + """ + validate_param(kwargs, "kwargs") + sharded_tensor = kwargs["tensor"] + validate_param(sharded_tensor, "tensor") + val = kwargs["val"] + validate_param(val, "val") + for shard in sharded_tensor.local_shards(): + core.nn.init.constant_(shard.tensor, val=val) + return sharded_tensor + + +tensor_like_creation_op_map = { + core.full_like: sharded_tensor.full, + core.empty_like: sharded_tensor.empty, + core.zeros_like: sharded_tensor.zeros, + core.ones_like: sharded_tensor.ones, + core.rand_like: sharded_tensor.rand, + core.randn_like: sharded_tensor.randn, +} + + +# tensor ops that behave the same as the default tensor +def register_tensor_creation_op(op): + @_sharded_op_impl(op) + def tensor_creation_op(types, args=(), kwargs=None, pg=None): + """ + Handles ``__torch_function__`` dispatch for tensor creation ops that + takes a ShardedTensor as argument, such as ``core.zeros_like`` or + ``core.full_like``. + """ + creation_op = tensor_like_creation_op_map.get(op, None) + if creation_op is None: + raise RuntimeError(f"Tensor creation {op} not supported!") + if kwargs is None: + kwargs = {} + + st = args[0] + + new_st = creation_op(st.sharding_spec(), st.size(), *args[1:], **kwargs) # type: ignore[operator] + return new_st + + +register_tensor_creation_op(core.full_like) +register_tensor_creation_op(core.empty_like) +register_tensor_creation_op(core.zeros_like) +register_tensor_creation_op(core.ones_like) +register_tensor_creation_op(core.rand_like) +register_tensor_creation_op(core.randn_like) diff --git a/mindnlp/core/distributed/_shard/sharded_tensor/_ops/misc_ops.py b/mindnlp/core/distributed/_shard/sharded_tensor/_ops/misc_ops.py new file mode 100644 index 000000000..3f8c9fb31 --- /dev/null +++ b/mindnlp/core/distributed/_shard/sharded_tensor/_ops/misc_ops.py @@ -0,0 +1,12 @@ +# mypy: allow-untyped-defs +from mindnlp import core +from core.distributed._shard.sharded_tensor import _sharded_op_impl + + +# This is used by `_apply()` within module.py to set new +# parameters after apply a certain method, we should follow +# the future behavior of overwriting the existing tensor +# instead of doing in-place change using `.data = `. +# @_sharded_op_impl(core._has_compatible_shallow_copy_type) +def tensor_has_compatible_shallow_copy_type(types, args=(), kwargs=None, pg=None): + return False diff --git a/mindnlp/core/distributed/_shard/sharded_tensor/_ops/tensor_ops.py b/mindnlp/core/distributed/_shard/sharded_tensor/_ops/tensor_ops.py new file mode 100644 index 000000000..3e53205df --- /dev/null +++ b/mindnlp/core/distributed/_shard/sharded_tensor/_ops/tensor_ops.py @@ -0,0 +1,219 @@ +# mypy: allow-untyped-defs +import copy + +from mindnlp import core +from core.distributed._shard.common_op_utils import _register_default_op +from core.distributed._shard.sharded_tensor import ( + _sharded_op_impl, + Shard, + ShardedTensor, +) + +from ._common import _register_sharded_op_on_local_shards + + +# Tensor properties access +_register_default_op(core.Tensor.shape.__get__, _sharded_op_impl) # type: ignore[attr-defined] +_register_default_op(core.Tensor.dtype.__get__, _sharded_op_impl) # type: ignore[attr-defined] +# _register_default_op(core.Tensor.layout.__get__, _sharded_op_impl) # type: ignore[attr-defined] +_register_default_op(core.Tensor.size, _sharded_op_impl) +_register_default_op(core.Tensor.dim, _sharded_op_impl) +_register_default_op(core.Tensor.ndim.__get__, _sharded_op_impl) # type: ignore[attr-defined] +_register_default_op(core.Tensor.is_contiguous, _sharded_op_impl) +_register_default_op(core.Tensor.contiguous, _sharded_op_impl) +_register_default_op(core.Tensor.is_floating_point, _sharded_op_impl) + +# __reduce_ex__ to dispatch to get_state/set_state +_register_default_op(core.Tensor.__reduce_ex__, _sharded_op_impl) + +# autograd related properties +_register_default_op(core.Tensor.requires_grad.__get__, _sharded_op_impl) # type: ignore[attr-defined] +# TODO: set grad with a ShardedTensor that consists of all local grads +_register_default_op(core.Tensor.grad.__get__, _sharded_op_impl) # type: ignore[union-attr] +# _register_default_op(core.Tensor.grad_fn.__get__, _sharded_op_impl) # type: ignore[union-attr] +# _register_default_op(core.Tensor.is_leaf.__get__, _sharded_op_impl) # type: ignore[attr-defined] + + +# device property is ambiguous as from a global prospective, +# ShardedTensor.device consists of multiple devices (might even across hosts) +# We choose to return the current device of the local tensor to represent +# the device property on each rank +# @_sharded_op_impl(core.Tensor.device.__get__) +def tensor_device(types, args=(), kwargs=None, pg=None): + self_st = args[0] + # Validate types + if not isinstance(self_st, ShardedTensor): + raise TypeError("input needs to be a ShardedTensor") + dev: core.device + if self_st._local_shards: + dev = self_st._local_shards[0].tensor.device + elif pg and pg._get_backend_name() == "gloo": + dev = core.device("cpu") + else: + dev = core.device(core.cuda.current_device()) + return dev + + +# @_sharded_op_impl(core.Tensor.is_meta.__get__) # type: ignore[attr-defined] +def st_is_meta(types, args=(), kwargs=None, pg=None): + return args[0].local_tensor().is_meta + + +def sharded_type_as_check(*args, **kwargs): + """ + Perform extra checks for the sharded_type_as op such as the input needs to + be either a Tensor or ShardedTensor. + + Args: same as ``core.Tensor.type_as``. + + Return: None + """ + if len(args) < 2: + raise ValueError("Needs to give a tensor to cast type as!") + if not isinstance(args[1], core.Tensor) and not isinstance(args[1], ShardedTensor): + raise ValueError("Needs to give a Tensor or ShardedTensor to cast type as!") + + +def same_dtype(*args, **kwargs): + """ + When the dtype is the same, return the original ShardedTensor. + + Args: same as ``core.Tensor.type_as``. + + Return (bool): Whether to return early or not. + """ + return args[0].dtype == args[1].dtype + + +def sharded_type_as(args, kwargs, pg): + """ + Handles ``__torch_function__`` dispatch for the ``core.Tensor.type_as`` op. + + Args: same as ``core.Tensor.type_as``. + + Return: + new_local_shards (List[Shard]): Local shards for the new sharded tensor. + st_meta (ShardedTensorMetadata): Metadata of the new sharded tensor. + """ + st = args[0] + tensor = args[1] + if isinstance(tensor, ShardedTensor): + tensor = tensor.local_tensor() + new_local_shards = [ + Shard(shard.tensor.type_as(tensor), shard.metadata) + for shard in st.local_shards() + ] + st_meta = copy.deepcopy(st._metadata) + st_meta.tensor_properties.dtype = tensor.dtype + return new_local_shards, st_meta + + +_register_sharded_op_on_local_shards( + core.Tensor.type_as, + early_stop_func=same_dtype, + extra_check=sharded_type_as_check, + customized_func=sharded_type_as, +) + + +def sharded_deepcopy(args, kwargs, pg): + # NOTE: we directly implement deepcopy magic method + # instead of using the default tensor.__deepcopy__ + # and implement clone(). This is because the default + # tensor deepcopy copies every attribute, but the + # process_group in ShardedTensor cannot be deep copied. + self_st = args[0] + new_local_shards = copy.deepcopy(self_st.local_shards()) + new_metadata = copy.deepcopy(self_st.metadata()) + return new_local_shards, new_metadata + + +_register_sharded_op_on_local_shards( + core.Tensor.__deepcopy__, + customized_func=sharded_deepcopy, +) + + +@_sharded_op_impl(core.Tensor.copy_) +def sharded_inplace_copy(types, args, kwargs, pg): + # NOTE: inplace op don't need to rewrap + kwargs = {} if kwargs is None else kwargs + self_st = args[0] + new_st = args[1] + nonblocking = kwargs.get("non_blocking", False) + for local_shard, new_shard in zip(self_st.local_shards(), new_st.local_shards()): + if local_shard.metadata != new_shard.metadata: + raise RuntimeError( + "inplace copy can only happen between two ShardedTensor with same metadata!" + ) + for local_shard, new_shard in zip(self_st.local_shards(), new_st.local_shards()): + local_shard.tensor.copy_(new_shard.tensor, nonblocking) + + return self_st + + +def sharded_clone(args, kwargs, pg): + self_st = args[0] + desire_memory_format = kwargs.get("memory_format", None) + if desire_memory_format and desire_memory_format != core.preserve_format: + raise RuntimeError("Only support core.preserve_format for ShardedTensor!") + cloned_local_shards = [ + Shard( + local_shard.tensor.clone(memory_format=desire_memory_format), + metadata=copy.deepcopy(local_shard.metadata), + ) + for local_shard in self_st.local_shards() + ] + new_metadata = copy.deepcopy(self_st.metadata()) + return cloned_local_shards, new_metadata + + +# _register_sharded_op_on_local_shards( +# core.Tensor.clone, +# customized_func=sharded_clone, +# ) + + +def sharded_detach(args, kwargs, pg): + self_st = args[0] + detached_local_shards = [ + Shard( + local_shard.tensor.detach(), + metadata=copy.deepcopy(local_shard.metadata), + ) + for local_shard in self_st.local_shards() + ] + new_metadata = copy.deepcopy(self_st.metadata()) + new_metadata.tensor_properties.requires_grad = False + return detached_local_shards, new_metadata + + +# _register_sharded_op_on_local_shards( +# core.Tensor.detach, +# customized_func=sharded_detach, +# ) + + +# @_sharded_op_impl(core.Tensor.requires_grad_) +def tensor_requires_grad_set(types, args=(), kwargs=None, pg=None): + self_st = args[0] + # Validate types + if not isinstance(self_st, ShardedTensor): + raise TypeError("input needs to be a ShardedTensor") + + if kwargs is None: + kwargs = {} + + requires_grad = args[1] if len(args) > 1 else kwargs.get("requires_grad", True) + if requires_grad == self_st.requires_grad: + return self_st + + for local_shard in self_st.local_shards(): + local_shard.tensor.requires_grad_(requires_grad) + + # update the wrapper class property + with core._C.DisableTorchFunctionSubclass(): + self_st.requires_grad_(requires_grad) + # update the metadata in the meanwhile + self_st._metadata.tensor_properties.requires_grad = requires_grad + return self_st diff --git a/mindnlp/core/distributed/_shard/sharded_tensor/api.py b/mindnlp/core/distributed/_shard/sharded_tensor/api.py new file mode 100644 index 000000000..a61fb504d --- /dev/null +++ b/mindnlp/core/distributed/_shard/sharded_tensor/api.py @@ -0,0 +1,1296 @@ +# mypy: allow-untyped-defs +from __future__ import annotations # type: ignore[attr-defined] + +import copy +import operator +import threading +import warnings +import weakref +from dataclasses import dataclass +from functools import reduce +from typing import Callable, cast, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING +from typing_extensions import deprecated + +from mindnlp import core +from mindnlp import core.distributed as dist +from mindnlp import core.distributed._shard.sharding_spec as shard_spec +from core.distributed import distributed_c10d, rpc +from core.distributed._shard._utils import DEPRECATE_MSG +from core.distributed._shard.sharding_spec._internals import ( + check_tensor, + validate_non_overlapping_shards_metadata, +) +from core.distributed._shard.sharding_spec.api import ( + _dispatch_custom_op, + _has_custom_op, +) +from core.distributed.remote_device import _remote_device +from core.utils import _pytree as pytree + +from .metadata import ShardedTensorMetadata, TensorProperties +from .reshard import reshard_local_shard, reshuffle_local_shard +from .shard import Shard +from .utils import ( + _flatten_tensor_size, + _parse_and_validate_remote_device, + _validate_output_tensor_for_gather, + build_global_metadata, + build_metadata_from_local_shards, +) + + +if TYPE_CHECKING: + from core.distributed._shard.metadata import ShardMetadata + + +# Tracking for sharded tensor objects. +_sharded_tensor_lock = threading.Lock() +_sharded_tensor_current_id = 0 +_sharded_tensor_map: Dict[int, weakref.ReferenceType[ShardedTensor]] = {} + +# Default sharded ops +_SHARDED_OPS: Dict[Callable, Callable] = {} + +# Customized user ops +_CUSTOM_SHARDED_OPS: Dict[Callable, Callable] = {} + + +def _register_remote_shards( + sharded_tensor_id: int, rrefs: List[rpc.RRef[Shard]], rpc_rank: int +): + with _sharded_tensor_lock: + if sharded_tensor_id not in _sharded_tensor_map: + raise RuntimeError( + f"Could not find sharded_tensor_id: {sharded_tensor_id} in map: {_sharded_tensor_map.keys()}" + ) + + sharded_tensor = _sharded_tensor_map[sharded_tensor_id]() + if sharded_tensor is None: + raise RuntimeError("ShardedTensor weakref has been deallocated") + else: + sharded_tensor._register_remote_shards(rrefs, rpc_rank) + + +class ShardedTensorBase(core.Tensor): + _sharding_spec: shard_spec.ShardingSpec + _metadata: ShardedTensorMetadata + _local_shards: List[Shard] + + def __new__(cls, sharding_spec: shard_spec.ShardingSpec, *size, **kwargs): + # Use __new__ to construct a wrapper tensor, for recording tensor + # properties and logging purposes. + + # check sharding spec and build sharded tensor metadata + if not isinstance(sharding_spec, shard_spec.ShardingSpec): + raise ValueError(f"Expecting ShardingSpec but got: {type(sharding_spec)}") + + sizes = _flatten_tensor_size(size) + dtype = kwargs["dtype"] + # layout = kwargs["layout"] + # pin_memory = kwargs["pin_memory"] + requires_grad = kwargs["requires_grad"] + + if dtype is None: + dtype = core.get_default_dtype() + + tensor_properties = TensorProperties( + dtype, requires_grad + ) + sharded_tensor_metadata = sharding_spec.build_metadata( + sizes, tensor_properties=tensor_properties + ) + + r = super().__new__(cls) + # set sharding spec + r._sharding_spec = sharding_spec + # set metadata + r._metadata = sharded_tensor_metadata + # set local shards + r._local_shards = [] + return r + + def metadata(self) -> ShardedTensorMetadata: + """ + Returns a :class:`ShardedTensorMetadata` object corresponding to the + metadata for the entire tensor. + """ + return self._metadata + + def local_shards(self) -> List[Shard]: + """ + Returns a list of :class:`Shard' corresponding to the + local shards for this rank. Returns an empty list if the current rank + does not host any shards for this Tensor. + """ + return self._local_shards + + @classmethod + def _init_from_local_shards_and_global_metadata( + cls, + local_shards: List[Shard], + sharded_tensor_metadata: ShardedTensorMetadata, + sharding_spec=None, + ) -> ShardedTensorBase: + """ + Initialize a ShardedTensorBase with local shards and a global + ShardedTensorMetadata built on each rank. + Warning: This API is experimental and subject to change. It does + not do cross rank validations, and fully rely on the user + for the correctness of sharded_tensor_metadata on each rank + """ + shards_metadata = sharded_tensor_metadata.shards_metadata + tensor_properties = sharded_tensor_metadata.tensor_properties + + if len(shards_metadata) == 0: + raise ValueError("shards_metadata must not be empty!") + + # if tensor_properties.layout != core.strided: + # raise ValueError("Only core.strided layout is currently supported") + + if sharding_spec is None: + spec = shard_spec._infer_sharding_spec_from_shards_metadata(shards_metadata) + else: + spec = sharding_spec + + sharded_tensor_base = ShardedTensorBase.__new__( + ShardedTensor, + spec, + sharded_tensor_metadata.size, + dtype=tensor_properties.dtype, + # layout=tensor_properties.layout, + # pin_memory=tensor_properties.pin_memory, + requires_grad=tensor_properties.requires_grad, + ) + + # check if shards_metadata have overlap shards + validate_non_overlapping_shards_metadata(shards_metadata) + + # check if the shards_metadata is compatible with overall size of the sharded tensor. + check_tensor(shards_metadata, list(sharded_tensor_metadata.size)) + + # done validation, add local_shards + sharded_tensor_base._local_shards = local_shards + return sharded_tensor_base + + @classmethod + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + raise RuntimeError( + f"A {cls.__name__} object is being used from c++ while calling {func.__module__}.{func.__name__} " + "but the there is no custom __torch_dispatch__ implementation for it." + ) + + +class ShardedTensor(ShardedTensorBase): + """ + ShardedTensor is an core.Tensor subclass to represent Tensors that are sharded + across multiple devices and multiple processes. + + ShardedTensor is initialized in an SPMD like fashion where each rank + initializes the ShardedTensor. The ShardedTensor object on each rank + then only stores the local shard for the Tensor and provides global + metadata for all the shards. + + ShardedTensor doesn't provide any Tensor like operations but is a wrapper + providing the Tensor representing the local shard and the global metadata. + Using these, users can build their custom distributed._sharded computations + on top of this primitive. The local shards are all initialized using the + create_op specified by tensor_init_params.create_op, e.g., core.ones, or + core.empty + + Args: + sharding_spec (:class:`core.distributed._shard.sharding_spec.ShardingSpec`): The specification + describing how to shard the Tensor. + size (int...): a sequence of integers defining the shape of the output + tensor. Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + dtype (:class:`core.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`core.set_default_dtype`). + layout (:class:`core.layout`, optional): the desired layout of returned Tensor. + Default: ``core.strided``. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + memory_format (:class:`core.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``core.contiguous_format``. + init_rrefs (bool, optional): Whether or not to initialize + :class:`core.distributed.rpc.RRef`s pointing to remote shards. + Need to initialize the RPC Framework if specified as ``True``. + Default: ``False``. + + .. note:: ShardedTensor uses collectives to do various operations, i.e. it + uses all_gather to do cross rank validations. For NCCL-based process + groups, internal tensor representations of objects must be moved to the + GPU device before communication takes place. In this case, the device + used is given by ``core.cuda.current_device()`` and it is the user's + responsibility to ensure that this is set so that each rank has an + individual GPU, via ``core.cuda.set_device()`` + + """ + + def __new__(cls, sharding_spec: shard_spec.ShardingSpec, *size, **kwargs): + self = super().__new__(cls, sharding_spec, *size, **kwargs) + return self + + def __init__( + self, + sharding_spec: shard_spec.ShardingSpec, + *size, + dtype=None, + layout=core.strided, + requires_grad=False, + pin_memory=False, + memory_format=core.contiguous_format, + process_group=None, + init_rrefs=False, + ): + # prepare initialization, initialize fields like + # _process_group, _local_shards, etc. + self._prepare_init(process_group=process_group, init_rrefs=init_rrefs) + + if layout != core.strided: + raise ValueError("Only core.strided layout is currently supported") + + if memory_format != core.contiguous_format: + raise ValueError( + "Only core.contiguous_format memory_format is currently supported" + ) + + self._metadata.tensor_properties.memory_format = memory_format + + current_rank = dist.get_rank() # global rank + + for shard_metadata in self._metadata.shards_metadata: + rank, device = _parse_and_validate_remote_device( + self._process_group, shard_metadata.placement + ) + if rank == current_rank: + local_tensor = _create_tensor_from_params( + shard_metadata.shard_sizes, + local_device=device, + tensor_properties=self._metadata.tensor_properties, + ) + self._local_shards.append(Shard(local_tensor, shard_metadata)) + + # do post initialization (i.e. register sharded_tensor_id, initialize_rpc) + self._post_init() + + def _prepare_init(self, process_group=None, init_rrefs=False): + self._init_rrefs = init_rrefs + self._sharded_tensor_id = None + + self._process_group = self._normalize_pg(process_group) + self._remote_shards: Dict[int, List[rpc.RRef[Shard]]] = {} + + def _post_init(self): + # Initialize RPC if available. + if self._init_rrefs: + with _sharded_tensor_lock: + global _sharded_tensor_current_id, _sharded_tensor_map + self._sharded_tensor_id = _sharded_tensor_current_id + _sharded_tensor_map[self._sharded_tensor_id] = weakref.ref(self) + _sharded_tensor_current_id += 1 + + if not rpc._is_current_rpc_agent_set(): + raise RuntimeError( + "RPC Framework needs to be initialized using" + " core.distributed.rpc.init_rpc if init_rrefs is set to True" + ) + self._init_rpc() + + def __del__(self): + # Clean up the global map. + with _sharded_tensor_lock: + global _sharded_tensor_current_id, _sharded_tensor_map + if ( + hasattr(self, "_sharded_tensor_id") + and self._sharded_tensor_id in _sharded_tensor_map + ): + _sharded_tensor_map.pop(self._sharded_tensor_id) # type: ignore[call-overload] + + def _init_rpc(self): + # Validate PG and RPC ranks match. + pg_rank = dist.get_rank() + rpc_rank = rpc.get_worker_info().id + if pg_rank != rpc_rank: + raise ValueError( + f"Default ProcessGroup and RPC ranks must be " + f"the same for ShardedTensor, found process group rank: " + f"{pg_rank} and RPC rank: {rpc_rank}" + ) + + self._remote_shards = {} + + # Gather all the sharded tensor ids. + worker_infos = rpc._get_current_rpc_agent().get_worker_infos() + rank_to_name = {} + name_to_rank = {} + + for worker_info in worker_infos: + rank_to_name[worker_info.id] = worker_info.name + name_to_rank[worker_info.name] = worker_info.id + + all_tensor_ids = rpc.api._all_gather(self._sharded_tensor_id) + + # Share the local shards to the entire world. + futs = [] + rpc_rank = rpc.get_worker_info().id + for rank in range(dist.get_world_size()): + # Skip self. + if rank == dist.get_rank(): + continue + + if len(self.local_shards()) != 0: + rrefs: List[rpc.RRef[Shard]] = [ + rpc.RRef(shard) for shard in self.local_shards() + ] + fut = rpc.rpc_async( + rank, + _register_remote_shards, + args=(all_tensor_ids[rank_to_name[rank]], rrefs, rpc_rank), + ) + futs.append(fut) + + core.futures.wait_all(futs) + + # Barrier for all RPCs to finish on all ranks. + rpc.api._all_gather(None) + + def _get_preferred_device(self) -> core.device: + """ + Return the preferred device to be used when creating tensors for collectives. + This method takes into account the associated process group + """ + if dist.get_backend(self._process_group) == dist.Backend.NCCL: + return core.device(core.cuda.current_device()) + return core.device("cpu") + + def gather( # type: ignore[override] + self, + dst: int = 0, + out: Optional[core.Tensor] = None, + enforce_dtype: bool = False, + dtype: Optional[core.dtype] = None, + ) -> None: + """ + Creates a full :class:`Tensor` on rank ``dst`` by gathering all shards of the + sharded tensor. + + The API needs to be called on all ranks in SPMD fashion. All ranks should have + the same ``dst``. ``out`` should be a tensor of the same size as the overall + size of the sharded tensor on ``dst`` and ``None`` on all other ranks. + + Args: + dst(int): The rank where full tensor is constructed. + Default: 0 + out (:class `core.Tensor`, optional): The output full tensor. + Must to be provided ONLY on ``dst`` rank. + Default: ``None`` + enforce_dtype (bool): Deprecated, please use dtype instead. Force the + gathered tensors to be the same type as input and output. + dtype (core.dtype): Force the gathered tensors to be this dtype. + Default: ``None`` + """ + + def shard_size(shard_md): + return reduce(operator.mul, shard_md.shard_sizes) # type: ignore[attr-defined] + + if enforce_dtype: + warnings.warn( + "`enforce_dtype` is deprecated. Please use `dtype` instead.", + FutureWarning, + stacklevel=2, + ) + + rank = dist.get_rank(self._process_group) + full_size = self.metadata().size + _validate_output_tensor_for_gather(rank, dst, full_size, out) + + local_shards = self.local_shards() + world_size = dist.get_world_size(self._process_group) + rank_sizes = [0 for _ in range(world_size)] + max_rank_size = 0 + shard_placement: Dict[ShardMetadata, Tuple[int, int]] = {} + # collect sizes + for shard_md in self.metadata().shards_metadata: + shard_rank = cast(_remote_device, shard_md.placement).rank() + assert shard_rank is not None + + shard_placement[shard_md] = (shard_rank, rank_sizes[shard_rank]) + rank_sizes[shard_rank] += shard_size(shard_md) + max_rank_size = max(max_rank_size, rank_sizes[shard_rank]) + + gather_list: Optional[List[core.Tensor]] + if rank == dst: + assert out is not None + if enforce_dtype: + # enforce_dtype is deprecated. Do it for backward compatibility. + dtype = out.dtype + # TODO make it as a view of out tensor + gather_list = [ + core.empty((max_rank_size,), device=out.device, dtype=dtype) + for _ in range(world_size) + ] + else: + gather_list = None + + with core.no_grad(): + if enforce_dtype and len(local_shards) > 0: + # enforce_dtype is deprecated. Do it for backward compatibility. + dtype = local_shards[0].tensor.dtype + data = core.empty( + max_rank_size, device=self._get_preferred_device(), dtype=dtype + ) + + for shard in local_shards: + src = shard.tensor.flatten() + if src.nelement() == 0: + warnings.warn( + "Gathering a tensor with zero elements on rank " + str(rank) + ) + return + shard_offset = shard_placement[shard.metadata][1] + data[shard_offset : shard_offset + src.numel()].copy_(src) + + dist.gather( + tensor=data, + gather_list=gather_list, + dst=dst, + group=self._process_group, + ) + if rank != dst: + return + # In _validate_output_tensor_for_gather, we raise if out == None and rank == dst + out = cast(core.Tensor, out) + assert gather_list is not None + + full_size = self.metadata().size + dims = len(full_size) + for shard_md in self.metadata().shards_metadata: + rank, rank_offset = shard_placement[shard_md] + tensor = gather_list[rank] + tensor = tensor[rank_offset : rank_offset + shard_size(shard_md)] + tensor = tensor.view(shard_md.shard_sizes) + + out_narrow_view = out + for dim in range(dims): + out_narrow_view = out_narrow_view.narrow( + dim, + shard_md.shard_offsets[dim], + shard_md.shard_sizes[dim], + ) + + out_narrow_view.copy_(tensor) + + def cpu( + self, memory_format=core.preserve_format, process_group=None + ) -> ShardedTensor: + """ + Returns a copy of this object in CPU memory. + + If this ShardedTensor is already on CPU memory, then no copy is + performed and original object is returned. + + .. note:: When moving a ShardedTensor from GPU to CPU, the ShardedTensor might + need to be managed by a different type of ProcessGroup(i.e. ProcessGroupGloo), + it is the user's responsiblity to explicitly pass in a new process_group that + is compatible with CPU. + """ + # TODO: make this a __torch_function__ op once ShardedTensor becomes a + # core.Tensor subclass, see https://github.com/pytorch/pytorch/issues/75402 + if ( + memory_format != core.preserve_format + and memory_format != core.contiguous_format + ): + raise RuntimeError( + "Only `core.contiguous_format` or " + "`core.preserve_format` is supported!" + ) + all_on_cpu = True + for meta in self.metadata().shards_metadata: + all_on_cpu &= meta.placement.device().type == "cpu" # type: ignore[union-attr] + + # if every shard is already on CPU, return the original object + if all_on_cpu: + return self + + # if not, returns a copy of this object on CPU + list_shards: List[Shard] = [] + # move all local shards to cpu, and change metadata + for shard in self._local_shards: + cpu_tensor = shard.tensor.cpu(memory_format=memory_format) # type: ignore[call-arg] + metadata = copy.deepcopy(shard.metadata) + metadata.placement._device = core.device("cpu") # type: ignore[union-attr] + list_shards.append(Shard(cpu_tensor, metadata)) + + st_meta = copy.deepcopy(self.metadata()) + for meta in st_meta.shards_metadata: + if meta.placement.device().type != "cpu": # type: ignore[union-attr] + meta.placement._device = core.device("cpu") # type: ignore[union-attr] + + pg = self._process_group if process_group is None else process_group + st_cpu = ShardedTensor._init_from_local_shards_and_global_metadata( + list_shards, + sharded_tensor_metadata=st_meta, + process_group=pg, + init_rrefs=self._init_rrefs, + ) + return st_cpu + + def cuda( + self, + device=None, + non_blocking=False, + memory_format=core.preserve_format, + process_group=None, + ) -> ShardedTensor: + """ + Returns a copy of this object in CUDA memory, if the original ShardedTensor + is on CPU, we will move the local shard to the current GPU device of each + process in a SPMD fashion. + If this ShardedTensor is already on CUDA memory and local shards on each rank are + already on current device, we still returns a new ShardedTensor object with new + metadata, but no underlying data movements are performed. + .. note:: When moving a ShardedTensor from CPU to GPU, the ShardedTensor might + need to be managed by a different type of ProcessGroup(i.e. ProcessGroupNCCL), + it is the user's responsiblity to explicitly pass in a new process_group that + is compatible with GPU. + """ + if ( + memory_format != core.preserve_format + and memory_format != core.contiguous_format + ): + raise RuntimeError( + "Only `core.contiguous_format` or " + "`core.preserve_format` is supported!" + ) + + if device is not None: + device = core.device(device) if isinstance(device, str) else device + assert ( + isinstance(device, core.device) + and device.index == core.cuda.current_device() + ), """Only device without device id (e.g. "cpu" or "cuda") is expected for ShardedTensor!""" + + current_device = core.device(core.cuda.current_device()) + # returns a copy of ShardedTensor on CUDA current device + list_shards: List[Shard] = [] + # move all local shards to current device, and change metadata + # if local shards already on the current device, there's no + # real data movement, only the metadata are copied. + for shard in self._local_shards: + cuda_tensor = shard.tensor.cuda( + device=current_device, + non_blocking=non_blocking, + memory_format=memory_format, + ) # type: ignore[call-arg] + metadata = copy.deepcopy(shard.metadata) + metadata.placement._device = current_device # type: ignore[union-attr] + + list_shards.append(Shard(cuda_tensor, metadata)) + + st_meta = copy.deepcopy(self.metadata()) + for meta in st_meta.shards_metadata: + if meta.placement.device().type != "cuda": # type: ignore[union-attr] + meta.placement._device = current_device # type: ignore[union-attr] + + pg = self._process_group if process_group is None else process_group + # we need to use `init_from_local_shards` to communicate between ranks + # and update the sharding spec/shards metadata. + st_cuda = ShardedTensor._init_from_local_shards_and_global_metadata( + list_shards, + sharded_tensor_metadata=st_meta, + process_group=pg, + init_rrefs=self._init_rrefs, + ) + return st_cuda + + def to(self, *args, **kwargs) -> ShardedTensor: + current_device: core.device + if self._local_shards: + current_device = self._local_shards[0].tensor.device + elif self._process_group._get_backend_name() == "gloo": + current_device = core.device("cpu") + else: + current_device = core.device(core.cuda.current_device()) + current_dtype = self.dtype + device_to = current_device + dtype_to = current_dtype + if len(args) == 1: + if isinstance(args[0], core.dtype): + dtype_to = args[0] + elif isinstance(args[0], core.device): + device_to = args[0] + elif isinstance(args[0], (str, int)): + device_to = core.device(args[0]) + elif isinstance(args[0], core.Tensor): + dtype_to = args[0].dtype + device_to = args[0].device + else: + raise RuntimeError(f"ShardedTensor.to() have wrong arguments: {args}") + elif len(args) == 2: + device_to, dtype_to = args + else: + dtype_to = kwargs.get("dtype", current_dtype) + device_to = kwargs.get("device", current_device) + + device_to = ( + core.device(device_to) if isinstance(device_to, (str, int)) else device_to + ) + + if device_to.type == "cuda": + # if device_to set to cuda, set to current device even + # if user specify the device index. + current_idx = core.cuda.current_device() + if device_to.index != current_idx: + warnings.warn( + "ShardedTensor.to only move tensor to its current device" + "If you want to put to different device, use `reshard` instead." + ) + device_to = core.device(current_idx) + + copy_tensor = kwargs.get("copy", False) + non_blocking = kwargs.get("non_blocking", False) + memory_format = kwargs.get("memory_format", core.preserve_format) + process_group = kwargs.get("process_group", None) + + if ( + not copy_tensor + and dtype_to == current_dtype + and device_to == current_device + ): + # already have correct dtype and device, return itself + return self + + # returns a copy of ShardedTensor on CUDA current device + list_shards: List[Shard] = [] + + for shard in self._local_shards: + new_tensor = shard.tensor.to( # type: ignore[call-overload] + device=device_to, + dtype=dtype_to, + non_blocking=non_blocking, + copy=copy_tensor, + memory_format=memory_format, + ) + metadata = copy.deepcopy(shard.metadata) + if metadata.placement is not None: + metadata.placement._device = device_to + list_shards.append(Shard(new_tensor, metadata)) + + # update metadata + st_meta = copy.deepcopy(self.metadata()) + st_meta.tensor_properties.dtype = dtype_to + for meta in st_meta.shards_metadata: + meta.placement._device = device_to # type: ignore[union-attr] + + pg = self._process_group if process_group is None else process_group + # we need to use `init_from_local_shards` to communicate between ranks + # and update the sharding spec/shards metadata. + st_to = ShardedTensor._init_from_local_shards_and_global_metadata( + list_shards, + sharded_tensor_metadata=st_meta, + process_group=pg, + init_rrefs=self._init_rrefs, + ) + return st_to + + @classmethod + def _normalize_pg( + cls, process_group: Optional[dist.ProcessGroup] + ) -> dist.ProcessGroup: + if process_group is not None: + return process_group + return distributed_c10d._get_default_group() + + @classmethod + def _init_from_local_shards( + cls, + local_shards: List[Shard], + *global_size, + process_group=None, + init_rrefs=False, + ): + # STEP 1: Validate the Shardmetadatas locally + process_group = cls._normalize_pg(process_group) + current_rank = dist.get_rank() # intentional to get global rank + world_size = dist.get_world_size(process_group) + + local_sharded_tensor_metadata: Optional[ShardedTensorMetadata] = None + global_tensor_size = _flatten_tensor_size(global_size) + + if len(local_shards) > 0: + local_sharded_tensor_metadata = build_metadata_from_local_shards( + local_shards, global_tensor_size, current_rank, process_group + ) + + # STEP 2. Validate metadata across ranks, and build a global sharded tensor + # metadata by gathering local ShardedTensorMetadata + gathered_metadatas: List[Optional[ShardedTensorMetadata]] = [] + if world_size > 1: + gathered_metadatas = [None for _ in range(world_size)] + + dist.all_gather_object( + gathered_metadatas, local_sharded_tensor_metadata, group=process_group + ) + else: + gathered_metadatas = [local_sharded_tensor_metadata] + + global_sharded_tensor_metadata = build_global_metadata(gathered_metadatas) + tensor_properties = global_sharded_tensor_metadata.tensor_properties + + # STEP 3: Validation done, create the actual ShardedTensor and populate fields + # prepare initialization + spec = shard_spec._infer_sharding_spec_from_shards_metadata( + global_sharded_tensor_metadata.shards_metadata + ) + sharded_tensor = cls.__new__( + cls, + spec, + global_sharded_tensor_metadata.size, + dtype=tensor_properties.dtype, + layout=tensor_properties.layout, + pin_memory=tensor_properties.pin_memory, + requires_grad=tensor_properties.requires_grad, + ) + sharded_tensor._prepare_init(process_group=process_group, init_rrefs=init_rrefs) + + # attach local_shards to the ShardedTensor created + sharded_tensor._local_shards = local_shards + + # run post initialization, i.e. map registration, rpc initialization + sharded_tensor._post_init() + return sharded_tensor + + @classmethod + @deprecated(DEPRECATE_MSG, category=FutureWarning) + def _init_from_local_tensor( + cls, + local_tensor: core.Tensor, + sharding_spec: shard_spec.ShardingSpec, + *global_size: Sequence[int], + process_group: Optional[dist.ProcessGroup] = None, + init_rrefs=False, + ) -> ShardedTensor: + """ + Initialize a ShardedTensor given only one local tensor, global sharded tensor + size and sharding spec on each rank. + + Args: + local_tensor (Tensor): Single tensor of local shard stored in each rank. + sharding_spec (:class:`core.distributed._shard.sharding_spec.ShardingSpec`): + The specification describing how to shard the Tensor. + global_size (Sequence[int]): Size of the sharded tensor. + process_group (ProcessGroup, optional): The process group to aggregate on. + Default: None + init_rrefs (bool, optional): Whether or not to initialize + :class:`core.distributed.rpc.RRef`s pointing to remote shards. + Need to initialize the RPC Framework if specified as ``True``. + Default: ``False``. + + Returns: + A :class:`ShardedTensor` sharded based on the given sharding_spec with local + tensor stored in the current rank. + + Examples: + >>> # xdoctest: +SKIP + >>> # All tensors below are of core.int64 type. + >>> # We have 2 process groups, 2 ranks. + >>> tensor = core.arange(2, dtype=core.int64) + 1 + 2 * rank + >>> local_tensor = core.unsqueeze(core.cat([tensor, tensor + 2])) + >>> local_tensor + tensor([[1, 2, 3, 4]]) # Rank 0 + tensor([[3, 4, 5, 6]]) # Rank 1 + >>> sharding_dim = 0 + >>> sharding_spec = ChunkShardingSpec( + dim=sharding_dim, + placements=[ + "rank:0/cuda:0", + "rank:1/cuda:1", + ], + ) + >>> st = ShardedTensor._init_from_local_tensor(local_tensor, sharding_spec, [2, 4]) + >>> st + ShardedTensor( + ShardedTensorMetadata( + shards_metadata=[ + ShardMetadata(shard_offsets=[0, 0], shard_sizes=[1, 4], placement=rank:0/cuda:0), + ShardMetadata(shard_offsets=[1, 0], shard_sizes=[1, 4], placement=rank:1/cuda:1), + ], + size=core.Size([2, 4]) + ) + >>> st.local_tensor() + tensor([1, 2, 3, 4]) # Rank 0 + tensor([3, 4, 5, 6]) # Rank 1 + + Warning: This API is experimental and subject to change. It lacks of a fully across + rank validations, and we only validate the local shard on the current rank. + We fully rely on the user to ensure local tensor is sharded based on the + sharding spec. + """ + if not local_tensor.is_contiguous(): + raise ValueError("local_tensor is not a contiguous Tensor.") + + global_tensor_size = _flatten_tensor_size(global_size) + tensor_properties = TensorProperties( + dtype=local_tensor.dtype, + layout=local_tensor.layout, + requires_grad=local_tensor.requires_grad, + memory_format=core.contiguous_format, + pin_memory=local_tensor.is_pinned(), + ) + sharded_tensor_metadata = sharding_spec.build_metadata( + global_tensor_size, tensor_properties + ) + + process_group = cls._normalize_pg(process_group) + current_rank = dist.get_rank() # intentional to get global rank + + local_shards: List[Shard] = [] + for shard_metadata in sharded_tensor_metadata.shards_metadata: + rank, _device = _parse_and_validate_remote_device( + process_group, shard_metadata.placement + ) + if rank == current_rank: + local_shards.append(Shard(local_tensor, shard_metadata)) + + # TODO: figure out what the API should behave when some rank have no shard + # see https://github.com/pytorch/pytorch/issues/7313 + return ShardedTensor._init_from_local_shards_and_global_metadata( + local_shards, + sharded_tensor_metadata, + process_group=process_group, + init_rrefs=init_rrefs, + sharding_spec=sharding_spec, + ) + + @classmethod + def _init_from_local_shards_and_global_metadata( # type: ignore[override] + cls, + local_shards: List[Shard], + sharded_tensor_metadata: ShardedTensorMetadata, + process_group=None, + init_rrefs=False, + sharding_spec=None, + ) -> ShardedTensor: + """ + Initialize a ShardedTensor with local shards and a global + ShardedTensorMetadata built on each rank. + + Warning: This API is experimental and subject to change. It does + not do cross rank validations, and fully rely on the user + for the correctness of sharded_tensor_metadata on each rank + """ + process_group = cls._normalize_pg(process_group) + current_rank = dist.get_rank() # intentional to get global rank + + shards_metadata = sharded_tensor_metadata.shards_metadata + + local_shard_metadatas = [] + + # collect local shard metadatas from the global sharded_tensor_metadata + for shard_metadata in shards_metadata: # type: ignore[attr-defined] + rank, local_device = _parse_and_validate_remote_device( + process_group, shard_metadata.placement + ) + + if current_rank == rank: + local_shard_metadatas.append(shard_metadata) + + if len(local_shards) != len(local_shard_metadatas): + raise RuntimeError( + f"Number of local shards ({len(local_shards)}) does not match number of local " + f"shards metadata in sharded_tensor_metadata ({len(local_shard_metadatas)}) " + f"on rank ({current_rank}) " + ) + + shards_metadata = sharded_tensor_metadata.shards_metadata + tensor_properties = sharded_tensor_metadata.tensor_properties + + if len(shards_metadata) == 0: + raise ValueError("shards_metadata must not be empty!") + + # if tensor_properties.layout != core.strided: + # raise ValueError("Only core.strided layout is currently supported") + + if sharding_spec is None: + spec = shard_spec._infer_sharding_spec_from_shards_metadata(shards_metadata) + else: + spec = sharding_spec + + sharded_tensor = ShardedTensor.__new__( + ShardedTensor, + spec, + sharded_tensor_metadata.size, + dtype=tensor_properties.dtype, + # layout=tensor_properties.layout, + # pin_memory=tensor_properties.pin_memory, + requires_grad=tensor_properties.requires_grad, + ) + + def _raise_if_mismatch(expected, actual, prop_name, rank, is_property=False): + tensor_property_or_metadata = ( + "tensor property" if is_property else "local ShardMetadata" + ) + if expected != actual: + raise ValueError( + f"Local shards' tensor {prop_name} property is incompatible with " + f"{tensor_property_or_metadata} on rank {rank}: " + f"{tensor_property_or_metadata} {prop_name}={expected}, " + f"local shard tensor {prop_name}={actual}." + ) + + for shard in local_shards: + shard_meta = shard.metadata + local_shard_tensor = shard.tensor + placement = shard_meta.placement + assert placement is not None, "Must specify placement for `Shard`!" + rank = placement.rank() + local_device = placement.device() + + # _raise_if_mismatch( + # tensor_properties.layout, + # local_shard_tensor.layout, + # "layout", + # rank, + # True, + # ) + if not local_shard_tensor.is_contiguous(): + raise ValueError( + "Only core.contiguous_format memory_format is currently supported" + ) + + _raise_if_mismatch( + shard_meta.shard_sizes, + list(local_shard_tensor.size()), + "size", + rank, + ) + # _raise_if_mismatch( + # tensor_properties.pin_memory, + # local_shard_tensor.is_pinned(), + # "pin_memory", + # rank, + # True, + # ) + # _raise_if_mismatch(local_device, local_shard_tensor.device, "device", rank) + _raise_if_mismatch( + tensor_properties.dtype, + local_shard_tensor.dtype, + "dtype", + rank, + True, + ) + _raise_if_mismatch( + tensor_properties.requires_grad, + local_shard_tensor.requires_grad, + "requires_grad", + rank, + True, + ) + + # check if shards_metadata have overlap shards + validate_non_overlapping_shards_metadata(shards_metadata) + + # check if the shards_metadata is compatible with overall size of the sharded tensor. + check_tensor(shards_metadata, list(sharded_tensor_metadata.size)) + + # done validation, add local_shards + sharded_tensor._local_shards = local_shards + sharded_tensor._prepare_init(process_group=process_group, init_rrefs=init_rrefs) + + # run post initialization, i.e. map registration, rpc initialization + sharded_tensor._post_init() + return sharded_tensor + + def sharding_spec(self) -> shard_spec.ShardingSpec: + """ + Returns the ShardingSpec for the tensor. + """ + return self._sharding_spec + + @deprecated(DEPRECATE_MSG, category=FutureWarning) + def reshard(self, resharding_spec: shard_spec.ShardingSpec) -> ShardedTensor: + """ + Reshard a sharded tensor given the ``resharding_spec``. For now, we only support + single local shard. + + If ``resharding_spec`` is same as the original one, this becomes a no-op. + If only ``resharding_spec`` shares the same sharding dim with the original one, + we swap local shards directly. + For more generic cases, we merge different shards across different ranks and split + the local shards based on the ``resharding_spec`` via `all_to_all` collective API. + + Args: + resharding_spec (:class:`core.distributed._shard.sharding_spec.ShardingSpec`): The + specification describing how the tensor is sharded. + + Returns: + A :class:`ShardedTensor` object whose local shards are resharded. + + Examples: + >>> # xdoctest: +SKIP + >>> # We have 2 process groups, 2 ranks. + >>> tensor = core.arange(4, dtype=core.int64) + 1 + 2 * rank + >>> tensor = core.stack([tensor, tensor]) + >>> tensor + tensor([[1, 2, 3, 4], [1, 2, 3, 4]]) # Rank 0 + tensor([[3, 4, 5, 6], [3, 4, 5, 6]]) # Rank 1 + tensor([[5, 6, 7, 8], [5, 6, 7, 8]]) # Rank 2 + tensor([[7, 8, 9, 10], [7, 8, 9, 10]]) # Rank 3 + >>> sharding_dim = 0 + >>> spec = ChunkShardingSpec( + dim=sharding_dim, + placements=[ + "rank:0/cuda:0", + "rank:1/cuda:1", + "rank:2/cuda:2", + "rank:3/cuda:3", + ], + ) + >>> current_offsets = [0] * 2 + >>> current_offsets[0] = rank * 2 + >>> shard_metadata = ShardMetadata( + shard_offsets=copy.deepcopy(current_offsets), + shard_sizes=tensor.size(), + placement=spec.placements[rank], + ) + >>> local_shards = [ + Shard( + tensor=tensor, + metadata=shard_metadata, + ) + ] + >>> st = ShardedTensor._init_from_local_shards(local_shards, tensor.size()) + >>> sharding_dim = 1 + >>> resharding_spec = ChunkShardingSpec( + dim=sharding_dim, + placements=[ + "rank:0/cuda:0", + "rank:1/cuda:1", + "rank:2/cuda:2", + "rank:3/cuda:3", + ], + ) + >>> st.reshard(resharding_spec) + >>> tensor = st.local_shards()[0].tensor + >>> tensor + tensor([[1], [1], [3], [3], [5], [5], [7], [7]]) # Rank 0 + tensor([[2], [2], [4], [4], [6], [6], [8], [8]]) # Rank 1 + tensor([[3], [3], [5], [5], [7], [7], [9], [9]]) # Rank 2 + tensor([[4], [4], [6], [6], [8], [8], [10], [10]]) # Rank 3 + """ + if not isinstance( + resharding_spec, shard_spec.ChunkShardingSpec + ) or not isinstance(self._sharding_spec, shard_spec.ChunkShardingSpec): + raise NotImplementedError("Only ChunkShardingSpec supported for reshard.") + if len(self.local_shards()) != 1: + raise NotImplementedError("Only single local shard supported for reshard.") + + if self._sharding_spec.dim == resharding_spec.dim: # type: ignore[attr-defined] + if self._sharding_spec.placements == resharding_spec.placements: # type: ignore[attr-defined] + return self + else: + local_shards, shards_metadata = reshuffle_local_shard( + self.local_tensor(), + self.size(), # type: ignore[arg-type] + self._sharding_spec, + resharding_spec, + self._process_group, + ) + else: + local_shards, shards_metadata = reshard_local_shard( + self.local_tensor(), + self.size(), # type: ignore[arg-type] + self._sharding_spec, + resharding_spec, + self._process_group, + ) + self._local_shards = local_shards + self._metadata.shards_metadata = shards_metadata + self._sharding_spec = resharding_spec + return self + + def local_tensor(self) -> core.Tensor: + """ + Return local tensor for a sharded_tensor. For now we only support single local shard. + + Returns: + A :class:`core.Tensor` of the local shard. + """ + if len(self.local_shards()) != 1: + raise NotImplementedError("Only single local shard is supported.") + return self.local_shards()[0].tensor + + @classmethod + @deprecated(DEPRECATE_MSG, category=FutureWarning) + def __torch_function__(cls, func, types, args=(), kwargs=None): + def dispatch(st: ShardedTensor, func: Callable): + # Dispatch to custom user provided op first if it exists. + if func in _CUSTOM_SHARDED_OPS: + return _CUSTOM_SHARDED_OPS[func](types, args, kwargs, st._process_group) + + # Dispatch to custom sharding spec op if it has one. + if _has_custom_op(st._sharding_spec, func): + return _dispatch_custom_op( + st._sharding_spec, func, types, args, kwargs, st._process_group + ) + + if func in _SHARDED_OPS: + return _SHARDED_OPS[func](types, args, kwargs, st._process_group) + + raise RuntimeError( + f"torch function '{func.__name__}', with args: {args} and " + f"kwargs: {kwargs} not supported for ShardedTensor!" + ) + + # Find ShardedTensor instance to get process_group and sharding_spec. + st_instance = None + + def find_sharded_tensor(e): + nonlocal st_instance + if st_instance is None and isinstance(e, ShardedTensor): + st_instance = e + + pytree.tree_map_(find_sharded_tensor, args) + pytree.tree_map_(find_sharded_tensor, kwargs) + + if st_instance is not None: + return dispatch(st_instance, func) + + raise RuntimeError( + f"torch function '{func.__name__}', with args: {args} and " + f"kwargs: {kwargs} not supported for ShardedTensor!" + ) + + def is_pinned(self) -> bool: # type: ignore[override] + """ + Returns True if the sharded tensor (each local shard) resides in pinned memory. + """ + return self._metadata.tensor_properties.pin_memory + + def _register_remote_shards( + self, remote_shards: List[rpc.RRef[Shard]], rpc_rank: int + ): + self._remote_shards[rpc_rank] = remote_shards + + def remote_shards(self) -> Dict[int, List[rpc.RRef[Shard]]]: + """ + Returns a Dict[int, RRef] with keys being the RPC rank and values + being RRefs to shards on that rank. Need to initialize the + RPC framework for this functionality. + + Raises an exception if ShardedTensor was created with ``init_rrefs=False`` + """ + if not self._init_rrefs: + raise RuntimeError( + "ShardedTensor created with init_rrefs=False, no RRefs to remote shards available" + ) + return self._remote_shards + + def __hash__(self): + return id(self) + + def __repr__(self) -> str: # type: ignore[override] + return f"ShardedTensor({self._metadata})" + + @dataclass + class ProcessGroupState: + """ + State for ser-de of process group + """ + + local_rank: int + global_rank: int + local_world_size: int + global_world_size: int + + def __getstate__(self): + pg_state = ShardedTensor.ProcessGroupState( + distributed_c10d.get_rank(self._process_group), + distributed_c10d.get_rank(), + distributed_c10d.get_world_size(self._process_group), + distributed_c10d.get_world_size(), + ) + + return ( + self._local_shards, + self._metadata, + pg_state, + self._sharding_spec, + self._init_rrefs, + ) + + def __setstate__(self, state): + self._sharded_tensor_id = None + if not distributed_c10d.is_initialized(): + raise RuntimeError( + "Need to initialize default process group using " + '"init_process_group" before loading ShardedTensor' + ) + + ( + self._local_shards, + self._metadata, + pg_state, + self._sharding_spec, + self._init_rrefs, + ) = state + + # Setup process group + from core.distributed._shard.api import _get_current_process_group + + self._process_group = _get_current_process_group() + + # Validate process group. + local_rank = distributed_c10d.get_rank(self._process_group) + if pg_state.local_rank != local_rank: + raise RuntimeError( + f"Local rank at save time was {pg_state.local_rank}, but at " + f"load time was {local_rank}" + ) + + global_rank = distributed_c10d.get_rank() + if pg_state.global_rank != global_rank: + raise RuntimeError( + f"Global rank at save time was {pg_state.global_rank}, but at " + f"load time was {global_rank}" + ) + + local_world_size = distributed_c10d.get_world_size(self._process_group) + if pg_state.local_world_size != local_world_size: + raise RuntimeError( + f"Local world size at save time was {pg_state.local_world_size}, " + f"but at load time was {local_world_size}" + ) + + global_world_size = distributed_c10d.get_world_size() + if pg_state.global_world_size != global_world_size: + raise RuntimeError( + f"Global world size at save time was {pg_state.global_world_size}, " + f"but at load time was {global_world_size}" + ) + + self._post_init() + + +def _create_tensor_from_params( + *size, local_device, tensor_properties: TensorProperties +): + """Helper to construct tensor from size, device and common params.""" + dtype = tensor_properties.dtype + layout = tensor_properties.layout + requires_grad = tensor_properties.requires_grad + memory_format = tensor_properties.memory_format + pin_memory = tensor_properties.pin_memory + + return core.empty( + *size, + dtype=dtype, + layout=layout, + device=local_device, + requires_grad=requires_grad, + memory_format=memory_format, + pin_memory=pin_memory, + ) diff --git a/mindnlp/core/distributed/_shard/sharded_tensor/logger.py b/mindnlp/core/distributed/_shard/sharded_tensor/logger.py new file mode 100644 index 000000000..ae56a5c8f --- /dev/null +++ b/mindnlp/core/distributed/_shard/sharded_tensor/logger.py @@ -0,0 +1,36 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from typing import List, Tuple + +from core.distributed._shard.sharded_tensor.logging_handlers import _log_handlers + + +__all__: List[str] = [] + + +def _get_or_create_logger() -> logging.Logger: + logging_handler, log_handler_name = _get_logging_handler() + logger = logging.getLogger(f"sharding-spec-{log_handler_name}") + logger.setLevel(logging.DEBUG) + formatter = logging.Formatter( + "%(asctime)s %(filename)s:%(lineno)s %(levelname)s p:%(processName)s t:%(threadName)s: %(message)s" + ) + logging_handler.setFormatter(formatter) + logger.propagate = False + logger.addHandler(logging_handler) + return logger + + +def _get_logging_handler( + destination: str = "default", +) -> Tuple[logging.Handler, str]: + log_handler = _log_handlers[destination] + log_handler_name = type(log_handler).__name__ + return (log_handler, log_handler_name) diff --git a/mindnlp/core/distributed/_shard/sharded_tensor/logging_handlers.py b/mindnlp/core/distributed/_shard/sharded_tensor/logging_handlers.py new file mode 100644 index 000000000..b1b02a635 --- /dev/null +++ b/mindnlp/core/distributed/_shard/sharded_tensor/logging_handlers.py @@ -0,0 +1,17 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from typing import Dict, List + + +__all__: List[str] = [] + +_log_handlers: Dict[str, logging.Handler] = { + "default": logging.NullHandler(), +} diff --git a/mindnlp/core/distributed/_shard/sharded_tensor/metadata.py b/mindnlp/core/distributed/_shard/sharded_tensor/metadata.py new file mode 100644 index 000000000..8ed6c04f1 --- /dev/null +++ b/mindnlp/core/distributed/_shard/sharded_tensor/metadata.py @@ -0,0 +1,95 @@ +# mypy: allow-untyped-defs +from dataclasses import dataclass, field +from enum import Enum +from typing import List + +from mindnlp import core +from core.distributed._shard.metadata import ShardMetadata + + +class MEM_FORMAT_ENCODING(Enum): + TORCH_CONTIGUOUS_FORMAT = 0 + TORCH_CHANNELS_LAST = 1 + TORCH_PRESERVE_FORMAT = 2 + + +@dataclass +class TensorProperties: + """Properties used to create :class:`Tensor`""" + + # Regular tensor fields + dtype: core.dtype = field(default=core.get_default_dtype()) + # layout: core.layout = field(default=core.strided) + requires_grad: bool = False + # memory_format: core.memory_format = field(default=core.contiguous_format) + pin_memory: bool = False + + def __getstate__(self): + # Since core.memory_format cannot be pickled! + # memory_format = self.memory_format + # if memory_format == core.contiguous_format: + # mem_format_encoding = MEM_FORMAT_ENCODING.TORCH_CONTIGUOUS_FORMAT + # elif memory_format == core.channels_last: + # mem_format_encoding = MEM_FORMAT_ENCODING.TORCH_CHANNELS_LAST + # elif memory_format == core.preserve_format: + # mem_format_encoding = MEM_FORMAT_ENCODING.TORCH_PRESERVE_FORMAT + # else: + # raise RuntimeError(f"Invalid core.memory_format: {memory_format}") + + return ( + self.dtype, + # self.layout, + self.requires_grad, + # mem_format_encoding, + # self.pin_memory, + ) + + def __setstate__( + self, + state, + ): + ( + self.dtype, + # self.layout, + self.requires_grad, + # mem_format_encoding, + # self.pin_memory, + ) = state + + # if mem_format_encoding == MEM_FORMAT_ENCODING.TORCH_CONTIGUOUS_FORMAT: + # memory_format = core.contiguous_format + # elif mem_format_encoding == MEM_FORMAT_ENCODING.TORCH_CHANNELS_LAST: + # memory_format = core.channels_last + # elif mem_format_encoding == MEM_FORMAT_ENCODING.TORCH_PRESERVE_FORMAT: + # memory_format = core.preserve_format + # else: + # raise RuntimeError( + # f"Invalid core.memory_format encoding: {mem_format_encoding}" + # ) + + # self.memory_format = memory_format + + @staticmethod + def create_from_tensor(tensor: core.Tensor) -> "TensorProperties": + return TensorProperties( + dtype=tensor.dtype, + # layout=tensor.layout, + requires_grad=tensor.requires_grad, + # memory_format=core.contiguous_format, + # pin_memory=tensor.is_pinned(), + ) + + +@dataclass +class ShardedTensorMetadata: + """ + Represents metadata for :class:`ShardedTensor` + """ + + # Metadata about each shard of the Tensor + shards_metadata: List[ShardMetadata] = field(default_factory=list) + + # Size of each dim of the overall Tensor. + size: core.Size = field(default=core.Size([])) + + tensor_properties: TensorProperties = field(default_factory=TensorProperties) diff --git a/mindnlp/core/distributed/_shard/sharded_tensor/reshard.py b/mindnlp/core/distributed/_shard/sharded_tensor/reshard.py new file mode 100644 index 000000000..11100b815 --- /dev/null +++ b/mindnlp/core/distributed/_shard/sharded_tensor/reshard.py @@ -0,0 +1,246 @@ +# mypy: allow-untyped-defs +import copy +from typing import List, Tuple + +from mindnlp import core +from mindnlp import core.distributed as dist +from mindnlp import core.distributed._shard.sharding_spec as shard_spec +from core.distributed._shard.metadata import ShardMetadata +from core.distributed._shard.sharding_spec._internals import ( + get_chunked_dim_size, + get_split_size, +) +from core.distributed.nn.functional import all_to_all, all_to_all_single +from ...c10d import ProcessGroup + +from .shard import Shard + + +def get_idx_from_placements(placements, current_rank) -> int: + """ + Return the position of the current rank in the given placements. + + Args: + placements(List[Union[_remote_device, str]]): + Specifies the placement of each shard of the Tensor. The size of + the list represents the number of shards to be created. This could + be a list of + :class:`core.distributed._remote_device`'s. This list + could also contain a string which represents remote + device as accepted by + :class:`core.distributed._remote_device` + current_rank (int): number of current device. + + Returns: + A int which contains the position of current device in the placement list. + """ + for idx, placement in enumerate(placements): # type: ignore[attr-defined] + if current_rank == placement.rank(): # type: ignore[union-attr] + return idx + raise RuntimeError("current_rank not in the placement.") + + +def build_reshard_metadata( + st_size: core.Size, + sharding_spec: shard_spec.ShardingSpec, + world_size: int, +) -> Tuple[List[ShardMetadata], List[int]]: + """ + Based the given sharding spec, we calculate the offset and local shard size. + We then build a ShardMetadata on top of the calculation result. + + Args: + st_size (core.Size): The size of the sharded tensor. + sharding_spec (:class:`core.distributed._shard.sharding_spec.ShardingSpec`): The + specification describing how the tensor is sharded. + world_size (int): number of ranks. + + Returns: + A Tuple of the followings: + A List[`ShardMetadata`] which contains the metadata for the shard, including + offsets, lengths and device placement. + A List[int] which contains the ranks in the order of placement. + """ + shard_dim = int(sharding_spec.dim) # type: ignore[attr-defined] + shards_metadata = [None] * world_size + ranks = [] + offsets = [0] * len(st_size) + split_size = get_split_size(st_size[shard_dim], world_size) + for idx, placement in enumerate(sharding_spec.placements): # type: ignore[attr-defined] + ranks.append(placement.rank()) + sharded_dim_size = get_chunked_dim_size(st_size[shard_dim], split_size, idx) + local_tensor_size = list(st_size) + local_tensor_size[shard_dim] = sharded_dim_size + shards_metadata[placement.rank()] = ShardMetadata( # type: ignore[call-overload] + shard_offsets=copy.deepcopy(offsets), + shard_sizes=local_tensor_size, + placement=placement, + ) + offsets[shard_dim] += sharded_dim_size + return shards_metadata, ranks # type: ignore[return-value] + + +def reshuffle_local_shard( + local_shard: core.Tensor, + st_size: core.Size, + sharding_spec: shard_spec.ShardingSpec, + resharding_spec: shard_spec.ShardingSpec, + pg: ProcessGroup, +) -> Tuple[List[Shard], List[ShardMetadata]]: + """ + Reshuffle the local shard directly when the reshard dim is same as the original + sharding dim. Logically we do this in two step: + 1. To collect all shards based on original sharding spec. + 2. Reshard the tensor based on the given resharding spec. + + In reality, we consolidate the two steps into one by sending the local tensor to + the new shard directly based on the resharding spec. + + Args: + local_shard (Tensor): Local tensor stored in the current rank. + st_size (core.Size): The size of the sharded tensor. + sharding_spec (:class:`core.distributed._shard.sharding_spec.ShardingSpec`): The + specification describing how the tensor is sharded originally. + resharding_spec (:class:`core.distributed._shard.sharding_spec.ShardingSpec`): The + specification describing how the tensor will be resharded. + pg (ProcessGroup): The process group to aggregate on. + + Returns: + A Tuple of the followings: + A List[`Shard`] which contains the local tensor and its metadata. + A List[`ShardMetadata`] which contains the metadata for the shard, including + offsets, lengths and device placement. + """ + current_rank = dist.get_rank(pg) + world_size = dist.get_world_size(pg) + # Build shards_metadata first. + shards_metadata, ranks = build_reshard_metadata( + st_size, resharding_spec, world_size + ) + # Get input split size for all2all. + reshard_dim = int(resharding_spec.dim) # type: ignore[attr-defined] + split_size = get_split_size(st_size[reshard_dim], world_size) + input_split_sizes = [0] * world_size + idx = get_idx_from_placements(sharding_spec.placements, current_rank) # type: ignore[attr-defined] + new_rank = resharding_spec.placements[idx].rank() # type: ignore[union-attr, attr-defined] + input_split_sizes[new_rank] = local_shard.size(reshard_dim) + # Get output split size for all2all. + output_split_sizes = [0] * world_size + new_idx = ranks.index(current_rank) + sharded_dim_size = get_chunked_dim_size(st_size[reshard_dim], split_size, new_idx) + output_split_sizes[new_rank] = sharded_dim_size + # Get gathered_input for all2all. + local_shard = local_shard.transpose(0, reshard_dim).contiguous() + gathered_input_size = list(local_shard.size()) + gathered_input_size[0] = sharded_dim_size + gathered_input = core.empty( + gathered_input_size, device=local_shard.device, dtype=local_shard.dtype + ) + # all2all. + local_shard = all_to_all_single( + gathered_input, + local_shard, + input_split_sizes=input_split_sizes, + output_split_sizes=output_split_sizes, + group=pg, + ) + local_tensor = local_shard.transpose(0, reshard_dim).contiguous() + local_shards = [Shard(local_tensor, shards_metadata[current_rank])] + return local_shards, shards_metadata + + +def reshard_local_shard( + local_tensor: core.Tensor, + st_size: core.Size, + sharding_spec: shard_spec.ShardingSpec, + resharding_spec: shard_spec.ShardingSpec, + pg: ProcessGroup, +) -> Tuple[List[Shard], List[ShardMetadata]]: + """ + Reshard a sharded tensor given the ``resharding_spec``. When the reshard dim is + different from the original sharding dim, we need to do two steps logically: + 1. To collect all shards based on original sharding spec. + 2. Reshard the tensor based on the given resharding spec. + + In reality, we consolidate the two steps into one by sending each rank the new + shard based on the resharding spec. + + Args: + local_tensor (Tensor): Local tensor stored in the current rank. + st_size (core.Size): The size of the sharded tensor. + sharding_spec (:class:`core.distributed._shard.sharding_spec.ShardingSpec`): The + specification describing how the tensor is sharded originally. + resharding_spec (:class:`core.distributed._shard.sharding_spec.ShardingSpec`): The + specification describing how the tensor will be resharded. + pg (ProcessGroup): The process group to aggregate on. + + Returns: + A Tuple of the followings: + A List[`Shard`] which contains the local tensor and its metadata. + A List[`ShardMetadata`] which contains the metadata for the shard, including + offsets, lengths and device placement. + """ + current_rank = dist.get_rank(pg) + world_size = dist.get_world_size(pg) + current_sharding_dim = int(sharding_spec.dim) # type: ignore[attr-defined] + reshard_dim = int(resharding_spec.dim) # type: ignore[attr-defined] + + # Build shards_metadata first. + shards_metadata, ranks = build_reshard_metadata( + st_size, resharding_spec, world_size + ) + + # Compute expected size + input_split_sizes = [ + metadata.shard_sizes[reshard_dim] for metadata in shards_metadata + ] + rearrange_input = any(ranks[i] > ranks[i + 1] for i in range(len(ranks) - 1)) + + if rearrange_input: + # Need to re-arrange reshard_dim of local_tensor before all2all. + indices: List[int] = [] + for metadata in shards_metadata: + offset_start_idx = metadata.shard_offsets[reshard_dim] + split_size = metadata.shard_sizes[reshard_dim] + indices += range(offset_start_idx, offset_start_idx + split_size) + local_tensor = local_tensor.index_select( + reshard_dim, core.tensor(indices, device=local_tensor.device) + ) + + # Because reshard_dim != original shard_dim. We need to compute the + # size of tensor from each rank. + output_tensor_list = [core.tensor(1)] * world_size + split_size = get_split_size(st_size[current_sharding_dim], world_size) + rearrange_output_list = False + indices = [] + for idx, placement in enumerate(sharding_spec.placements): # type: ignore[attr-defined] + sharded_dim_size = get_chunked_dim_size( + st_size[current_sharding_dim], split_size, idx + ) + output_tensor_size = list(st_size) + output_tensor_size[current_sharding_dim] = sharded_dim_size + output_tensor_size[reshard_dim] = input_split_sizes[current_rank] + output_tensor_list[ + placement.rank() + ] = core.empty( # type: ignore[union-attr, index] + output_tensor_size, device=local_tensor.device, dtype=local_tensor.dtype + ) + indices.append(placement.rank()) # type: ignore[union-attr, index, arg-type] + if idx != placement.rank(): # type: ignore[union-attr] + rearrange_output_list = True + + # Perform autograd enabled all2all. + input_tensor_tuple = core.split(local_tensor, input_split_sizes, dim=reshard_dim) + input_tensor_list = [tensor.contiguous() for tensor in input_tensor_tuple] + output_tensor_list = all_to_all( + output_tensor_list, + input_tensor_list, + group=pg, + ) + + if rearrange_output_list: + # Need to re-arrange original shard_dim of output_tensor_list. + output_tensor_list = [output_tensor_list[idx] for idx in indices] # type: ignore[call-overload] + local_tensor = core.cat(output_tensor_list, dim=current_sharding_dim) + local_shards = [Shard(local_tensor, shards_metadata[current_rank])] + return local_shards, shards_metadata diff --git a/mindnlp/core/distributed/_shard/sharded_tensor/shard.py b/mindnlp/core/distributed/_shard/sharded_tensor/shard.py new file mode 100644 index 000000000..89ab8d9bf --- /dev/null +++ b/mindnlp/core/distributed/_shard/sharded_tensor/shard.py @@ -0,0 +1,63 @@ +# mypy: allow-untyped-defs +from dataclasses import dataclass +from typing import List + +from mindnlp import core +from core.distributed._shard.metadata import ShardMetadata +from core.distributed.remote_device import _remote_device + + +@dataclass +class Shard: + """ + Container which holds the data for a shard as a Tensor and also + the associated metadata for that shard. + + Args: + tensor(core.Tensor): Local tensor for the shard. + metadata(:class `core.distributed._shard.sharded_tensor.ShardMetadata`): + The metadata for the shard, including offsets, lengths and device placement. + """ + + __slots__ = ["tensor", "metadata"] + tensor: core.Tensor + metadata: ShardMetadata + + def __post_init__(self): + # verification between local tensor and metadata + if list(self.tensor.size()) != self.metadata.shard_sizes: + raise ValueError( + "Shard tensor size does not match with metadata.shard_lengths! " + f"Found shard tensor size: {list(self.tensor.size())}, " + f"metadata.shard_lengths: {self.metadata.shard_sizes}, " + ) + placement_device = self.metadata.placement + # if ( + # placement_device is not None + # and placement_device.device() != self.tensor.device + # ): + # raise ValueError( + # f"Local shard tensor device does not match with local Shard's placement! " + # f"Found local shard tensor device: {self.tensor.device}, " + # f"local shard metadata placement device: {placement_device.device()}" + # ) + + @classmethod + def from_tensor_and_offsets( + cls, tensor: core.Tensor, shard_offsets: List[int], rank: int + ): + """ + Creates a Shard of a ShardedTensor from a local core.Tensor, shard_offsets and rank. + + Args: + tensor(core.Tensor): Local tensor for the shard. + shard_offsets(List[int]): List of integers specify the offset + of the shard on each dimension. + rank(int): Specify the rank for the shard. + """ + shard_sizes = list(tensor.size()) + placement = _remote_device(f"rank:{rank}/{str(tensor.device)}") + shard_meta = ShardMetadata( + shard_offsets=shard_offsets, shard_sizes=shard_sizes, placement=placement + ) + return Shard(tensor, shard_meta) diff --git a/mindnlp/core/distributed/_shard/sharded_tensor/utils.py b/mindnlp/core/distributed/_shard/sharded_tensor/utils.py new file mode 100644 index 000000000..1b7715f1a --- /dev/null +++ b/mindnlp/core/distributed/_shard/sharded_tensor/utils.py @@ -0,0 +1,267 @@ +# mypy: allow-untyped-defs +import collections.abc +import copy +from typing import List, Optional, Sequence, TYPE_CHECKING + +from mindnlp import core +from core.distributed import distributed_c10d as c10d +from core.distributed._shard.sharding_spec._internals import ( + check_tensor, + validate_non_overlapping_shards_metadata, +) + +from .metadata import ShardedTensorMetadata, TensorProperties +from .shard import Shard + + +if TYPE_CHECKING: + from core.distributed._shard.metadata import ShardMetadata + + +def _parse_and_validate_remote_device(pg, remote_device): + if remote_device is None: + raise ValueError("remote device is None") + + worker_name = remote_device.worker_name() + rank = remote_device.rank() + device = remote_device.device() + + # Validate rank, skip validation if rank is not part of process group. + if rank is not None and not c10d._rank_not_in_group(pg): + pg_global_ranks = c10d.get_process_group_ranks(pg) + if rank not in pg_global_ranks: + raise ValueError( + f"Global rank {rank} does not exist in input process group: {pg_global_ranks}" + ) + + # if worker_name is not None: + # if not rpc._is_current_rpc_agent_set(): + # raise RuntimeError( + # f"RPC framework needs to be initialized for using worker names: {worker_name}" + # ) + + # workers = rpc._get_current_rpc_agent().get_worker_infos() + # for worker in workers: + # if worker.name == worker_name: + # return worker.id, device + + # raise ValueError(f"Invalid worker name: {worker_name}") + + return rank, device + + +def _validate_output_tensor_for_gather( + my_rank: int, + dst_rank: int, + size: core.Size, + dst_tensor: Optional[core.Tensor], +) -> None: + if dst_rank == my_rank: + if dst_tensor is None: + raise ValueError( + f"Argument ``dst_tensor`` must be specified on destination rank {dst_rank}" + ) + if tuple(size) != (dst_tensor.size()): + raise ValueError( + f"Argument ``dst_tensor`` have size {tuple(dst_tensor.size())}," + f"but should be {tuple(size)}" + ) + elif dst_tensor: + raise ValueError( + "Argument ``dst_tensor`` must NOT be specified " "on non-destination ranks." + ) + + +def _flatten_tensor_size(size) -> core.Size: + """ + Checks if tensor size is valid, then flatten/return a core.Size object. + """ + if len(size) == 1 and isinstance(size[0], collections.abc.Sequence): + dims = list(*size) + else: + dims = list(size) + + for dim in dims: + if not isinstance(dim, int): + raise TypeError(f"size has to be a sequence of ints, found: {dims}") + + return core.Size(dims) + + +def _raise_if_mismatch(expected, actual, prop_name, ranks, is_local=True): + if is_local: + assert isinstance(ranks, int) + if expected != actual: + raise ValueError( + f"Local shards' tensor {prop_name} property need to be the same on rank:{ranks}! " + f"Found one local shard tensor {prop_name}={expected}, " + f"the other local shard tensor {prop_name}={actual}." + ) + else: + # compare failure check across ranks, ranks list should have two rank + assert len(ranks) == 2 + if expected != actual: + raise ValueError( + f"ShardedTensor {prop_name} property does not match from different ranks! " + f"Found {prop_name}={expected} on rank:{ranks[0]}, " + f"and {prop_name}={actual} on rank:{ranks[1]}." + ) + + +def build_metadata_from_local_shards( + local_shards: List[Shard], + global_size: core.Size, + current_rank: int, + pg: c10d.ProcessGroup, +) -> ShardedTensorMetadata: + assert len(local_shards) > 0, "must have local shards!" + local_shard_metadatas: List[ShardMetadata] = [] + + first_shard_dtype = local_shards[0].tensor.dtype + first_shard_layout = local_shards[0].tensor.layout + first_shard_requires_grad = local_shards[0].tensor.requires_grad + first_shard_is_pinned = local_shards[0].tensor.is_pinned() + + # 1). Validate local tensors and associated metadatas + for local_shard in local_shards: + local_shard_tensor = local_shard.tensor + local_shard_meta = local_shard.metadata + local_shard_metadatas.append(local_shard_meta) + rank, local_device = _parse_and_validate_remote_device( + pg, local_shard_meta.placement + ) + + if ( + local_shard_tensor.layout != core.strided + or local_shard_tensor.layout != first_shard_layout + ): + raise ValueError( + f"Only core.strided layout is currently supported, but found " + f"{local_shard_tensor.layout} on rank:{current_rank}!" + ) + + if not local_shard_tensor.is_contiguous(): + raise ValueError( + "Only core.contiguous_format memory_format is currently supported!" + ) + + if rank != current_rank: + raise ValueError( + f"Local shard metadata's rank does not match with the rank in its process group! " + f"Found current rank in the process group: {current_rank}, " + f"local ShardMetadata placement's rank: {rank}" + ) + if local_shard_tensor.device != local_device: + raise ValueError( + f"Local shard tensor device does not match with local Shard's placement! " + f"Found local shard tensor device: {local_shard_tensor.device}, " + f"local shard metadata placement device: {local_device}" + ) + + _raise_if_mismatch( + local_shard_meta.shard_sizes, + list(local_shard_tensor.size()), + "size", + current_rank, + ) + _raise_if_mismatch( + local_shard_tensor.is_pinned(), + first_shard_is_pinned, + "pin_memory", + current_rank, + ) + _raise_if_mismatch( + local_shard_tensor.dtype, first_shard_dtype, "dtype", current_rank + ) + _raise_if_mismatch( + local_shard_tensor.requires_grad, + first_shard_requires_grad, + "requires_grad", + current_rank, + ) + + # 2). Build a "local" ShardedTensorMetadata with all local shards on this rank, then + # do all_gather to collect local_sharded_tensor_metadata from all ranks + local_tensor_properties = TensorProperties( + dtype=first_shard_dtype, + layout=first_shard_layout, + requires_grad=first_shard_requires_grad, + memory_format=core.contiguous_format, + pin_memory=first_shard_is_pinned, + ) + + local_sharded_tensor_metadata = ShardedTensorMetadata( + shards_metadata=local_shard_metadatas, + size=global_size, + tensor_properties=local_tensor_properties, + ) + + return local_sharded_tensor_metadata + + +def build_global_metadata( + gathered_metadatas: Sequence[Optional[ShardedTensorMetadata]], +): + global_sharded_tensor_metadata = None + global_metadata_rank = 0 + + for rank, rank_metadata in enumerate(gathered_metadatas): + if rank_metadata is None: + continue + + if global_sharded_tensor_metadata is None: + global_sharded_tensor_metadata = copy.deepcopy(rank_metadata) + global_metadata_rank = rank + else: + _raise_if_mismatch( + global_sharded_tensor_metadata.size, + rank_metadata.size, + "global_size", + [global_metadata_rank, rank], + is_local=False, + ) + + # don't need to check layout and memory format as we already checked in local shards validation stage + _raise_if_mismatch( + global_sharded_tensor_metadata.tensor_properties.dtype, + rank_metadata.tensor_properties.dtype, + "dtype", + [global_metadata_rank, rank], + is_local=False, + ) + + _raise_if_mismatch( + global_sharded_tensor_metadata.tensor_properties.requires_grad, + rank_metadata.tensor_properties.requires_grad, + "requires_grad", + [global_metadata_rank, rank], + is_local=False, + ) + + _raise_if_mismatch( + global_sharded_tensor_metadata.tensor_properties.pin_memory, + rank_metadata.tensor_properties.pin_memory, + "pin_memory", + [global_metadata_rank, rank], + is_local=False, + ) + # pass all validations, extend shards metadata + global_sharded_tensor_metadata.shards_metadata.extend( + rank_metadata.shards_metadata + ) + + if global_sharded_tensor_metadata is not None: + # check if shards_metadata have overlap shards + validate_non_overlapping_shards_metadata( + global_sharded_tensor_metadata.shards_metadata + ) + + # check if the shards_metadata is compatible with global size of the sharded tensor. + check_tensor( + global_sharded_tensor_metadata.shards_metadata, + global_sharded_tensor_metadata.size, + ) + else: + raise ValueError("ShardedTensor have no local shards on all ranks!") + + return global_sharded_tensor_metadata diff --git a/mindnlp/core/distributed/_shard/sharder.py b/mindnlp/core/distributed/_shard/sharder.py new file mode 100644 index 000000000..1c17b427d --- /dev/null +++ b/mindnlp/core/distributed/_shard/sharder.py @@ -0,0 +1,29 @@ +import abc + +from mindnlp import core.nn as nn + + +class Sharder(abc.ABC): + """ + This is an interface which allows user to create more advanced + sharding strategies that are not easily be composed by the + `ShardingSpec`. + + :class:`core.distributed._shard.sharding_plan.ShardingPlan` could + take an object of the `Sharder` and call `shard` to shard the module, + then replace the original module with sharded module returned. + """ + + @abc.abstractmethod + def shard(self, module: nn.Module) -> nn.Module: + """ + Shard a module base on the implementation of this method, and + return the sharded version of the module. + + Args: + module (:class:`core.nn.Module`): + The module to apply sharding to. + Returns: + A :class:`core.nn.Module` object that represents a module + that's already been sharded. + """ diff --git a/mindnlp/core/distributed/_shard/sharding_plan/__init__.py b/mindnlp/core/distributed/_shard/sharding_plan/__init__.py new file mode 100644 index 000000000..3c8662fba --- /dev/null +++ b/mindnlp/core/distributed/_shard/sharding_plan/__init__.py @@ -0,0 +1 @@ +from .api import ShardingPlan, ShardingPlanner diff --git a/mindnlp/core/distributed/_shard/sharding_plan/api.py b/mindnlp/core/distributed/_shard/sharding_plan/api.py new file mode 100644 index 000000000..de8d3e36d --- /dev/null +++ b/mindnlp/core/distributed/_shard/sharding_plan/api.py @@ -0,0 +1,87 @@ +import abc +from dataclasses import dataclass +from typing import Dict, List, Optional, Union + +from mindnlp import core.nn as nn +from core.distributed._shard.sharder import Sharder +from core.distributed._shard.sharding_spec import ShardingSpec + + +@dataclass +class ShardingPlan: + """ + Representation of a sharding plan, describes how to shard a module + across hosts. `plan` is used to shard module parameters according to the spec provided, + `output_plan` and `return_local_tensor` are optional, they are used to specify the output + layout of a module with a spec, and when to convert back to data parallel fashion. + + Args: + plan (Dict[str, Union[:class:`core.distributed._shard.sharding_spec.ShardingSpec`, + :class:`core.distributed._shard.sharder.Sharder`]): + a dict describes how to shard a module, there're currently two ways to shard a module: + 1. directly shard a module parameter by a `ShardingSpec`, keyed by the name of + a parameter to a `ShardingSpec`. + 2. shard a submodule by applying a `Sharder` on it, keyed by the name of a module + to a `Sharder` object. + output_plan (Dict[str, :class:`core.distributed._shard.sharding_spec.ShardingSpec`), optional): + a dict specifies the layout of a module's output which produces a ShardedTensor, + keyed by the name of module to ShardingSpec("" in key means the root module). + Default: `None` + return_local_tensor (List[str], optional): a list of string, each element enables + a module's sharded output to be returned as a Tensor from its local shards to + ensure further processing in a data parallel fashion. ("" in list means the + root module). + Default: None + Example: + Suppose we want to shard a module with two linear layers and then run it with DDP, we also + want to convert the output of the second linear layer back to DDP, we can do it as follows: + + >>> # xdoctest: +REQUIRES(module:core._C._distributed_c10d) + >>> class MyModule(nn.Module): + >>> def __init__(self) -> None: + >>> super().__init__() + >>> self.fc1 = nn.Linear() + >>> self.gelu = nn.GELU() + >>> self.fc2 = nn.Linear() + >>> self.relu = nn.Linear() + >>> + >>> def forward(self, input): + >>> return self.relu(self.fc2(self.gelu(self.fc1(input)))) + + + >>> # xdoctest: +SKIP("Undefined spec1, spec2) + >>> sharding_plan = ShardingPlan( + >>> plan={ + >>> "fc1.weight": spec1, + >>> "fc2.weight": spec2 + >>> }, + >>> output_plan={ + >>> "fc2": output_spec + >>> }, + >>> return_local_tensor=["fc2"] + >>> ) + """ + + plan: Dict[str, Union[ShardingSpec, Sharder]] + output_plan: Optional[Dict[str, ShardingSpec]] = None + return_local_tensor: Optional[List[str]] = None + + +class ShardingPlanner(abc.ABC): + """ + Default ShardingPlanner interface, can be extended and + implement advanced sharding strategies. + """ + + @abc.abstractmethod + def build_plan(self, module: nn.Module) -> ShardingPlan: + """ + Given a nn.Module, define how to shard the module across + ranks, return a ShardingPlan + Args: + module (:class:`core.nn.Module`): + The module to apply sharding to. + Returns: + A :class:`core.distributed._shard.sharding_plan.ShardingPlan` object that + represents how to shard the module. + """ diff --git a/mindnlp/core/distributed/_shard/sharding_spec/__init__.py b/mindnlp/core/distributed/_shard/sharding_spec/__init__.py new file mode 100644 index 000000000..d2b5fbe9d --- /dev/null +++ b/mindnlp/core/distributed/_shard/sharding_spec/__init__.py @@ -0,0 +1,10 @@ +from core.distributed._shard.metadata import ShardMetadata + +from .api import ( + _infer_sharding_spec_from_shards_metadata, + DevicePlacementSpec, + EnumerableShardingSpec, + PlacementSpec, + ShardingSpec, +) +from .chunk_sharding_spec import ChunkShardingSpec as ChunkShardingSpec diff --git a/mindnlp/core/distributed/_shard/sharding_spec/_internals.py b/mindnlp/core/distributed/_shard/sharding_spec/_internals.py new file mode 100644 index 000000000..74ae539d3 --- /dev/null +++ b/mindnlp/core/distributed/_shard/sharding_spec/_internals.py @@ -0,0 +1,217 @@ +# mypy: allow-untyped-defs +from typing import List, Optional, Tuple + +from core.distributed._shard.metadata import ShardMetadata + + +def _check_shard_metadata_pair_overlap(shard1: ShardMetadata, shard2: ShardMetadata): + """ + Checks if two shards overlap. + """ + + # For each dim of each shard, check if one shard resides on the other + # end of second shard with respect to that dim. As an example for a 2D + # shard, we would check if one shard is above or on the left of the + # other shard. + ndims = len(shard1.shard_offsets) + for i in range(ndims): + if shard1.shard_offsets[i] >= shard2.shard_offsets[i] + shard2.shard_sizes[i]: + return False + if shard2.shard_offsets[i] >= shard1.shard_offsets[i] + shard1.shard_sizes[i]: + return False + + return True + + +def _find_nd_overlapping_shards( + shards: List[ShardMetadata], sharded_dims: List[int] +) -> Optional[Tuple[int, int]]: + # Each rank has len(sharded_dims) tuples. Each tuple represent the + # [begin, end] (inclusive) pair of that dimension. + shard_intervals = [ + [ + (s.shard_offsets[dim], s.shard_offsets[dim] + s.shard_sizes[dim] - 1) + for dim in sharded_dims + ] + for s in shards + ] + + for i in range(len(shards)): + shard_i = shard_intervals[i] + for j in range(i + 1, len(shards)): + shard_j = shard_intervals[j] + # For each dim of each shard, check if one shard resides on the other + # end of second shard with respect to that dim. As an example for a 2D + # shard, we would check if one shard is above or on the left of the + # other shard. + overlap = True + for interval_i, interval_j in zip(shard_i, shard_j): + if interval_i[0] > interval_j[1] or interval_j[0] > interval_i[1]: + overlap = False + break + if overlap: + return (i, j) + return None + + +def _find_1d_overlapping_shards( + shards: List[ShardMetadata], dim: int +) -> Optional[Tuple[int, int]]: + # (begin, end, index_in_shards). Begin and end are inclusive. + intervals = [ + (s.shard_offsets[dim], s.shard_offsets[dim] + s.shard_sizes[dim] - 1, i) + for i, s in enumerate(shards) + ] + intervals.sort() + for i in range(len(shards) - 1): + if intervals[i][1] >= intervals[i + 1][0]: + return (intervals[i][2], intervals[i + 1][2]) + return None + + +def validate_non_overlapping_shards_metadata(shards: List[ShardMetadata]): + """ + Ensures none of the shards overlap with each other. + + Args: + shards(List[ShardMetadata]): List of :class:`ShardMetadata` objects representing + each shard. + Raises: + ``ValueError`` if there's overlap in any two shards. + """ + if not shards or len(shards) == 1: + return + + sharded_dims: List[int] = [] + for dim in range(len(shards[0].shard_offsets)): + for i in range(1, len(shards)): + if ( + shards[i].shard_offsets[dim] != shards[0].shard_offsets[dim] + or shards[i].shard_sizes[dim] != shards[0].shard_sizes[dim] + ): + sharded_dims.append(dim) + break + + pair: Optional[Tuple[int, int]] = None + if len(sharded_dims) == 0: + # All shards are the same, all dims are not partitioned. Choose any 2. + pair = (0, 1) + elif len(sharded_dims) == 1: + # Shards are partitioned over only one dimension. Overlap can be found + # using a O(nlogn) overlapping interval algorithm. + pair = _find_1d_overlapping_shards(shards, sharded_dims[0]) + else: + # Shards are partitioned over more than one dimension. Fall back to + # pair-wise check. Even though O(nlogn) algorithms (line sweep) exist + # for 2D overlap, the implementation is not trivial and may not justify + # the time saving in most cases. + pair = _find_nd_overlapping_shards(shards, sharded_dims) + + if pair: + raise ValueError(f"Shards {shards[pair[0]]} and {shards[pair[1]]} overlap") + + +def check_tensor(shards_metadata, tensor_dims) -> None: + """ + Checks if the shards_metadata is compatible with the provided tensor dims. + + Args: + shards_metadata(List[ShardMetadata]): List of :class:`ShardMetadata` + objects representing each shard of the tensor. + tensor_dims(Sequence of int): Dimensions of tensor to verify + Raises: + ``ValueError`` if not compatible. + """ + + # If the tensor's volume matches the total volume of all shards and + # all shard boundaries are within tensor dims, we have a compatible + # sharding spec for this tensor. Note that we have already verified + # we don't have overlapping shards. + tensor_rank = len(tensor_dims) + shards_rank = len(shards_metadata[0].shard_offsets) + if tensor_rank != shards_rank: + raise ValueError( + f"Rank of tensor is {tensor_rank}, but shards rank is {shards_rank}" + ) + + total_shard_volume = 0 + for shard in shards_metadata: + shard_volume = 1 + for i, shard_length in enumerate(shard.shard_sizes): + shard_volume *= shard_length + if shard.shard_offsets[i] + shard.shard_sizes[i] > tensor_dims[i]: + raise ValueError( + f"Shard offset {shard.shard_offsets[i]} and length " + f"{shard.shard_sizes[i]} exceeds tensor dim: {tensor_dims[i]} for shard {shard}" + ) + total_shard_volume += shard_volume + + tensor_volume = 1 + for size in tensor_dims: + tensor_volume *= size + + if total_shard_volume != tensor_volume: + # TODO: Can we improve this error message to point out the gaps? + raise ValueError( + f"Total volume of shards: {total_shard_volume} " + f"does not match tensor volume: {tensor_volume}, in other words " + f"all the individual shards do not cover the entire tensor" + ) + + +def get_split_size(dim_size, chunks): + """ + Computes the split size inline with ``core.chunk`` + + Args: + dim_size(int): Size of the dimension being chunked. + chunks(int): Number of chunks to create for ``dim_size``. + + Returns: + An int indicating the split size to use. + """ + return (dim_size + chunks - 1) // chunks + + +def get_chunked_dim_size(dim_size, split_size, idx): + """ + Computes the dim size of the chunk for provided ``idx`` given ``dim_size`` + and ``split_size``. + + Args: + dim_size(int): Size of the dimension being chunked. + split_size(int): The chunk size for each chunk of ``dim_size``. + idx(int): The index of chunk whose dim size is being requested. + + Returns: + An int indicating the dim size of the chunk. + """ + return max(min(dim_size, split_size * (idx + 1)) - split_size * idx, 0) + + +def get_chunk_sharding_params(sharding_dim_size, world_size, spec, rank): + """ + Generate the start pos and offset length for the current rank for + chunk sharding. + + Args: + sharding_dim_size(int): The dimension length which we shard on. + world_size(int): number of ranks. + spec (:class:`core.distributed._shard.sharding_spec.ChunkShardingSpec`): + sharding spec. + rank(int): # of cuda process. + + Returns: + start_pos(int): start position of sharded tensor on the given rank. + chunk_size(int): chunk size of sharded tensor on the given rank. + """ + split_size = get_split_size(sharding_dim_size, world_size) + current_offsets = 0 + start_pos = current_offsets + for idx, placement in enumerate(spec.placements): + chunk_size = get_chunked_dim_size(sharding_dim_size, split_size, idx) + if rank == placement.rank(): + start_pos = current_offsets + break + current_offsets += chunk_size + return start_pos, chunk_size # type: ignore[possibly-undefined] diff --git a/mindnlp/core/distributed/_shard/sharding_spec/api.py b/mindnlp/core/distributed/_shard/sharding_spec/api.py new file mode 100644 index 000000000..d143990b3 --- /dev/null +++ b/mindnlp/core/distributed/_shard/sharding_spec/api.py @@ -0,0 +1,263 @@ +# mypy: allow-untyped-defs +import functools +import operator +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Callable, Dict, List, TYPE_CHECKING + +from mindnlp import core +from mindnlp import core.distributed._shard.sharded_tensor.metadata as sharded_tensor_meta +from core.distributed._shard.metadata import ShardMetadata +from core.distributed._shard.op_registry_utils import _decorator_func + +from ._internals import ( + check_tensor, + get_chunked_dim_size, + get_split_size, + validate_non_overlapping_shards_metadata, +) + + +if TYPE_CHECKING: + # Only include ShardedTensor when do type checking, exclude it + # from run-time to resolve circular dependency. + from core.distributed._shard.sharded_tensor import ShardedTensor + + +class PlacementSpec(ABC): # noqa: B024 + """ + Base class representing the placement of an entity. Subclasses of this + class can be used to specify customized placements which might not be + covered by existing APIs. + """ + + +@dataclass +class DevicePlacementSpec(PlacementSpec): + """ + Associates placement of an entity with a single device. + + Args: + device(:class:`core.distributed._remote_device`): The device to place the entity on. + """ + + device: core.distributed._remote_device + + def __post_init__(self): + if not isinstance(self.device, core.distributed._remote_device): + self.device = core.distributed._remote_device(self.device) + + +class ShardingSpec(ABC): + """ + Base class representing sharding specifications. + """ + + @abstractmethod + def build_metadata( + self, + tensor_sizes: core.Size, + tensor_properties: sharded_tensor_meta.TensorProperties, + ) -> sharded_tensor_meta.ShardedTensorMetadata: + """ + Given a global tensor size, define how to shard a tensor like this shape + across ranks, return ShardedTensorMetadata + Args: + tensor_sizes (:class:`core.Size`): + The tensor shape to shard on, a `core.Size` object that represents the + tensor shape to be sharded according to the ShardingSpec. + tensor_properties(:class:`core.distributed._shard.sharded_tensor.TensorProperties): + Tensor properties used to create a ShardedTensor. + Returns: + A :class:`ShardedTensorMetadata` object that encodes the information about + the layout of the ShardedTensor and its properties. + """ + + @abstractmethod + def shard( + self, tensor: core.Tensor, src_rank: int = 0, process_group=None + ) -> "ShardedTensor": + """ + Given a global tensor on src_rank, shard this tensor + across ranks within the process group, return a ShardedTensor. + Args: + tensor (:class:`core.Tensor`): Tensor needs to be sharded. + Keyword args: + src_rank (int, optional): The source rank which is used as the ground truth of + the data for the parameter that would be sharded and scattered + across the rest of the ranks. + Default: 0. + process_group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + Returns: + A :class:`ShardedTensor` sharded from the given tensor. + """ + + +# Ops customized for a particular ShardingSpec. +_CUSTOM_SHARDING_SPEC_OPS: Dict[str, Dict[Callable, Callable]] = {} + + +def _has_custom_op(sharding_spec, op): + """ + Returns whether or not the ShardingSpec has a custom op implementation. + """ + class_name = type(sharding_spec).__qualname__ + return ( + class_name in _CUSTOM_SHARDING_SPEC_OPS + and op in _CUSTOM_SHARDING_SPEC_OPS[class_name] + ) + + +def _dispatch_custom_op( + sharding_spec, op: Callable, types, args, kwargs, process_group +): + """ + Calls the custom op for this ShardingSpec if it exists. + """ + class_name = type(sharding_spec).__qualname__ + if not _has_custom_op(sharding_spec, op): + raise RuntimeError(f"Custom op: {op} not registered for {class_name}") + func = _CUSTOM_SHARDING_SPEC_OPS[class_name][op] + return func(types, args, kwargs, process_group) + + +def custom_sharding_spec_op(sharding_spec_class, func): + """ + Decorator to allow custom registration of ops. + Args: + sharding_spec_class(type): The ShardingSpec for which we need to add this custom op. + func(Callable): The op to override (ex: core.bmm) + """ + class_name = sharding_spec_class.__qualname__ + if class_name not in _CUSTOM_SHARDING_SPEC_OPS: + _CUSTOM_SHARDING_SPEC_OPS[class_name] = {} + return functools.partial( + _decorator_func, op=func, op_table=_CUSTOM_SHARDING_SPEC_OPS[class_name] + ) + + +@dataclass +class EnumerableShardingSpec(ShardingSpec): + """ + This is a type of PlacementSpec that allows users to specify a generic + sharding scheme by enumerating exactly how each shard is laid out. + + Args: + shards(List[ShardMetadata]): List of :class:`ShardMetadata` objects representing + each shard. Note that none of the shards should overlap. + """ + + shards: List[ShardMetadata] + + def __post_init__(self): + if len(self.shards) == 0: + raise ValueError(f"Empty shard list provided: {self.shards}") + + # Validate each shard has same rank. + rank = -1 + for shard in self.shards: + if rank != -1 and rank != len(shard.shard_offsets): + raise ValueError( + f"Found inconsistent ranks for shards: {rank} and {len(shard.shard_offsets)}" + ) + rank = len(shard.shard_offsets) + + validate_non_overlapping_shards_metadata(self.shards) + + def build_metadata( + self, + tensor_sizes: core.Size, + tensor_properties: sharded_tensor_meta.TensorProperties, + ) -> sharded_tensor_meta.ShardedTensorMetadata: + # check if shards form a valid tensor + check_tensor(self.shards, tensor_sizes) + return sharded_tensor_meta.ShardedTensorMetadata( + self.shards, tensor_sizes, tensor_properties + ) + + def shard( + self, tensor: core.Tensor, src_rank: int = 0, process_group=None + ) -> "ShardedTensor": + # TODO: figure out a generic and efficient way to scatter the shards for EnumerableShardingSpec + raise NotImplementedError("EnumerableShardingSpec.shard not implemented yet!") + + +def _infer_sharding_spec_from_shards_metadata(shards_metadata): + """ + Infer the sharding spec from the metadata of each shard of a ShardedTensor. + If the tensor is sharded only on one dimension, we can then verify whether it's + a ChunkShardingSpec or not. The way to verify it is to first get the total length + and perform a chunk sharding with the given placements to see if we can have the + same chunk size as the given shards_metadata. If not, we assume it's enum sharded. + + Args: + shards_metadata (List[ShardMetadata]): List of Metadata of local shards. + + Returns: + A :class:`core.distributed._shard.sharding_spec.ShardingSpec` object of sharding + spec for one sharded tensor. + """ + placements = [] + chunk_sharding_dim = None + chunk_offset_list = [] + shard_size_list = [] + shard_offset_list = [] + # collect local shard metadatas from the global sharded_tensor_metadata + for shard_metadata in shards_metadata: # type: ignore[attr-defined] + placements.append(shard_metadata.placement) + local_offsets = shard_metadata.shard_offsets + chunk_offset_list.append(sum(local_offsets)) + shard_size_list.append(shard_metadata.shard_sizes) + shard_offset_list.append(shard_metadata.shard_offsets) + shard_dims = [idx for idx, e in enumerate(local_offsets) if e != 0] + # If the offset is [0, 0, ..., 0] (all zeros), + # we cannot decide whether how the tensor is sharded. + if len(shard_dims) == 0: + continue + # If the offset is [0, N, .,0, M, 0, .., 0], + # we are sure it's sharded by more than one dimension. + if len(shard_dims) != 1: + chunk_sharding_dim = None + break + # If the offset is [0, 0, .,0, M, 0, .., 0], aka, it's sharded by just + # one dimension, we need to make sure all ranks share the same dimension. + if not chunk_sharding_dim: + chunk_sharding_dim = shard_dims[0] + elif chunk_sharding_dim != shard_dims[0]: + chunk_sharding_dim = None + break + + if chunk_sharding_dim is not None: + # Ensure we infer the correct placement order from offsets + placements = [ + x + for _, x in sorted( + zip(chunk_offset_list, placements), key=operator.itemgetter(0) + ) + ] + + from .chunk_sharding_spec import ChunkShardingSpec + + chunk_spec = ChunkShardingSpec( + dim=chunk_sharding_dim, + placements=placements, + ) + + shard_sizes = sorted([x[chunk_sharding_dim] for x in shard_size_list]) + shard_total_length = sum(shard_sizes) + shard_offsets = sorted([x[chunk_sharding_dim] for x in shard_offset_list]) + + chunks = len(placements) + split_size = get_split_size(shard_total_length, chunks) + chunk_shard_sizes = sorted( + [ + get_chunked_dim_size(shard_total_length, split_size, idx) + for idx in range(chunks) + ] + ) + # Should match ChunkShardingSpec offsets calculation + chunk_shard_offsets = [split_size * idx for idx in range(chunks)] + if shard_sizes == chunk_shard_sizes and shard_offsets == chunk_shard_offsets: + return chunk_spec + return EnumerableShardingSpec(shards_metadata) diff --git a/mindnlp/core/distributed/_shard/sharding_spec/chunk_sharding_spec.py b/mindnlp/core/distributed/_shard/sharding_spec/chunk_sharding_spec.py new file mode 100644 index 000000000..15d92d285 --- /dev/null +++ b/mindnlp/core/distributed/_shard/sharding_spec/chunk_sharding_spec.py @@ -0,0 +1,226 @@ +# mypy: allow-untyped-defs +from dataclasses import dataclass +from typing import cast, List, Optional, TYPE_CHECKING, Union + +from mindnlp import core +from mindnlp import core.distributed as dist +from mindnlp import core.distributed._shard.sharded_tensor.metadata as sharded_tensor_meta +from mindnlp import core.distributed.distributed_c10d as distributed_c10d +from core.distributed._shard._utils import narrow_tensor +from core.distributed._shard.metadata import ShardMetadata +from core.distributed._shard.sharded_tensor.shard import Shard +from core.distributed._shard.sharded_tensor.utils import ( + _parse_and_validate_remote_device, +) + +from ._internals import get_chunked_dim_size, get_split_size +from .api import ShardingSpec + + +if TYPE_CHECKING: + # Only include ShardedTensor when do type checking, exclude it + # from run-time to resolve circular dependency. + from core.distributed._shard.sharded_tensor import ShardedTensor + + +@dataclass +class ChunkShardingSpec(ShardingSpec): + """ + This is a type of PlacementSpec that defines the placement as being sharded + across multiple devices. In particular, it represents sharding a Tensor + along a single dimension into equal chunks (similar to :meth:`core.chunk`). + + The semantics of how a tensor is partitioned is inline with + :meth:`core.chunk`, where ``dim`` in core.chunk corresponds to the + specified ``dim`` and ``chunks`` in core.chunk is the number of elements + in the placement specified. + + Args: + dim (int or str): + The dimension to shard on, could be an integer representing the + dimension or a string in case of named tensors where dimensions are + named. Note that named tensor support is not added yet. + placement(List[Union[_remote_device, str]]): + Specifies the placement of each shard of the Tensor. The size of + the list represents the number of shards to be created. This could + be a list of + :class:`core.distributed._remote_device`'s. This list + could also contain a string which represents remote + device as accepted by + :class:`core.distributed._remote_device` + """ + + ShardingDim = Union[int, str] + + dim: ShardingDim + placements: List[Union[core.distributed._remote_device, str]] + + def __post_init__(self): + self._verify_dim(self.dim) + for i, remote_device in enumerate(self.placements): + if not isinstance(remote_device, core.distributed._remote_device): + self.placements[i] = core.distributed._remote_device(remote_device) + + @staticmethod + def _verify_dim(dim): + # Validate the sharding spec. + # TODO: support named dimension + if isinstance(dim, str): + raise NotImplementedError( + "ChunkShardingSpec does not support named dimension yet!" + ) + + if not isinstance(dim, int): + raise ValueError(f"Sharding dim needs to be an integer, found: {dim}") + + def build_metadata( + self, + tensor_sizes: core.Size, + tensor_properties: sharded_tensor_meta.TensorProperties, + ) -> sharded_tensor_meta.ShardedTensorMetadata: + tensor_num_dim = len(tensor_sizes) + + self._verify_dim(self.dim) + if self.dim >= tensor_num_dim or self.dim < -tensor_num_dim: # type: ignore[operator] + raise ValueError(f"Invalid sharding dim: {self.dim}") + + shards_metadata = [] + sharding_dim_size = tensor_sizes[self.dim] # type: ignore[index] + chunks = len(self.placements) + split_size = get_split_size(sharding_dim_size, chunks) + for idx, placement in enumerate(self.placements): + # generate ShardMetadata for each placement device + chunked_dim_size = get_chunked_dim_size(sharding_dim_size, split_size, idx) + shard_size = list(tensor_sizes) + current_offsets = [0] * tensor_num_dim + current_offsets[self.dim] = split_size * idx # type: ignore[index] + shard_size[self.dim] = chunked_dim_size # type: ignore[index] + + shard_metadata = ShardMetadata( + shard_offsets=current_offsets, + shard_sizes=shard_size, + placement=placement, + ) + shards_metadata.append(shard_metadata) + + return sharded_tensor_meta.ShardedTensorMetadata( + shards_metadata, tensor_sizes, tensor_properties + ) + + def shard( + self, tensor: core.Tensor, src_rank: int = 0, process_group=None + ) -> "ShardedTensor": + """ + Args: + src_rank: group rank relative to ``process_group`` + + N.B. If ``process_group`` is None, ``src_rank`` is a global rank. + """ + # relative imports to avoid circular dependency + from core.distributed._shard.sharded_tensor import ShardedTensor + + tensor_properties = sharded_tensor_meta.TensorProperties( + dtype=tensor.dtype, + layout=tensor.layout, + requires_grad=tensor.requires_grad, + memory_format=core.contiguous_format, + pin_memory=tensor.is_pinned(), + ) + current_rank = dist.get_rank(process_group) + current_global_rank = dist.get_rank() + tensor_meta = self.build_metadata(tensor.size(), tensor_properties) + local_shards = [] + local_tensor = None + local_metadata = None + + tensors_to_scatter = cast( + List[Optional[core.Tensor]], + [None] * dist.get_world_size(process_group), + ) + + sharding_dim_size = tensor.size()[self.dim] # type: ignore[index] + chunks = len(self.placements) + split_size = get_split_size(sharding_dim_size, chunks) + scatter_shape = list(tensor.size()) + scatter_shape[self.dim] = split_size # type: ignore[index] + + for shard_meta in tensor_meta.shards_metadata: + remote_global_rank, device = _parse_and_validate_remote_device( + process_group, shard_meta.placement + ) + if current_rank == src_rank: + # Reshape to get shard for this rank and we don't want autograd + # recording here for the narrow op and 'local_shard' should be a + # leaf variable in the autograd graph. + narrowed_tensor = narrow_tensor(tensor, shard_meta) + if shard_meta.shard_sizes[self.dim] < split_size: # type: ignore[index] + # for the last shard that might be smaller to other shards + # resize the narrowed tensor to the same size and use it for + # the scatter collective as dist.scatter requires same size + # inputs on every rank + tensor_to_scatter = ( + narrowed_tensor.detach().clone().resize_(scatter_shape) + ) + else: + tensor_to_scatter = narrowed_tensor.detach().clone().contiguous() + + tensors_to_scatter[ + dist.get_group_rank(process_group, remote_global_rank) + ] = tensor_to_scatter + + if current_global_rank == remote_global_rank: + local_tensor = core.empty( + scatter_shape, + dtype=tensor.dtype, + layout=tensor.layout, + device=device, + ) + local_metadata = shard_meta + + # each rank should have local_tensor and local_metadata initialized if we build + # the metadata list in a correct way. + assert local_tensor is not None + assert local_metadata is not None + + # Scatter the shards to all ranks in the pg + # scatter takes the global rank as ``src`` + src_for_scatter = src_rank + if ( + process_group is not None + and process_group is not distributed_c10d._get_default_group() + ): + src_for_scatter = distributed_c10d.get_global_rank( + process_group, src_for_scatter + ) + + tensors_to_scatter_: Optional[List[core.Tensor]] = None + if current_rank == src_rank: + tensors_to_scatter_ = [] + for t in tensors_to_scatter: + assert isinstance(t, core.Tensor) + tensors_to_scatter_.append(t) + + dist.scatter( + local_tensor, + scatter_list=tensors_to_scatter_, + src=src_for_scatter, + group=process_group, + ) + + if list(local_tensor.size()) != local_metadata.shard_sizes: + # detach again after receiving to ensure local shards remain a leaf node + local_tensor = local_tensor.resize_(local_metadata.shard_sizes).detach() + + # Sync requires_grad to local_shard. + local_tensor.requires_grad = tensor.requires_grad + + local_shards.append(Shard(tensor=local_tensor, metadata=local_metadata)) + + st = ShardedTensor._init_from_local_shards_and_global_metadata( + local_shards, tensor_meta, process_group=process_group + ) + + # Manually set sharding_spec + st._sharding_spec = self + + return st diff --git a/mindnlp/core/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/__init__.py b/mindnlp/core/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/mindnlp/core/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/_common.py b/mindnlp/core/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/_common.py new file mode 100644 index 000000000..c3039e2b5 --- /dev/null +++ b/mindnlp/core/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/_common.py @@ -0,0 +1,348 @@ +# mypy: allow-untyped-defs + +from mindnlp import core +from mindnlp import core.distributed as dist +from core.distributed._shard.sharded_tensor import ShardedTensor +from core.distributed._shard.sharded_tensor._ops._common import _sharded_op_common +from core.distributed._shard.sharding_spec import ChunkShardingSpec +from core.distributed._shard.sharding_spec._internals import ( + get_chunk_sharding_params, + get_chunked_dim_size, + get_split_size, +) +from core.distributed._shard.sharding_spec.api import custom_sharding_spec_op +from core.distributed.nn.functional import ( + _all_gather_base, + all_reduce, + all_to_all_single, +) + + +def _chunk_sharding_spec_check(spec, op): + """ + For the given op implementation check if the sharding spec is ChunkShardingSpec. + """ + if not isinstance(spec, ChunkShardingSpec): + raise NotImplementedError( + f"Only ChunkShardingSpec supported for '{op.__name__}'." + ) + + +def _register_sharded_op_on_local_tensor( + op, early_stop_func=None, extra_check=None, customized_func=None +): + """ + Handles ``__torch_function__`` dispatch for ops which are performed on + the single local tensor of the sharded tensor such as op like + ``core.nn.functional.softmax`` or ``core.Tensor.view``. + + For more complicated ops, a customized func can be used to generate + the new local tensor, sharding spec and sharded tensor size. + + Args: + op: The op to be registered and applied to all shards of the st. + early_stop_func (Callable, optional): the func for early stop. + Default: if ``None``, no early stop. + extra_check (Callable, optional): the func for extra condition check. + Default: if ``None``, no extra check. + customized_func (Callable, optional): the func for customized logic + to generate the new local tensor, sharding spec and sharded tensor size. + Default: if ``None``, we simply lower to the real op call with + the single local tensor of the st. + + Return: + func (Callable): registered implementation for sharded op for + ``__torch_function__`` dispatch. + """ + + @custom_sharding_spec_op(ChunkShardingSpec, op) + @_sharded_op_common(op, early_stop_func, extra_check) + def sharded_tensor_op_on_local_tensor(types, args=(), kwargs=None, pg=None): + st = args[0] + sharding_spec = st.sharding_spec() + if len(st.local_shards()) != 1: + raise TypeError( + f"torch function '{op.__name__}', with args: {args} and " + f"kwargs: {kwargs} only supported for single local tensor!" + ) + st_size = st.size() + if customized_func: + local_tensor, sharding_spec, st_size = customized_func(args, kwargs, pg) + else: + args = (st.local_tensor(), *args[1:]) + local_tensor = op(*args, **kwargs) + return ShardedTensor._init_from_local_tensor( + local_tensor.contiguous(), + sharding_spec, + st_size, # type: ignore[arg-type] + process_group=pg, + init_rrefs=st._init_rrefs, + ) + + +def _handle_col_wise_sharding_base( + op_func, + col_dim, + input, + world_size, + weight, + local_shard, + pg, + gathered_inputs, + mode=None, + gathered_per_sample_weights=None, + gathered_offsets=None, + padding_idx=None, +): + """ + For col-wise sharding of weight, lots of logic are common. + So we extract the common logic and put in this function: + Step 1. To get input from each rank and + Step 2. To perform the op on the concatenated tensor. + Step 3. To distribute results to each rank with col rearrangement. + Step 4. To concatenate all results from all ranks. + + Args: + op_func: operator which is applied to the input tensor. + col_dim: dim of result tensor after the operation. + input: tensor to be applied op on. + world_size: number of ranks. + weight: sharded weight tensor. + local_shard: col-wise sharded weight tensor. + pg: process group. + gathered_inputs: list of inputs from all ranks. If specified, we + don't need to communicate with each rank any more. + mode: aggregation mode of EmbeddingBag. + gathered_per_sample_weights: per_sample_weights across all ranks. + gathered_offsets: offsets across all ranks. + padding_idx: If specified, the entries at padding_idx do + not contribute to the gradient; therefore, the embedding + vector at padding_idx is not updated during training, + i.e. it remains as a fixed "pad". + Note that the embedding vector at padding_idx is + excluded from the reduction. + + Return: final result of input being applied with the op. + """ + # run the operator's function for all the inputs. + results = [] + for i, inp in enumerate(gathered_inputs): + if op_func == core.nn.functional.embedding_bag: + result = op_func( + inp, + local_shard, + offsets=gathered_offsets[i] if gathered_offsets is not None else None, + mode=mode, + per_sample_weights=gathered_per_sample_weights[i] + if gathered_per_sample_weights is not None + else None, + padding_idx=padding_idx, + ) + elif op_func == core.nn.functional.embedding: + result = op_func( + inp, + local_shard, + padding_idx=padding_idx, + ) + else: + result = op_func(inp, local_shard) + results.append(core.transpose(result, 0, col_dim)) + + # Distribute results to each rank with col rearrangement. + output = _result_distribute_with_col_rearrange( + results, input, world_size, weight, pg + ) + + # transpose the output and return result. + return core.transpose(output, 0, col_dim) + + +def _result_distribute_with_col_rearrange(results, input, world_size, weight, pg): + """ + For col-wise sharding of weight, we need to distribute + results to each rank. We do them in this function. + Note that, if the index in the Sharding Spec is not equal to + the rank number, we need to do the rearrangement based on the + order given by the Sharding Spec (placement). + + Args: + results: results from ops applied to inputs from all ranks. + We need to distribute them back to their original ranks. + input: tensor to be applied op to. + world_size: number of ranks. + weight: sharded weight tensor. + pg: process group. + + Return: column rearranged result. + """ + # Process results and outputs for all2all. + sharding_dim = weight._sharding_spec.dim + sharding_dim_size = weight.size(sharding_dim) + dims = list(results[0].size()) + dims[0] = sharding_dim_size + combined_results = core.cat(results) + output = core.empty( + *dims, device=combined_results.device, dtype=combined_results.dtype + ) + + # Compute output splits + split_size = get_split_size(sharding_dim_size, world_size) + output_split_sizes = [0] * world_size + for idx, placement in enumerate(weight._sharding_spec.placements): + output_split_sizes[placement.rank()] = get_chunked_dim_size( + sharding_dim_size, split_size, idx + ) + + # distribute the outputs using all2all. + output = all_to_all_single( + output, combined_results, output_split_sizes=output_split_sizes, group=pg + ) + + # Check if we need to rearrange columns appropriately for output. + rearrange_columns = any( + idx != placement.rank() + for idx, placement in enumerate(weight._sharding_spec.placements) + ) + if not rearrange_columns: + return output + + indices = [] + for placement in weight._sharding_spec.placements: + dim_size = output_split_sizes[placement.rank()] + start = sum( + split_size if i < placement.rank() else 0 + for i, split_size in enumerate(output_split_sizes) + ) + indices += list(range(start, start + dim_size)) + + return output.index_select(0, core.tensor(indices, device=output.device)) + + +def _handle_max_norm_col_wise( + max_norm, + norm_type, + local_shard, + input, + world_size, + gathered_inputs, + pg, +): + """ + For col-wise sharding of weight, we need to aggregate the + norm across all ranks before we can perform the proper re-norm. + Note that, the max_norm logic is only applied to the embedding + indices that are looked up and not the whole shard. + + Args: + max_norm: If given, each embedding vector with norm larger + than max_norm is renormalized to have norm max_norm. + Note: this will modify weight in-place. + norm_type: The p in the p-norm to compute for the max_norm option. + local_shard: col-wise shared local weight used for lookup. + input: tensor to be applied op to. + world_size: number of ranks. + gathered_inputs: list of inputs from all ranks. + pg: process group. + + Return: + local_shard_norm_renormed: local_shard re-normed to max_norm if the norm is larger + than it. + + """ + norm_type = norm_type if norm_type is not None else 2.0 + unique_inp = core.unique(core.cat(gathered_inputs)) + local_shard_sum = core.sum( + core.pow(core.abs(local_shard), norm_type), dim=1, dtype=local_shard.dtype + ) + # For col-wise sharding, we need to first aggregate the powered sum + # from each rank first and then calculate the norm. + local_shard_sum = all_reduce(local_shard_sum, group=pg) + local_shard_norm = core.pow(local_shard_sum, 1.0 / norm_type) + max_norm_tensor = core.full( + (local_shard.size(0),), + float("inf"), + dtype=local_shard.dtype, + device=input.device, + ) + max_norm_tensor[unique_inp] = max_norm + local_shard_t = local_shard.t().contiguous() + normalized_tensor = core.where( + local_shard_norm > max_norm_tensor, max_norm_tensor, local_shard_norm + ) + # Make sure divisor is not zero. + local_shard_norm[local_shard_norm == 0.0] = 1.0 + local_shard_norm_renormed = ( + core.div(core.mul(local_shard_t, normalized_tensor), local_shard_norm) + .t() + .contiguous() + ) + return local_shard_norm_renormed + + +def _all_gather_base_input(input, pg): + """ + Use _all_gather_base to get a concatenated input from each rank. + + Args: + input: tensor to be applied op on. + pg: process group. + + Returns: + gathered_inputs: input gathered from each rank and concat by dim 0. + """ + # allgather the inputs first. + gather_inp_size = list(input.size()) + gather_inp_size[0] = input.size(0) * dist.get_world_size(pg) + gather_inp = core.empty(gather_inp_size, device=input.device, dtype=input.dtype) + return _all_gather_base(gather_inp, input, group=pg) + + +def _handle_row_wise_mask(gather_inp, padding_idx, weight, world_size, rank): + """ + Mask the input for embedding look-up for IDs which are not stored + on the current rank. This function also adjust the ``padding_idx`` + so that it is only used on the rank where the corresponding row is + stored. + + Note that, with ``max_norm`` flag on, only weights of rows being + looked up will be re-normed. So we need an extra row for masked ID + so that it does not affect the final result and ``max_norm``. + + Args: + gather_inp: tensor to be applied op on gathered from all ranks. + padding_idx: If specified, the entries at padding_idx do + not contribute to the gradient; therefore, the embedding + vector at padding_idx is not updated during training, + i.e. it remains as a fixed "pad". + Note that the embedding vector at padding_idx is + excluded from the reduction. + weight: weight tensor of Embedding look-up table. + world_size: number of ranks. + rank: # of cuda process. + + Returns: + lookup_input: Tensor of masked input. + padding_idx: adjusted padding_idx. + padding_row: The extra row we used during lookup so that + looking up does not affect ``max_norm``. + """ + (start_pos, chunk_size) = get_chunk_sharding_params( + weight.size(0), world_size, weight._sharding_spec, rank + ) + mask = (gather_inp < start_pos) | (gather_inp >= start_pos + chunk_size) + lookup_input = gather_inp.clone() - start_pos + lookup_input[mask] = chunk_size + if ( + padding_idx is not None + and padding_idx >= start_pos + and padding_idx < (start_pos + chunk_size) + ): + padding_idx = padding_idx - start_pos + else: + padding_idx = None + + # When max_norm is set, it will only re-norm the row being looked up. + padding_row = core.zeros( + 1, weight.size(1), device=gather_inp.device, dtype=weight.dtype + ) + return lookup_input, padding_idx, padding_row diff --git a/mindnlp/core/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/embedding.py b/mindnlp/core/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/embedding.py new file mode 100644 index 000000000..c5d53e47d --- /dev/null +++ b/mindnlp/core/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/embedding.py @@ -0,0 +1,294 @@ +# mypy: allow-untyped-defs + +from mindnlp import core +from mindnlp import core.distributed as dist +from core.distributed._shard.sharded_tensor import ShardedTensor +from core.distributed._shard.sharding_spec import ChunkShardingSpec +from core.distributed._shard.sharding_spec.api import custom_sharding_spec_op +from core.distributed.nn.functional import all_gather, reduce_scatter + +from ._common import ( + _all_gather_base_input, + _handle_col_wise_sharding_base, + _handle_max_norm_col_wise, + _handle_row_wise_mask, +) + + +@custom_sharding_spec_op(ChunkShardingSpec, core.nn.functional.embedding) +def sharded_embedding(types, args, kwargs, pg): + """ + Handles ``__torch_function__`` dispatch for ``core.nn.functional.embedding``. + This method computes a sharded embedding lookup and has the following limitations: + + 1. Supports only sharding of ``weight``. + 2. Supports only ``ChunkShardingSpec``. + 3. Supports only a single local shard per rank. + 4. Supports all specs except for scale_grad_by_freq, sparse, etc. + + Based on the dimension that the weight is sharded on, there are two + algorithms: + + ROWWISE SHARDING + ================ + For row-wise sharding the weight is sharded on dimension 0. + + The overall algorithm can be best explained with an example. Let's assume + the dims for input are (4 x 6) and W are (10 x 17) and W is sharded across + 4 GPUs creating 3 shard of (3 x 17) and 1 shard of (1 x 17). + The algorithm is as follows: + + 1. First the input is all gathered to all ranks, since this is SPMD and + input is actually sharded across all ranks. The inputs then become a + 4 (4 x 6) tensor on each rank. For example if the given input is + tensor([[6, 5, 2, 9, 6, 3], + [3, 1, 2, 4, 7, 6], + [4, 0, 4, 9, 8, 9], + [8, 6, 6, 4, 6, 1]]) + on rank 0. + Then on every rank, we will have this tensor. + If input itself is already replicated, no all-gather will be done. + 2. Next, we mask the ID which are not stored on that rank. + For example on rank 0, we store ID [0, 1, 2]. We only keep the ID + inside the set of numbers. The rest of them will be masked to an extra row. + The masked matrix will be used for embedding look up and is like: + tensor([[4, 4, 2, 4, 4, 4], + [4, 1, 2, 4, 4, 4], + [4, 0, 4, 4, 4, 4], + [4, 4, 4, 4, 4, 1]]) + The reason of having an extra row (aka, number 4 in the example) is + because when max_norm is specified only weight which has looked will + be re-normed so mask IDs whose embeddings are not stored in current + rank will to an extra row will ensure max_norm still works as expected. + 3. If max_norm is specified, the extra row guarantees that the mask ID will + not affect the behavior of weigh re-norm. + + COLWISE SHARDING + ================ + For col-wise sharding the weight is sharded on dimension 1. + + The overall algorithm can be best explained with an example. Let's assume + the dims for input are (4 x 6) and W are (16 x 17) and W is sharded across + 4 GPUs creating 3 shards of (16 x 5) and 1 shard of (16 x 2). + The algorithm is as follows: + + 1. First the input is broadcasted to all ranks, since this is SPMD we + actually do an all_gather for all the inputs resulting in 4 (4 x 6) + inputs on each rank. + 2. Next we perform local embedding lookup operation by apply each + input (4 x 6) with the local shard (16 x 5) ((16 x 2) for the last). + This results in 4 (5 x 6 x 4) ((2 x 6 x 4) for the last) matrices + on each rank. We transpose dim 0 and dim 2. + 3. Next, we concat these 4 matrices and perform an all2all to share the + appropriate (5 x 6 x 4) or (2 x 6 x 4) matrices to each rank. + 4. Now, each rank receives a (17 x 6 x 4) matrix which is basically the + size of the result we need. + 5. If placements are not in order any appropriate rearrangement of columns + are done for the (17 x 6 x 4) matrix and finally we transpose the + dim 0 and dim 2 again. + 6. If max_norm is specified, we manually sum up the norm and renorm. Because + the renorm must be in place, we need to override the local_shard to mimic + this behavior. + """ + # Validate input params + _validate_embedding_param(args, kwargs) + + input = args[0] + weight = args[1] + max_norm = kwargs.get("max_norm") + norm_type = kwargs.get("norm_type") + padding_idx = kwargs.get("padding_idx") + + local_shard = weight.local_tensor().contiguous() + sharding_dim = weight._sharding_spec.dim + world_size = dist.get_world_size(pg) + rank = dist.get_rank(pg) + + if sharding_dim == 1: + output, local_shard = _handle_col_wise_sharding( + input, world_size, weight, local_shard, max_norm, norm_type, padding_idx, pg + ) + weight.local_shards()[0].tensor = local_shard + return output + elif sharding_dim == 0: + return _handle_row_wise_sharding( + input, + world_size, + weight, + local_shard, + max_norm, + norm_type, + padding_idx, + rank, + pg, + ) + else: + raise RuntimeError( + f"nn.Embedding weight sharded on dim {sharding_dim} not supported!" + ) + + +def _validate_embedding_param(args, kwargs): + """ + Validate input params of sharded embedding op. + + Args: + input: list of ID used for lookup. + weight: sharded weight tensor. + kwargs: same as normal Embedding. + + Return: None. + """ + + input = args[0] + weight = args[1] + max_norm = kwargs.get("max_norm") + scale_grad_by_freq = kwargs.get("scale_grad_by_freq") + sparse = kwargs.get("sparse") + + # Validate types + if not isinstance(input, core.Tensor): + raise TypeError("input need to be core.Tensor") + if not isinstance(weight, ShardedTensor): + raise TypeError("weight needs to be ShardedTensor") + weight_size = weight.size() + if len(weight_size) != 2: + raise ValueError("Weight needs to have exactly 2 dims") + if int(core.min(input).item()) < 0: + raise ValueError( + "Index out of range in Input %d %d", + int(core.min(input).item()), + weight_size[1], + ) + if int(core.max(input).item()) >= weight_size[0]: + raise ValueError( + "Index out of range in Input %d %d", + int(core.max(input).item()), + weight_size[1], + ) + if scale_grad_by_freq: + raise RuntimeError( + 'nn.Embedding weight sharded with flag on "scale_grad_by_freq" not supported!' + ) + if sparse: + raise RuntimeError( + 'nn.Embedding weight sharded with flag on "sparse" not supported!' + ) + if max_norm and max_norm <= 0.0: + raise ValueError('"max_norm" must be larger than zero!') + + if not isinstance(weight._sharding_spec, ChunkShardingSpec): + raise ValueError("Only ChunkShardingSpec supported for ShardedTensor ops!") + if len(weight.local_shards()) != 1: + raise ValueError("Only one local shard supported!") + + +def _handle_col_wise_sharding( + input, world_size, weight, local_shard, max_norm, norm_type, padding_idx, pg +): + """ + Entry-point function to handle the logic of col-wise sharding of weight + for embedding. (Detailed explanations of the logic can be found in + the comment for sharded_embedding.) + + Args: + input: list of ID used for lookup and aggregation. + world_size: number of ranks. + weight: sharded weight tensor. + local_shard: col-wise shared local weight used for lookup. + max_norm: If given, each embedding vector with norm larger + than max_norm is renormalized to have norm max_norm. + Note: this will modify weight in-place. + norm_type: The p in the p-norm to compute for the max_norm option. + padding_idx: If specified, the entries at padding_idx do + not contribute to the gradient; therefore, the embedding + vector at padding_idx is not updated during training, + i.e. it remains as a fixed "pad". + pg: process group. + + Returns: final result of lookup. + """ + # allgather the inputs first for non Replicated Tensor. + gathered_inputs = all_gather(input, group=pg) + + if max_norm is not None: + # max_norm changes the weight in-place + local_shard = _handle_max_norm_col_wise( + max_norm, norm_type, local_shard, input, world_size, gathered_inputs, pg + ) + + output = _handle_col_wise_sharding_base( + core.nn.functional.embedding, + len(input.size()), + input, + world_size, + weight, + local_shard, + pg, + gathered_inputs, + padding_idx=padding_idx, + ) + return (output, local_shard) + + +def _handle_row_wise_sharding( + input, world_size, weight, local_shard, max_norm, norm_type, padding_idx, rank, pg +): + """ + Entry-point function to handle the logic of row-wise sharding of weight + for embedding. (Detailed explanations of the logic can be found in + the comment for sharded_embedding.) + + Args: + input: list of ID used for lookup and aggregation. + world_size: number of ranks. + weight: sharded weight tensor. + local_shard: row-wise shared local weight used for lookup. + max_norm: If given, each embedding vector with norm larger + than max_norm is renormalized to have norm max_norm. + Note: this will modify weight in-place. + norm_type: The p in the p-norm to compute for the max_norm option. + padding_idx: If specified, the entries at padding_idx do + not contribute to the gradient; therefore, the embedding + vector at padding_idx is not updated during training, + i.e. it remains as a fixed "pad". + rank: # of cuda process. + pg: process group. + + Returns: final result of lookup. + """ + # allgather the inputs first for non Replicated Tensor. + gather_inp = _all_gather_base_input(input, pg) + + # Mask the input according to sharding spec. + lookup_input, padding_idx, padding_row = _handle_row_wise_mask( + gather_inp, padding_idx, weight, world_size, rank + ) + + # When input is a large tensor, the value of weight is changed. + # This is a walk-around for now. GH issue: #81717 + if max_norm is not None: + core.nn.functional.embedding( + core.unique(lookup_input)[:-1], + local_shard, + padding_idx=padding_idx, + max_norm=max_norm, + norm_type=norm_type, + ) + max_norm = None + + local_input_embeddings = core.nn.functional.embedding( + lookup_input, + core.cat([local_shard, padding_row]), + padding_idx=padding_idx, + max_norm=max_norm, + norm_type=norm_type, + ) + + # TODO: Make the result a PartialTensor. + local_shards = local_input_embeddings.chunk(pg.size()) + return reduce_scatter( + core.empty_like(local_shards[0]), + list(local_shards), + group=pg, + ) diff --git a/mindnlp/core/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/embedding_bag.py b/mindnlp/core/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/embedding_bag.py new file mode 100644 index 000000000..3032fb240 --- /dev/null +++ b/mindnlp/core/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/embedding_bag.py @@ -0,0 +1,477 @@ +# mypy: allow-untyped-defs + +from typing import cast, List + +from mindnlp import core +from mindnlp import core.distributed as dist +from core.distributed._shard.sharded_tensor import ShardedTensor +from core.distributed._shard.sharding_spec import ChunkShardingSpec +from core.distributed._shard.sharding_spec.api import custom_sharding_spec_op +from core.distributed.nn.functional import all_gather, reduce_scatter +from ....c10d import ReduceOp + +from ._common import ( + _all_gather_base_input, + _handle_col_wise_sharding_base, + _handle_max_norm_col_wise, + _handle_row_wise_mask, +) + + +# @custom_sharding_spec_op(ChunkShardingSpec, core.nn.functional.embedding_bag) +def sharded_embedding_bag(types, args, kwargs, pg): + """ + Handles ``__torch_function__`` dispatch for ``core.nn.functional.embedding_bag``. + This method computes a sharded embedding bag aggregation and has the following limitations: + + 1. Supports only sharding of ``weight``. + 2. Supports only ``ChunkShardingSpec``. + 3. Supports only a single local shard per rank. + 4. Supports all specs except for scale_grad_by_freq, sparse, etc. + + Based on the dimension that the weight is sharded on, there are two + algorithms: + + ROWWISE SHARDING + ================ + For row-wise sharding the weight is sharded on dimension 0. + + The overall algorithm can be best explained with an example. Let's assume + the dims for input are (4 x 6) and W are (16 x 17) and W is sharded across + 4 GPUs creating 4 shard of (4 x 17). + The algorithm is as follows: + + 1. First the input is all gathered to all ranks, since this is SPMD and + input is actually sharded across all ranks. The inputs then become a + 4 (4 x 6) tensor on each rank. For example if the given input is + tensor([[6, 5, 2, 9, 6, 3], + [3, 1, 2, 4, 7, 6], + [4, 0, 4, 9, 8, 9], + [8, 6, 6, 4, 6, 1]]) + on rank 0. + Then on every rank, we will have this tensor. + If input itself is already replicated, no all-gather will be done. + 2. Next, we mask the ID which are not stored on that rank. + For example on rank 0, we store ID [0, 1, 2]. We only keep the ID + inside the set of numbers. The rest of them will be masked to an extra row. + The masked matrix will be used for embedding look up and is like: + tensor([[4, 4, 2, 4, 4, 4], + [4, 1, 2, 4, 4, 4], + [4, 0, 4, 4, 4, 4], + [4, 4, 4, 4, 4, 1]]) + 3. If ``max_norm`` is specified, the extra row guarantees that the mask ID will + not affect the behavior of weigh re-norm. + 4. The example above only happens in one rank and each rank does a very similar thing. + For "Mean" mode we need to divide by either column size (2D) or the interval length + defined by the offset (excluding the row specified in ``padding_idx``). + We also need to mask the unexisting row to neg Inf so that negative value does not + gets wiped out in the "Max" mode. + + COLWISE SHARDING + ================ + For col-wise sharding the weight is sharded on dimension 1. + + The overall algorithm can be best explained with an example. Let's assume + the dims for input are (4 x 6) and W are (16 x 17) and W is sharded across + 4 GPUs creating 3 shards of (16 x 5) and 1 shard of (16 x 2). + The algorithm is as follows: + + 1. First the input is broadcasted to all ranks, since this is SPMD we + actually do an all_gather for all the inputs resulting in 4 (4 x 6) + inputs on each rank. + 2. Next we perform local embedding bag operation under the given mode by + apply each input (4 x 6) with the local shard (16 x 5) ((16 x 2) for the last). + This results in 4 (5 x 4) ((2 x 4) for the last) matrices on each rank. + We transpose the aggregation result. + 3. Next, we concatenate these 4 matrices and perform an all2all to share the + appropriate (5 x 4) or (2 x 4) matrices to each rank. + 4. Now, each rank receives a (17 x 4) matrix which is basically the + size of the result we need. + 5. If placements are not in order any appropriate rearrangement of columns + are done for the (17 x 4) matrix and finally we transpose the output again. + 6. If max_norm is specified, we manually sum up the norm and renorm. Because + the renorm must be in place, we need to override the local_shard to mimic + this behavior. + """ + # Validate input params + _validate_embedding_bag_param(args, kwargs) + + input = args[0] + weight = args[1] + offsets = kwargs.get("offsets") + per_sample_weights = kwargs.get("per_sample_weights") + mode = kwargs.get("mode") + max_norm = kwargs.get("max_norm") + norm_type = kwargs.get("norm_type") + include_last_offset = kwargs.get("include_last_offset") + padding_idx = kwargs.get("padding_idx") + + local_shard = weight.local_tensor().contiguous() + sharding_dim = weight._sharding_spec.dim + world_size = dist.get_world_size(pg) + rank = dist.get_rank(pg) + if include_last_offset: + offsets = offsets[:-1] + + if sharding_dim == 1: + output, local_shard = _handle_col_wise_sharding( + input, + world_size, + weight, + local_shard, + offsets, + per_sample_weights, + mode, + max_norm, + norm_type, + padding_idx, + pg, + ) + weight.local_shards()[0].tensor = local_shard + return output + elif sharding_dim == 0: + return _handle_row_wise_sharding( + input, + world_size, + weight, + local_shard, + offsets, + per_sample_weights, + mode, + max_norm, + norm_type, + padding_idx, + rank, + pg, + ) + else: + raise RuntimeError( + f"nn.EmbeddingBag weight sharded on dim {sharding_dim} not supported!" + ) + + +def _validate_embedding_bag_param(args, kwargs): + """ + Validate input params of sharded embeddingBag op. + + Args: + input: list of ID used for lookup and aggregation. + weight: sharded weight tensor. + kwargs: same as normal EmbeddingBag. + + Return: None. + """ + + input = args[0] + weight = args[1] + offsets = kwargs.get("offsets") + per_sample_weights = kwargs.get("per_sample_weights") + mode = kwargs.get("mode") + max_norm = kwargs.get("max_norm") + scale_grad_by_freq = kwargs.get("scale_grad_by_freq") + sparse = kwargs.get("sparse") + include_last_offset = kwargs.get("include_last_offset") + + # Validate types + if not isinstance(input, core.Tensor): + raise TypeError("input need to be core.Tensor") + if offsets is not None and not isinstance(offsets, core.Tensor): + raise TypeError("offsets need to be core.Tensor") + if per_sample_weights is not None and not isinstance( + per_sample_weights, core.Tensor + ): + raise TypeError("per_sample_weights need to be core.Tensor") + if not isinstance(weight, ShardedTensor): + raise TypeError("weight needs to be ShardedTensor") + if len(input.size()) > 2: + raise ValueError("Input more than 2 dims not supported") + weight_size = weight.size() + if len(weight_size) != 2: + raise ValueError("Weight needs to have exactly 2 dims") + if int(core.min(input).item()) < 0: + raise ValueError( + "Index out of range in Input %d %d", + int(core.min(input).item()), + weight_size[1], + ) + if int(core.max(input).item()) >= weight_size[0]: + raise ValueError( + "Index out of range in Input %d %d", + int(core.max(input).item()), + weight_size[1], + ) + if offsets is not None and len(input.size()) != 1: + raise ValueError("Input dimension needs to be exactly 1 dim") + if len(input.size()) == 1 and offsets is None: + raise ValueError("offsets is required for 1D input") + if per_sample_weights is not None and per_sample_weights.size() != input.size(): + raise ValueError( + f"per_sample_weights size {per_sample_weights.size()} not equal to input size {input.size()}" + ) + if mode is None: + mode = "mean" + if mode not in ["sum", "mean", "max"]: + raise ValueError(f"mode '{mode}' is not supported") + if scale_grad_by_freq: + raise RuntimeError( + 'nn.Embedding weight sharded with flag on "scale_grad_by_freq" not supported!' + ) + if sparse: + raise RuntimeError( + 'nn.Embedding weight sharded with flag on "sparse" not supported!' + ) + if include_last_offset and offsets is None: + raise ValueError('offsets is required for flag "include_last_offset"!') + if include_last_offset and cast(List[int], offsets)[-1] != input.size(0): + raise ValueError( + 'offsets need to have the input size in the end when the flag "include_last_offset" is on!' + ) + + if max_norm and max_norm <= 0.0: + raise ValueError('"max_norm" must be larger than zero!') + + if not isinstance(weight._sharding_spec, ChunkShardingSpec): + raise ValueError("Only ChunkShardingSpec supported for ShardedTensor ops!") + if len(weight.local_shards()) != 1: + raise ValueError("Only one local shard supported!") + + +def _handle_col_wise_sharding( + input, + world_size, + weight, + local_shard, + offsets, + per_sample_weights, + mode, + max_norm, + norm_type, + padding_idx, + pg, +): + """ + Entry-point function to handle the logic of col-wise sharding of weight + for embeddingBag. (Detailed explanations of the logic can be found in + the comment for sharded_embedding_bag.) + + Args: + input: list of ID used for lookup and aggregation. + world_size: number of ranks. + weight: sharded weight tensor. + local_shard: col-wise shared local weight used for lookup. + offsets: list of start positions of each bag for 1D input. + per_sample_weights: weights for weighted sum mode. + mode: aggregation method of each bag. + max_norm: If given, each embedding vector with norm larger + than max_norm is renormalized to have norm max_norm. + Note: this will modify weight in-place. + norm_type: The p in the p-norm to compute for the max_norm option. + padding_idx: If specified, the entries at padding_idx do + not contribute to the gradient; therefore, the embedding + vector at padding_idx is not updated during training, + i.e. it remains as a fixed "pad". + Note that the embedding vector at padding_idx is + excluded from the reduction. + pg: process group. + + Return: + output: final result of lookup and aggregation. + local_shard: col-wise shared local weight used for lookup. + If max_norm, this will be the renormed weight. + """ + # allgather the special input of embedding bag first. + ( + gathered_inputs, + gathered_per_sample_weights, + gathered_offsets, + ) = _all_gather_embedding_bag_input(input, per_sample_weights, offsets, pg) + + if max_norm is not None: + # max_norm changes the weight in-place + local_shard = _handle_max_norm_col_wise( + max_norm, norm_type, local_shard, input, world_size, gathered_inputs, pg + ) + + output = _handle_col_wise_sharding_base( + core.nn.functional.embedding_bag, + 1, + input, + world_size, + weight, + local_shard, + pg, + gathered_inputs, + mode=mode, + gathered_per_sample_weights=gathered_per_sample_weights, + gathered_offsets=gathered_offsets, + padding_idx=padding_idx, + ) + return (output, local_shard) + + +def _handle_row_wise_sharding( + input, + world_size, + weight, + local_shard, + offsets, + per_sample_weights, + mode, + max_norm, + norm_type, + padding_idx, + rank, + pg, +): + """ + Entry-point function to handle the logic of row-wise sharding of weight + for embeddingBag. (Detailed explanations of the logic can be found in + the comment for sharded_embedding_bag.) + + Args: + input: list of ID used for lookup and aggregation. + world_size: number of ranks. + weight: sharded weight tensor. + local_shard: row-wise shared local weight used for lookup. + offsets: list of start positions of each bag for 1D input. + per_sample_weights: weights for weighted sum mode. + mode: aggregation method of each bag. + max_norm: If given, each embedding vector with norm larger + than max_norm is renormalized to have norm max_norm. + Note: this will modify weight in-place. + norm_type: The p in the p-norm to compute for the max_norm option. + padding_idx: If specified, the entries at padding_idx do + not contribute to the gradient; therefore, the embedding + vector at padding_idx is not updated during training, + i.e. it remains as a fixed "pad". + Note that the embedding vector at padding_idx is + excluded from the reduction. + rank: # of cuda process. + pg: process group. + + Returns: + gathered_output: final result of lookup and aggregation. + """ + if input.dim() > 1 and per_sample_weights is None: + # allgather the inputs first for non Replicated Tensor. + gather_inp = _all_gather_base_input(input, pg) + else: + ( + gathered_inputs, + gathered_per_sample_weights, + gathered_offsets, + ) = _all_gather_embedding_bag_input(input, per_sample_weights, offsets, pg) + cat_dim = 0 if input.dim() != 1 else -1 + gather_inp = core.cat(gathered_inputs, dim=cat_dim) + if per_sample_weights is not None: + per_sample_weights = core.cat(gathered_per_sample_weights, dim=cat_dim) + offset_add = 0 if input.dim() > 1 else input.size(0) + if offsets is not None: + offsets_list = core.cat( + [gathered_offsets[i] + (offset_add * i) for i in range(pg.size())], + dim=cat_dim, + ) + + # Mask the input according to sharding spec. + lookup_input, padding_local, padding_row = _handle_row_wise_mask( + gather_inp, padding_idx, weight, world_size, rank + ) + if mode == "max": + padding_row[:] = -float("Inf") + + # When input is a large tensor, the value of weight is changed. + # This is a walk-around for now. GH issue: #81717. + if max_norm is not None: + core.nn.functional.embedding_bag( + core.unique(lookup_input)[:-1], + local_shard, + offsets=core.tensor([0], device=local_shard.device, dtype=core.long), + mode=mode, + per_sample_weights=None, + max_norm=max_norm, + norm_type=norm_type, + padding_idx=padding_local, + ) + max_norm = None + result = core.nn.functional.embedding_bag( + lookup_input, + core.cat([local_shard, padding_row]), + offsets=offsets_list if offsets is not None else offsets, # type: ignore[possibly-undefined] + mode=mode if mode != "mean" else "sum", + per_sample_weights=per_sample_weights, + max_norm=max_norm, + norm_type=norm_type, + padding_idx=padding_local, + ) + + op = ReduceOp.SUM if mode != "max" else ReduceOp.MAX + # TODO: Make the result a PartialTensor and move the logic below there. + local_shards = result.chunk(pg.size()) + result = reduce_scatter( + core.empty_like(local_shards[0]), + list(local_shards), + op=op, + group=pg, + ) + + # For Mean, we cannot do the division until very end because the sum of means + # not equal to the mean of sum. (Divisor is different) + if mode == "mean": + if input.dim() > 1: + padding_idx = padding_idx if padding_idx is not None else -1 + split_sizes = core.sum( + core.ne(input, padding_idx), dim=-1, dtype=local_shard.dtype + ) + else: + split_sizes = core.cat( + ( + offsets[1 : offsets.size(0)] - offsets[0:-1], + (input.size(0) - offsets[-1]).unsqueeze(0), + ), + dim=-1, + ) + return core.div(result, split_sizes.unsqueeze(1)) + + # Return the appropriate local result. + return result + + +def _all_gather_embedding_bag_input(input, per_sample_weights, offsets, pg): + """ + In case we need to gather input and all other parameters of embeddingBag + ops, we need to stack all input together to perform ``all_gather`` + collective communication just once. + + Note that since offsets does not share the same size as input and + is always smaller than input, we resize it during the communication. + + Args: + input: tensor to be applied op on. + per_sample_weights: weights for weighted sum mode. + offsets: when input is 1D. offsets determines the starting + index position of each bag (sequence) in input. + pg: process group. + + Returns: + gathered_inputs: list of input tensor gathered from each rank. + gathered_per_sample_weights: list of per_sample_weights from each rank. + gathered_offsets: list of offsets from each rank. + """ + input_to_gather = [input] + if per_sample_weights is not None: + input_to_gather.append(per_sample_weights) + if offsets is not None: + input_to_gather.append(offsets.clone().resize_(input.size())) + gathered_inputs = all_gather(core.stack(input_to_gather), group=pg) + + gathered_per_sample_weights = None + if per_sample_weights is not None: + gathered_per_sample_weights = [t[1] for t in gathered_inputs] + gathered_offsets = None + if offsets is not None: + idx = 2 if per_sample_weights is not None else 1 + gathered_offsets = [ + t[idx].resize_(offsets.size()).to(offsets.dtype) for t in gathered_inputs + ] + gathered_inputs = [t[0].to(input.dtype) for t in gathered_inputs] + return gathered_inputs, gathered_per_sample_weights, gathered_offsets diff --git a/mindnlp/core/distributed/_sharded_tensor/__init__.py b/mindnlp/core/distributed/_sharded_tensor/__init__.py new file mode 100644 index 000000000..bf4c967a5 --- /dev/null +++ b/mindnlp/core/distributed/_sharded_tensor/__init__.py @@ -0,0 +1,21 @@ +# Keep old package for BC purposes, this file should be removed once +# everything moves to the `core.distributed._shard` package. +import sys +import warnings + +from mindnlp import core +from core.distributed._shard.sharded_tensor import * # noqa: F403 + + +with warnings.catch_warnings(): + warnings.simplefilter("always") + warnings.warn( + "`core.distributed._sharded_tensor` will be deprecated, " + "use `core.distributed._shard.sharded_tensor` instead", + DeprecationWarning, + stacklevel=2, + ) + +sys.modules[ + "core.distributed._sharded_tensor" +] = core.distributed._shard.sharded_tensor diff --git a/mindnlp/core/distributed/_sharding_spec/__init__.py b/mindnlp/core/distributed/_sharding_spec/__init__.py new file mode 100644 index 000000000..4ab86dc32 --- /dev/null +++ b/mindnlp/core/distributed/_sharding_spec/__init__.py @@ -0,0 +1,22 @@ +# Keep old package for BC purposes, this file should be removed once +# everything moves to the `core.distributed._shard` package. +import sys +import warnings + +from mindnlp import core +from core.distributed._shard.sharding_spec import * # noqa: F403 + + +with warnings.catch_warnings(): + warnings.simplefilter("always") + warnings.warn( + "`core.distributed._sharding_spec` will be deprecated, " + "use `core.distributed._shard.sharding_spec` instead", + DeprecationWarning, + stacklevel=2, + ) + +from mindnlp import core.distributed._shard.sharding_spec as _sharding_spec + + +sys.modules["core.distributed._sharding_spec"] = _sharding_spec diff --git a/mindnlp/core/distributed/_state_dict_utils.py b/mindnlp/core/distributed/_state_dict_utils.py new file mode 100644 index 000000000..8fd5279c8 --- /dev/null +++ b/mindnlp/core/distributed/_state_dict_utils.py @@ -0,0 +1,756 @@ +# mypy: allow-untyped-defs +import copy +import io +import math +import weakref +from typing import ( + Any, + Callable, + cast, + Dict, + List, + Mapping, + MutableMapping, + NamedTuple, + Optional, + Tuple, + TYPE_CHECKING, + Union, +) + +from mindnlp import core +from mindnlp import core.distributed as dist +from mindnlp import core.nn.functional as F +from core.distributed._functional_collectives import AsyncCollectiveTensor + + +if dist.is_available() or TYPE_CHECKING: + from core.distributed import distributed_c10d + from core.distributed._shard.sharded_tensor import ShardedTensor + # from core.distributed.tensor import distribute_tensor, DTensor, Replicate + # from core.distributed.tensor._utils import compute_local_shape_and_global_offset + + +def _identity_func( + obj: core.Tensor, + pg: Optional[dist.ProcessGroup], + device: Optional[core.device], + companion_obj: Any, +) -> core.Tensor: + return obj + + +def _all_gather_sharded_tensor( + sharded_tensor: "ShardedTensor", + pg: Optional[dist.ProcessGroup] = None, + device: Optional[core.device] = None, +) -> core.Tensor: + if pg is None: + pg = distributed_c10d._get_default_group() + world_size = dist.get_world_size(pg) + shards = sharded_tensor.local_shards() + dim_0_size = sharded_tensor.size()[0] # type: ignore[index] + tensor_numel = sharded_tensor.size().numel() # type: ignore[union-attr] + chunk_size = math.ceil(dim_0_size / world_size) * tensor_numel // dim_0_size + pg_device = ( + distributed_c10d._get_pg_default_device(pg) if device is None else device + ) + if shards: + local_tensor = shards[0].tensor.flatten() + if local_tensor.device.type != pg_device.type: + local_tensor = local_tensor.to(pg_device) + num_padding = chunk_size - local_tensor.numel() + if num_padding > 0: + local_tensor = F.pad(local_tensor, [0, num_padding]) + else: + local_tensor = core.zeros( + chunk_size, dtype=sharded_tensor.dtype, device=pg_device + ) + + tensor = core.empty( + chunk_size * world_size, + dtype=local_tensor.dtype, + device=pg_device, + ) + dist.all_gather_into_tensor(tensor, local_tensor, group=pg) + + tensor = tensor.narrow(0, 0, tensor_numel).reshape(sharded_tensor.size()) + return tensor + + +class CompanionMismatch(Exception): + ... + + +def _iterate_state_dict( + iter_object: Any, + sharded_tensor_func: Callable, + dtensor_func: Callable, + tensor_func: Callable, + *, + pg: Optional[dist.ProcessGroup] = None, + device: Optional[core.device] = None, + cpu_offload: bool = False, + companion_obj: Any = None, + ranks_only: Tuple[int, ...] = (), + type_check: bool = True, + non_blocking: bool = True, +) -> Dict[str, Any]: + """Iterate through the state dict, applying the given functions to each tensor type. + + Args: + iter_object (Any): the target state_dict. + sharded_tensor_func (Callable): the function to apply to ShardedTensor + dtensor_func (Callable): the function to apply to DTensor + tensor_func (Callable): the function to apply to Tensor + pg (Optional[dist.ProcessGroup]): process group passed to tensor functions + device (Optional[core.device]): device passed to tensor functions + cpu_offload (bool): whether to offload the tensors to CPU memory. This option is ignored + if a companion_obj is supplied. + companion_obj (Any): A companion object to the state dict. If this object + is supplied, we attempt to copy the tensor to the companion object. + ranks_only (Tuple[int, ...]): if this tuple is empty, all ranks will + have the same state_dicts. Otherwise only ranks that in ``ranks_only`` + have the same state_dicts. Other ranks will get empty state_dicts. + type_check (bool): check if the instance data type is a supported type + that can be saved by DCP. The current supported data types are + core.Tensor, DTensor, int, float, str, list, dict, None. + non_blocking (bool): whether to use non-blocking copy when copying to the companion object. + """ + # TODO: should we use pytree? + cpu_device = core.device("cpu") + if isinstance(iter_object, ShardedTensor): + ret = sharded_tensor_func(iter_object, pg, device, companion_obj) + elif isinstance(iter_object, DTensor): + ret = dtensor_func(iter_object, pg, device, companion_obj) + elif isinstance(iter_object, core.Tensor): + ret = tensor_func(iter_object, pg, device, companion_obj) + elif ( + isinstance(iter_object, (int, float, str, bytes, io.BytesIO)) + or iter_object is None + ): + ret = iter_object + elif isinstance(iter_object, dict): + if companion_obj is not None and ( + not isinstance(companion_obj, dict) + or set(companion_obj.keys()) != set(iter_object.keys()) + ): + msg = ( + "" + if isinstance(companion_obj, dict) + else f"{set(companion_obj.keys())=} {set(iter_object.keys())=}" + ) + raise CompanionMismatch(msg) + + ret = { + key: _iterate_state_dict( + value, + sharded_tensor_func, + dtensor_func, + tensor_func, + pg=pg, + device=device, + cpu_offload=cpu_offload, + companion_obj=companion_obj[key] if companion_obj is not None else None, + ranks_only=ranks_only, + type_check=type_check, + non_blocking=non_blocking, + ) + for key, value in iter_object.items() + } + elif isinstance(iter_object, (list, tuple)): + if companion_obj is not None and ( + not isinstance(companion_obj, (list, tuple)) + or len(companion_obj) != len(iter_object) + ): + raise CompanionMismatch + + ret = [ + _iterate_state_dict( + v, + sharded_tensor_func, + dtensor_func, + tensor_func, + pg=pg, + device=device, + cpu_offload=cpu_offload, + companion_obj=companion_obj[idx] if companion_obj is not None else None, + ranks_only=ranks_only, + type_check=type_check, + non_blocking=non_blocking, + ) + for idx, v in enumerate(iter_object) + ] + if isinstance(iter_object, tuple): + ret = tuple(ret) + elif not type_check: + ret = copy.deepcopy(iter_object) + else: + raise ValueError(f"Unexpected value type {type(iter_object)}") + + if not ranks_only or dist.get_rank(pg) in ranks_only: + if isinstance(ret, core.Tensor): + if cpu_offload and companion_obj is None: + ret = ret.to(cpu_device) + + if companion_obj is not None: + # TODO: support DTensor + companion_obj.copy_(ret, non_blocking=non_blocking) + ret = companion_obj + else: + ret = {} if isinstance(ret, dict) else None + + return ret + + +def _gather_state_dict( + state_dict: Dict[str, Any], + *, + pg: Optional[dist.ProcessGroup] = None, + device: Optional[core.device] = None, + cpu_offload: bool = False, + ranks_only: Tuple[int, ...] = (), + type_check: bool = True, +) -> Dict[str, Any]: + """ + Given a state_dict, this API gathers all the ShardedTensors or DTensors in + the state_dict. + + + Args: + state_dict (Dict[str, Any]): the target sharded state_dict. + pg (Optional[dist.ProcessGroup]): the process group that is used to + gather ShardedTensor. Note that gathering a DTensor will use + the DeviceMesh. So this argument will be ignored when gathering a + DTensor. + device: (Optional[core.device]): the device that is used to + perform allgather for ShardedTensor. Note that gathering a DTensor + will use the DeviceMesh. So this argument will be ignored when + gathering a DTensor. + cpu_offload (bool): whether to offload the tensors to CPU memory. The + default value is False. + ranks_only: (Tuple[int, ...]): if this tuple is empty, all ranks will + have the same state_dicts. Otherwise only ranks that in ``ranks_only`` + have the same state_dicts. Other ranks will get empty state_dicts. + type_check: (bool): check if the instance data type is a supported type + that can be saved by DCP. The current supported data types are + core.Tensor, DTensor, int, float, str, list, dict, None. + + Returns: + The gathered state dictionary. + """ + + def sharded_tensor_func(value, pg, device, companion_obj): + # ShardedTensor does not seem to record the original device type. + # So if the tensor is moved to CPU, we won't know the original type. + # As a result, we have to rely on the user to tell us the correct one. + cpu_device = core.device("cpu") + output_tensor = _all_gather_sharded_tensor(value, pg, device) + local_shard_device = ( + value.local_shards()[0].tensor.device + if value.local_shards() + else cpu_device + ) + if output_tensor.device != local_shard_device: + value = output_tensor.to(local_shard_device) + else: + value = output_tensor + return value + + def dtensor_func(value, pg, device, companion_obj): + if value.device != value.device_mesh.device_type: + value = value.to(value.device_mesh.device_type) + # FSDP all_gather: [Shard(0)] -> [Replicate()] + # HSDP all_gather: [Replicate(), Shard(0)] -> [Replicate(), Replicate()] + # 2D FSDP + TP all_gather: + # - [Shard(0), Shard(n)] -> [Replicate(), Replicate()] + # - [Shard(0), Replicate()] -> [Replicate(), Replicate()] + placements = [Replicate() for _ in value.placements] + value = value.redistribute( + device_mesh=value.device_mesh, + placements=placements, + ) + # Call `wait()` to force the tensor to be synchronous with respect + # to the main stream. + # See the discussion in https://github.com/pytorch/pytorch/pull/117799. + value = value.to_local() + if isinstance(value, AsyncCollectiveTensor): + value = value.wait() + return value + + return _iterate_state_dict( + state_dict, + sharded_tensor_func, + dtensor_func, + _identity_func, + pg=pg, + device=device, + cpu_offload=cpu_offload, + ranks_only=ranks_only, + type_check=type_check, + ) + + +def _offload_state_dict_to_cpu( + state_dict: Dict[str, Any], + *, + ranks_only: Tuple[int, ...] = (), + type_check: bool = True, +) -> Dict[str, Any]: + """ + Given a state_dict, this API offload all the tensors to CPU memory. + + Args: + state_dict (Dict[str, Any]): the target state_dict. + pg (Optional[dist.ProcessGroup]): the process group that is used to + gather ShardedTensor. Note that gathering a DTensor will use + the DeviceMesh. So this argument will be ignored when gathering a + DTensor. + ranks_only: (Tuple[int, ...]): if this tuple is empty, all ranks will + have the same state_dicts. Otherwise only ranks that in ``ranks_only`` + have the same state_dicts. Other ranks will get empty state_dicts. + type_check: (bool): check if the instance data type is a supported type + that can be saved by DCP. The current supported data types are + core.Tensor, DTensor, int, float, str, list, dict, None. + + Returns: + The gathered state dictionary. + """ + + ret = _iterate_state_dict( + state_dict, + _identity_func, + _identity_func, + _identity_func, + pg=None, + device=None, + cpu_offload=True, + ranks_only=ranks_only, + type_check=type_check, + ) + return ret + + +def _copy_state_dict( + state_dict: Dict[str, Any], + copy_state_dict: Dict[str, Any], + non_blocking: bool = False, + type_check: bool = True, +) -> Dict[str, Any]: + """ + Copies all tensors in a given state dict into a different state_dict with the + same structure. Additionally, a copied state dict with the same value references + is returned. Editing the keys on this state dict will not affect the + passed in copy_state_dict (but the value references are the same). + + .. warning:: + It is expected by this function that state_dict and copy_state_dict share + the same structure and data types. + + .. warning:: + The current supported data types are + core.Tensor, DTensor, int, float, str, list, dict, None. + + Args: + state_dict (Dict[str, Any]): the target state_dict. + copy_state_dict (Dict[str, Any]): + The state dict we are copying into. This state_dict must have exactly + the same structure as the source `state_dict`. + non_blocking: (bool): Whether copy ops should be performed asynchronously + type_check (bool): check if the instance data type is a supported type + that can be saved by DCP. The current supported data types are + core.Tensor, DTensor, int, float, str, list, dict, None. + + Returns: + State Dict copy + """ + + return _iterate_state_dict( + state_dict, + _identity_func, + _identity_func, + _identity_func, + pg=None, + device=None, + cpu_offload=False, + ranks_only=(), + companion_obj=copy_state_dict, + type_check=type_check, + non_blocking=non_blocking, + ) + + +def _create_cpu_state_dict( + state_dict: Dict[str, Any], pin_memory: bool = False, share_memory: bool = False +) -> Dict[str, Any]: + """ + Given a state_dict, create another state_dict with the same structure and elements. + However, all tensors in the returned state_dict are new tensors on CPU. These + tensors can be placed on pin_memory or share_memory based on the provided arguments. + + .. warning:: + Setting both `pin_memory` and `share_memory` to True significantly increases the + latency of this method because of the nuances which require us to register memory + as pinned directly as opposed to relying on the pin_memory cache allocator. This + option should only be used for long lived tensors which are required to be shared. + This is not the case as long as at least one of `pin_memory` or `share_memory` is + set to False. + + """ + + def tensor_func( + obj: core.Tensor, + pg: Optional[dist.ProcessGroup], + device: Optional[core.device], + _: Any, + ) -> core.Tensor: + if len(obj.size()) == 0: + return core.tensor(0, dtype=obj.dtype) + + if share_memory: + t = core.empty(*tuple(obj.size()), dtype=obj.dtype) + t = t.share_memory_() + if pin_memory: + + def unpin_memory(t): + succ = int(core.cuda.cudart().cudaHostUnregister(t.data_ptr())) + assert ( + succ == 0 + ), f"Unpinning shared memory failed with error-code: {succ}" + + weakref.finalize(t, unpin_memory, t) + succ = int( + core.cuda.cudart().cudaHostRegister( + t.data_ptr(), + t.numel() * t.element_size(), + 1, # lines up with 'cudaHostRegisterPortable' + ) + ) + assert ( + succ == 0 + ), f"Pinning shared memory failed with error-code: {succ}" + return t + elif pin_memory: + return core.empty(*tuple(obj.size()), dtype=obj.dtype).pin_memory() + else: + return core.empty(*tuple(obj.size()), dtype=obj.dtype) + + ret = _iterate_state_dict( + state_dict, + _identity_func, + _identity_func, + tensor_func, + pg=None, + device=None, + cpu_offload=False, + ranks_only=(), + type_check=False, + ) + return ret + + +def _check_state_dict_similarity( + state_dict: Dict[str, Any], + compared_state_dict: Dict[str, Any], +) -> bool: + """ + Given two state_dicts, check if the structures are the same. And + if a [key, tensor] pair exist in one state_dict there must be + the a corresponding pait, [key, other_tensor], in the other state_dict, + where tensor and other_tensor have the same size and dtype. + + Return the check result. + """ + + def tensor_func( + obj: core.Tensor, + pg: Optional[dist.ProcessGroup], + device: Optional[core.device], + companion_obj: Any, + ) -> core.Tensor: + if companion_obj.dtype != obj.dtype or companion_obj.size() != obj.size(): + raise CompanionMismatch + return obj + + try: + _iterate_state_dict( + state_dict, + _identity_func, + _identity_func, + tensor_func, + pg=None, + device=None, + cpu_offload=False, + ranks_only=(), + companion_obj=compared_state_dict, + type_check=False, + ) + except CompanionMismatch: + return False + + return True + + +class _TensorInfo(NamedTuple): + size: core.Size + dtype: core.dtype + + +def _broadcast_tensors( + full_state_dict: Dict[str, Any], + local_state_dict: Dict[str, Any], + keys: List[str], + device: core.device, + pg: Optional[dist.ProcessGroup] = None, +) -> None: + tensors = [] + for key in keys: + if dist.get_rank() == 0: + full_state = full_state_dict[key] + assert isinstance(full_state, core.Tensor) + full_tensor = full_state.detach().to(device) + else: + tensor_info = full_state_dict[key] + full_tensor = core.empty( + size=tensor_info.size, + device=device, + dtype=tensor_info.dtype, + ) + + tensors.append(full_tensor) + local_state = local_state_dict.get(key, None) + if local_state is None: + continue + elif isinstance(local_state, DTensor): + local_state_dict[key] = (local_state, full_tensor) + else: + local_state_dict[key] = full_tensor + + if pg is None: + pg = dist.distributed_c10d._get_default_group() + + if len(tensors) > 1: + dist._broadcast_coalesced(pg, tensors, 500, 0) + else: + dist.broadcast(tensors[0], src=0, group=pg) + + _distribute_tensors(local_state_dict, keys, device, pg) + + +def _distribute_tensors( + local_state_dict: Dict[str, Any], + keys: List[str], + device: core.device, + pg: Optional[dist.ProcessGroup] = None, +) -> None: + if pg is None: + pg = dist.distributed_c10d._get_default_group() + for key in keys: + _local_state = local_state_dict.get(key, None) + if _local_state is None or core.is_tensor(_local_state): + continue + + local_state = _local_state[0] + full_tensor = _local_state[1] + + shape, offset = compute_local_shape_and_global_offset( + full_tensor.shape, local_state.device_mesh, local_state.placements + ) + slices = [ + slice(cur_offset, cur_offset + cur_shape) + for cur_shape, cur_offset in zip(shape, offset) + ] + local_tensor = full_tensor[slices] + # TODO: currently, we cannot handle strided sharding if the dp dimension is not even. For example, + # one of the case that is not yet supported is when placements = (Shard(0), _StridedShard(0, sf=2)). + local_state_dict[key] = DTensor.from_local( + local_tensor, + local_state.device_mesh, + local_state.placements, + shape=local_state.shape, + stride=local_state.stride(), + ) + + +def _broadcast_state_dict( + full_state_dict: Dict[str, Any], + local_state_dict: Dict[str, Any], + device: core.device, + pg: Optional[dist.ProcessGroup] = None, + strict: bool = False, +) -> None: + # Broadcast from rank0's `full_state_dict` to all ranks' `local_state_dict`. + # If strict is True, any keys in `local_state_dict` but not in `full_state_dict` + # will be removed from `local_state_dict`. + ret = {} + if dist.get_rank() == 0: + for key, value in full_state_dict.items(): + if not core.is_tensor(value): + ret[key] = value + elif value.dim() == 0: + ret[key] = value.cpu() + else: + ret[key] = _TensorInfo(value.size(), value.dtype) + + broadcast_list = [ret] + dist.broadcast_object_list(broadcast_list, src=0, group=pg) + ret = broadcast_list[0] + + # Gather values + keys = [] + local_state_dict_keys = set(local_state_dict.keys()) + global_keys = set() + for key, value in ret.items(): + global_keys.add(key) + if not isinstance(value, _TensorInfo): + if key in local_state_dict: + local_state_dict[key] = value + continue + + if dist.get_rank() == 0: + ret[key] = full_state_dict[key] + + keys.append(key) + # Broadcast every tensor to avoid OOM for now. + if len(keys) >= 1: + _broadcast_tensors(ret, local_state_dict, keys, device, pg) + keys.clear() + + if strict: + if missing_keys := (local_state_dict_keys - global_keys): + for key in missing_keys: + local_state_dict.pop(key) + + if keys: + _broadcast_tensors(ret, local_state_dict, keys, device, pg) + + +def _distribute_state_dict( + full_state_dict: Dict[str, Any], + local_state_dict: Dict[str, Any], + device: core.device, + pg: Optional[dist.ProcessGroup] = None, +) -> None: + # Full_state_dict = True, broadcast_from_rank0 = False here. Each rank has + # full_state_dict. Skip the broadcast in ``_broadcast_state_dict`` and + # distribute tensors in each rank + for key, value in full_state_dict.items(): + if key not in full_state_dict: + continue + if not core.is_tensor(value): + local_state_dict[key] = value + elif value.dim() == 0: + local_state_dict[key] = value.cpu() + else: + assert isinstance(value, core.Tensor) + local_state = local_state_dict.get(key, None) + if local_state is None: + continue + elif isinstance(local_state, DTensor): + local_state_dict[key] = distribute_tensor( + value.detach().to(device), + local_state.device_mesh, + local_state.placements, + ) + else: + local_state_dict[key] = value.detach().to(device) + + +# These APIs are from core.distributed.checkpoint. +# TODO: We should consolidate the code here as some not all modules can depend on +# DCP. +PATH_ITEM = Union[str, int] +OBJ_PATH = Tuple[PATH_ITEM, ...] +FLATTEN_MAPPING = Dict[str, OBJ_PATH] +STATE_DICT_TYPE = Dict[str, Any] +CONTAINER_TYPE = MutableMapping[PATH_ITEM, Any] + + +def _traverse_state_dict( + state_dict: STATE_DICT_TYPE, + visitor: Callable[[OBJ_PATH, Any], None], +) -> None: + """ + Invoke ``visitor`` for each value recursively in ``state_dict``. + Mapping, list, and tuple will be flattened and other value types are treated + as the terminal values and will invoke ``visitor``. + """ + + def _traverse_obj(path: OBJ_PATH, value: Any) -> None: + if isinstance(value, Mapping): + for k, v in value.items(): + _traverse_obj(path + (str(k),), v) + elif isinstance(value, (list, tuple)): + for i, v in enumerate(value): + _traverse_obj(path + (i,), v) + else: + visitor(path, value) + + for key, value in state_dict.items(): + _traverse_obj((str(key),), value) + + +def _flatten_state_dict( + state_dict: STATE_DICT_TYPE, +) -> Tuple[STATE_DICT_TYPE, FLATTEN_MAPPING]: + """ + Flatten ``state_dict`` made of nested dicts and lists into a top level dictionary. + + Use ``unflatten_state_dict`` to revert this process. + Returns: + A tuple with the flatten state_dict and a mapping from original to new state_dict. + N.B. The new keys are derived from the object paths, joined by dot. + For example: ``{ 'a': {'b':...}}`` results in the key `a.b`. + """ + flattened: STATE_DICT_TYPE = {} + mappings: FLATTEN_MAPPING = {} + + def flat_copy(path: OBJ_PATH, value: Any) -> None: + new_fqn = ".".join(map(str, path)) + if new_fqn in flattened: + raise ValueError(f"duplicated flatten key {new_fqn}") + flattened[new_fqn] = value + mappings[new_fqn] = path + + _traverse_state_dict(state_dict, flat_copy) + return flattened, mappings + + +def _set_element(root_dict: STATE_DICT_TYPE, path: OBJ_PATH, value: Any) -> None: + """Set ``value`` in ``root_dict`` along the ``path`` object path.""" + cur_container = cast(CONTAINER_TYPE, root_dict) + + def extend_list(lst: List[Any], idx: int) -> None: + while len(lst) <= idx: + lst.append(None) + + for i in range(1, len(path)): + prev_key = path[i - 1] + key = path[i] + def_val: Union[CONTAINER_TYPE, List[Any]] = {} if type(key) == str else [] + + if isinstance(cur_container, Mapping): + cur_container = cast( + CONTAINER_TYPE, cur_container.setdefault(prev_key, def_val) + ) + else: + extend_list(cur_container, prev_key) + if cur_container[prev_key] is None: + cur_container[prev_key] = def_val + cur_container = cur_container[prev_key] + + key = path[-1] + if type(key) == int: + extend_list(cast(List[Any], cur_container), key) + + cur_container[key] = value + + +def _unflatten_state_dict( + state_dict: STATE_DICT_TYPE, mapping: FLATTEN_MAPPING +) -> STATE_DICT_TYPE: + """Restore the original nested state_dict according to ``mapping`` and the flattened ``state_dict``.""" + nested: STATE_DICT_TYPE = {} + for key, value in state_dict.items(): + _set_element(nested, mapping[key], value) + return nested diff --git a/mindnlp/core/distributed/_symmetric_memory/__init__.py b/mindnlp/core/distributed/_symmetric_memory/__init__.py new file mode 100644 index 000000000..c1b3cfbe1 --- /dev/null +++ b/mindnlp/core/distributed/_symmetric_memory/__init__.py @@ -0,0 +1,1496 @@ +# mypy: allow-untyped-decorators +import socket +import uuid +from contextlib import contextmanager +from datetime import timedelta +from enum import Enum +from functools import partial +from typing import Any, Callable, Dict, Generator, List, Optional, Tuple + +from mindnlp import core +from mindnlp import core.distributed._functional_collectives as funcol +from mindnlp import core.distributed.distributed_c10d as c10d +from core._C._distributed_c10d import _SymmetricMemory, Work as _Work + + +_group_name_to_store: Dict[str, c10d.Store] = {} + + +def enable_symm_mem_for_group(group_name: str) -> None: + """ + Enables symmetric memory for a process group. + + Args: + group_name (str): the name of the process group. + """ + if group_name in _group_name_to_store: + return + + group = c10d._resolve_process_group(group_name) + global_ranks = sorted(c10d._world.pg_group_ranks[group].keys()) + # Different subgroups with the same name should use different stores + global_ranks_str = "_".join(map(str, global_ranks)) + store = c10d.PrefixStore( + f"symmetric_memory-{global_ranks_str}", + c10d._get_process_group_store(group), + ) + # Use one store-based broadcast to bootstrap a file store from the process + # and simultaneously verify that all ranks are on the same host. + hostname = socket.gethostname() + if group.rank() == 0: + uid = str(uuid.uuid4()) + msg = f"{hostname}/{uid}" + store.set("init", msg) + else: + msg = store.get("init").decode("utf-8") + tokens = msg.split("/") + assert len(tokens) == 2, tokens + rank_0_hostname, uid = tokens + if hostname != rank_0_hostname: + raise RuntimeError( + "init_symmetric_memory_for_process_group() failed for " + f'group "{group_name}". Rank 0 and rank {group.rank()} ' + f"are on different hosts ({rank_0_hostname} and {hostname})" + ) + store = core._C._distributed_c10d.FileStore(f"/tmp/{uid}", group.size()) + # TODO: check device connectiivity + _group_name_to_store[group_name] = store + _SymmetricMemory.set_group_info( + group_name, + group.rank(), + group.size(), + store, + ) + + +_is_test_mode: bool = False + + +@contextmanager +def _test_mode() -> Generator[None, None, None]: + """ + Forces ``is_symm_mem_enabled_for_group()`` to return ``True`` and the ops + defined in the ``symm_mem`` namespace to use fallback implementations. + + The context manager is not thread safe. + """ + global _is_test_mode + prev = _is_test_mode + try: + _is_test_mode = True + yield + finally: + _is_test_mode = prev + + +def is_symm_mem_enabled_for_group(group_name: str) -> bool: + """ + Check if symmetric memory is enabled for a process group. + + Args: + group_name (str): the name of the process group. + """ + return _is_test_mode or group_name in _group_name_to_store + + +_group_name_to_workspace_tensor: Dict[str, Optional[core.Tensor]] = {} + + +def get_symm_mem_workspace(group_name: str, min_size: int) -> _SymmetricMemory: + """ + Get the symmetric memory workspace associated with the process group. If + ``min_size`` is greater than the workspace associated with ``group_name``, + the workspace will be re-allocated and re-rendezvous'd. + + Args: + group_name (str): the name of the process group. + min_size (int): the size requirement for the workspace in bytes. + + Returns: + _SymmetricMemory: the symmetric memory workspace associated with the + group. + """ + enable_symm_mem_for_group(group_name) + + tensor = _group_name_to_workspace_tensor.get(group_name) + size = tensor.numel() * tensor.element_size() if tensor is not None else 0 + if tensor is None or size < min_size: + if core.cuda.is_current_stream_capturing(): + curr_size = 0 if tensor is None else tensor.numel() * tensor.element_size() + raise RuntimeError( + f"get_symm_mem_workspace(): the requested size ({min_size} bytes) " + "is greater than the size of the currently allocated workspace " + f"({curr_size} bytes). It's currently not possible to expand the " + "workspace size during graph capture. Please invoke " + f'`get_symm_mem_workspace(group_name="{group_name}", ' + f'min_size="{min_size}")` before initiating the graph capture ' + "and try again." + ) + tensor = _SymmetricMemory.empty_strided_p2p( + (max(size, min_size),), + [1], + core.uint8, + core.device(f"cuda:{core.cuda.current_device()}"), + group_name, + ) + _group_name_to_workspace_tensor[group_name] = tensor + return _SymmetricMemory.rendezvous(tensor) + + +_backend_streams: Dict[int, core.cuda.Stream] = {} + + +def _get_backend_stream(priority: int = 0) -> core.cuda.Stream: + if priority not in _backend_streams: + _backend_streams[priority] = core.cuda.Stream(priority=priority) + return _backend_streams[priority] + + +def _pipelined_multi_all_gather_and_consume( + shard: List[core.Tensor], + shard_consumer: Callable[[List[core.Tensor], int], None], + ag_out: List[core.Tensor], + group_name: str, +) -> None: + """ + Perform the following logic with micro-pipelined computation and + communication: + + gathered = [ + all_gather_tensor(x, gather_dim=0, group=group) + for x in shard + ] + + shards = [[] for _ in range(group_size)] + for x in ag_out: + for i, y in enumerate(x.chunk(group_size)): + shards[i].append(y) + + for src_rank, shard in enumerate(shards): + shard_consumer(shard, src_rank) + """ + p2p_workspace_size_req = 0 + for x in shard: + p2p_workspace_size_req += x.numel() * x.element_size() + symm_mem = get_symm_mem_workspace(group_name, min_size=p2p_workspace_size_req) + group_size = symm_mem.world_size + rank = symm_mem.rank + + symm_mem.barrier(channel=0) + backend_stream = _get_backend_stream() + backend_stream.wait_stream(core.cuda.current_stream()) + + for x, y in zip(shard, ag_out): + assert x.is_contiguous(), ( + "_pipelined_all_gather_and_consume: all tensors " + "in `shard` must be contiguous" + ) + assert y.is_contiguous(), ( + "_pipelined_all_gather_and_consume: all tensors " + "in `ag_out` must be contiguous" + ) + assert x.shape[0] * group_size == y.shape[0] + assert x.shape[1:] == y.shape[1:] + + def copy_shard(dst: List[core.Tensor], src: List[core.Tensor]) -> None: + for d, s in zip(dst, src): + d.copy_(s) + + def get_p2p_bufs(remote_rank: int) -> List[core.Tensor]: + offset_bytes = 0 + bufs = [] + for x in shard: + buf = symm_mem.get_buffer( + remote_rank, + x.shape, + x.dtype, + storage_offset=offset_bytes // x.element_size(), + ) + bufs.append(buf) + offset_bytes += buf.numel() * buf.element_size() + return bufs + + local_p2p_bufs = get_p2p_bufs(rank) + + # shards[i] => shard from rank i + shards: List[List[core.Tensor]] = [[] for _ in range(group_size)] + for x in ag_out: + for i, y in enumerate(x.chunk(group_size)): + shards[i].append(y) + + # Parallelization strategy: after each rank copies its shard into its local + # p2p buffer, every rank issues independent p2p copy -> shard_consumer + # sequences to two streams. In addition to computation/communication + # overlapping, the strategy allows for computation/computation overlapping, + # greatly reducing quantization inefficiency. + # + # Notation: + # - "mv" for the copy to local buffer + # - "cp" for p2p copies + # - "b" for barriers + # + # Constraints: + # - The GPU scheduler may or may not overlap "mv" with the first shard_consumer. + # - "cp" from different streams cannot overlap. + # + # Ideal scenario 0 - "mv" overlaps with the first shard_consumer: + # + # stream 0: [ shard_consumer ][ cp ][ shard_consumer ] + # stream 1: [ mv ][b][ cp ][ shard_consumer ] + # + # Ideal scenario 1 - "mv" is scheduled before the first shard_consumer: + # + # stream 0: [ shard_consumer ][ cp ][ shard_consumer ] + # stream 1: [ mv ][b][ cp ][ shard_consumer ] + # + # Suboptimal scenario 0 - "mv" is scheduled after the first shard_consumer: + # + # stream 0: [ shard_consumer ] [ cp ][ shard_consumer ] + # stream 1: [ mv ][b][ cp ][ shard_consumer ] + # + # Suboptimal scenario 0 - "b" is scheduled after the first shard_consumer: + # + # stream 0: [ shard_consumer ] [ cp ][ shard_consumer ] + # stream 1: [ mv ] [b][ cp ][ shard_consumer ] + # + # We haven't yet figured out a way to ensure "mv" and "b" are either + # overlapped with or scheduled before the first shard_consumer. Thus, to + # prevent suboptimal scenarios, we are giving up the chance to overlap "mv" + # and "b" with the first shard_consumer for now. + copy_shard(dst=local_p2p_bufs, src=shard) + symm_mem.barrier(channel=1) + backend_stream.wait_stream(core.cuda.current_stream()) + + # At this point, all ranks have copied their local shard to + # their local p2p buffer. Each rank can now copy and consume + # remote shards. + shard_consumer(shard, rank) + + for step in range(1, group_size): + if step % 2 == 0: + stream = core.cuda.current_stream() + else: + stream = backend_stream + remote_rank = (step + rank) % group_size + remote_p2p_bufs = get_p2p_bufs(remote_rank) + with core.cuda.stream(stream): + copy_shard(dst=shards[remote_rank], src=remote_p2p_bufs) + shard_consumer(shards[remote_rank], remote_rank) + + # Copy from input to the all-gather output. Opportunistically overlap it + # with the last shard_consumer. + if group_size % 2 == 0: + stream = core.cuda.current_stream() + else: + stream = backend_stream + with core.cuda.stream(stream): + copy_shard(dst=shards[rank], src=shard) + + core.cuda.current_stream().wait_stream(backend_stream) + symm_mem.barrier(channel=0) + + +def _pipelined_all_gather_and_consume( + shard: core.Tensor, + shard_consumer: Callable[[core.Tensor, int], None], + ag_out: core.Tensor, + group_name: str, +) -> None: + """ + Perform the following logic with micro-pipelined computation and + communication: + + ag_out = all_gather_tensor(shard, gather_dim=0, group=group) + shards = ag_out.chunk(group.size()) + for src_rank, shard in enumerate(shards): + shard_consumer(shard, src_rank) + """ + + def adapter(shard: List[core.Tensor], rank: int) -> None: + shard_consumer(shard[0], rank) + + _pipelined_multi_all_gather_and_consume( + [shard], + adapter, + [ag_out], + group_name, + ) + + +def _pipelined_produce_and_all2all( + chunk_producer: Callable[[int, core.Tensor], None], + output: core.Tensor, + group_name: str, +) -> None: + """ + Perform the following logic with micro-pipelined computation and + communication: + + chunks = [ + chunk_producer(dst_rank, chunks[dst_rank]) + for dst_rank in range(group_size): + ] + dist.all_to_all_single(output=output, input=core.cat(chunks)) + """ + out_chunks = output.chunk(c10d._get_group_size_by_name(group_name)) + p2p_workspace_size_req = out_chunks[0].numel() * out_chunks[0].element_size() * 2 + symm_mem = get_symm_mem_workspace(group_name, min_size=p2p_workspace_size_req) + group_size = symm_mem.world_size + rank = symm_mem.rank + + symm_mem.barrier(channel=0) + backend_stream = _get_backend_stream() + backend_stream.wait_stream(core.cuda.current_stream()) + + def get_p2p_buf(rank: int, idx: int) -> core.Tensor: + assert idx in (0, 1) + offset = 0 if idx == 0 else out_chunks[0].numel() + return symm_mem.get_buffer( + rank, out_chunks[0].shape, out_chunks[0].dtype, offset + ) + + # Prepare two local p2p buffers, so that a remote rank can pull the result + # of step [i] in one p2p buffer while the local rank can compute the + # result of step [i+1] and write it directly the other p2p buffer. + local_p2p_buf_0 = get_p2p_buf(rank, 0) + local_p2p_buf_1 = get_p2p_buf(rank, 1) + + for step in range(1, group_size): + remote_rank = (rank - step) % group_size + if step % 2 == 0: + stream = core.cuda.current_stream() + p2p_buf = local_p2p_buf_1 + remote_p2p_buf = get_p2p_buf(remote_rank, 1) + else: + stream = backend_stream + p2p_buf = local_p2p_buf_0 + remote_p2p_buf = get_p2p_buf(remote_rank, 0) + with core.cuda.stream(stream): + # Parallelization strategy: every rank issues independent compute + # -> barrier -> p2p copy sequences on two streams. In addition to + # computation/communication overlapping, the strategy allows for + # computation/computation overlapping, greatly reducing + # quantization inefficiency. + # + # Ideally, stream activities would look like this ("b" for + # barriers, "cp" for p2p copies): + # + # [rank 0] + # stream 0: [ chunk_producer ][b][ cp ][ chunk_producer ][b][ cp ] + # stream 1: [ chunk_producer ][b][ cp ][ chunk_producer ][b][ cp ] + # + # [rank 1] + # stream 0: [ chunk_producer ][b][ cp ][ chunk_producer ][b][ cp ] + # stream 1: [ chunk_producer ][b][ cp ][ chunk_producer ][b][ cp ] + # + # Note that the barriers synchronize streams with the same ID + # across ranks. They don't synchronize streams on the same rank. + # + # Since the work on both streams is independent, there's no + # guarantee that the chunk_producer from stream 0 or stream 1 will + # be scheduled first. If there is a scheduling mismatch across + # ranks, the barrier forces all ranks to wait for the slowest. + # + # When scheduling mismatches occur among ranks, the stream + # activities might look like this (note that p2p copies from + # different streams cannot overlap with each other): + # + # [rank 0] + # stream 0: [ chunk_producer ][b ][ cp ][ chunk_producer ][b ][ cp ] + # stream 1: [ chunk_producer ][b] [ cp ][ chunk_producer ][b] [ cp ] + # + # [rank 1] + # stream 0: [ chunk_producer ][b] [ cp ][ chunk_producer ][b] [ cp ] + # stream 1: [ chunk_producer ][b ][ cp ][ chunk_producer ][b ][ cp ] + # + # To prevent this, we need to ensure that the chunk_producer on + # stream 1 gets scheduled first on every rank. Without access to + # the underlying kernels, CUDA offers no API to control the + # scheduling order of two independent, overlapping kernels. Our + # solution is to issue a small sleep kernel in stream 0. The sleep + # duration is insignificant, but having an extra task in stream 0 + # will almost guarantee that the chunk_producer on stream 1 gets + # scheduled first. Once the first chunk_producer is scheduled in + # the correct order, there's very little room for the scheduling + # order of subsequent kernels to be inconsistent across ranks. + if step == 2: + core.cuda._sleep(100) + chunk_producer((rank + step) % group_size, p2p_buf) + symm_mem.barrier(channel=step % 2) + out_chunks[remote_rank].copy_(remote_p2p_buf) + # The local P2P buffer can only be overwritten by the next + # chunk_producer after all peers have finished reading from it. + symm_mem.barrier(channel=step % 2) + + chunk_producer(rank, out_chunks[rank]) + core.cuda.current_stream().wait_stream(backend_stream) + symm_mem.barrier(channel=0) + + +lib = core.library.Library("symm_mem", "DEF") # noqa: TOR901 +lib.define( + "fused_all_gather_matmul(Tensor A, Tensor[] Bs, int gather_dim, str group_name) -> (Tensor, Tensor[])" +) +lib.define( + "fused_all_gather_scaled_matmul(" + "Tensor A, Tensor[] Bs, Tensor A_scale, Tensor[] B_scales, " + "int gather_dim, str group_name, " + "Tensor?[] biases, " + "Tensor?[] result_scales, " + "ScalarType?[] out_dtypes, " + "bool[] use_fast_accum) -> (Tensor, Tensor[])" +) +lib.define( + "fused_matmul_reduce_scatter(Tensor A, Tensor B, str reduce_op, int scatter_dim, str group_name) -> Tensor" +) +lib.define( + "fused_scaled_matmul_reduce_scatter(" + "Tensor A, Tensor B, Tensor A_scale, Tensor B_scale, " + "str reduce_op, int scatter_dim, str group_name, " + "Tensor? bias = None, " + "Tensor? result_scale = None, " + "ScalarType? out_dtype = None, " + "bool use_fast_accum = False) -> Tensor" +) +lib.define("_low_contention_all_gather(Tensor tensor, str group_name) -> Tensor") +lib.define( + "_low_contention_reduce_scatter(Tensor tensor, str reduce_op, str group_name) -> Tensor" +) + + +class _ScaleMode(Enum): + UNSCALED = "unscaled" + TENSOR_WISE = "tensor-wise" + ROW_WISE_SHARDED = "row-wise-sharded" + ROW_WISE_REPLICATED = "row-wise-replicated" + + +def _check_and_verify_fp8_all_gather_scale_mode( + shard: core.Tensor, scale: Optional[core.Tensor], gather_dim: int, group_size: int +) -> _ScaleMode: + full_shape = list(shard.shape) + full_shape[gather_dim] *= group_size + + if scale is None: + return _ScaleMode.UNSCALED + elif scale.shape[:-1] == shard.shape[:-1] and scale.shape[-1] == 1: + # Row-wise scaling + # + # NOTE: when the last dim of both A_shard and A_scale is one, we can't + # tell if A_scale is replicated tensor-wise scale or sharded row-wise + # scale. Treating it as row-wise scaling for safety. + return _ScaleMode.ROW_WISE_SHARDED + elif scale.numel() == 1: + return _ScaleMode.TENSOR_WISE + elif list(scale.shape[:-1]) == full_shape[:-1]: + return _ScaleMode.ROW_WISE_REPLICATED + else: + raise ValueError( + "Invalid scale shape for fp8 all-gather " + f"(shard shape: {shard.shape}, scale shape: {scale.shape})" + ) + + +def _fused_all_gather_matmul_impl( + mm_out_op: core._ops.OpOverload, + A_shard: core.Tensor, + Bs: List[core.Tensor], + A_scale: Optional[core.Tensor], + kwargs_list: List[Dict[str, Any]], + out_dtypes: List[Optional[core.dtype]], + gather_dim: int, + group_name: str, +) -> Tuple[core.Tensor, List[core.Tensor]]: + if A_shard.dim() < 2: + raise ValueError("A_shard must be a matrix") + for B in Bs: + if B.dim() != 2: + raise ValueError("B must be a matrix") + if len(out_dtypes) != len(Bs): + raise ValueError("len(out_types) must be the same as len(Bs)") + if len(kwargs_list) != len(Bs): + raise ValueError("len(kwargs_list) must be the same as len(Bs)") + if gather_dim < 0 or gather_dim >= A_shard.dim(): + raise ValueError("Invalid gather_dim") + + group = c10d._resolve_process_group(group_name) + + # Move the gather_dim to the front and flatten the tensor into a 2D matrix. + # The flattened tensor doesn't need to be contiguous (for computation + # efficiency), as _pipelined_all_gather_and_consume guarantees that shards + # passed to shard_consumer are contiguous. + A_shard_flat = A_shard.movedim(gather_dim, 0) + leading_dims = [group.size()] + list(A_shard_flat.shape[:-1]) + A_shard_flat = A_shard_flat.flatten(0, -2) + + # Helper function for reverting the above transformation + def unflatten(t: core.Tensor) -> core.Tensor: + return t.view(*leading_dims, -1).flatten(0, 1).movedim(0, gather_dim) + + A_flat = A_shard_flat.new_empty( + A_shard_flat.shape[0] * group.size(), + A_shard_flat.shape[1], + ) + + outputs = [ + A_flat.new_empty(A_flat.shape[0], B.shape[1], dtype=out_dtype or B.dtype) + for B, out_dtype in zip(Bs, out_dtypes) + ] + output_shards = [output.chunk(group.size()) for output in outputs] + + scale_mode = _check_and_verify_fp8_all_gather_scale_mode( + shard=A_shard, scale=A_scale, gather_dim=gather_dim, group_size=group.size() + ) + + # Computing block-wise matmul along the first dim of A + if scale_mode == _ScaleMode.ROW_WISE_SHARDED: + assert A_scale is not None + A_scale_shard = A_scale.movedim(gather_dim, 0).flatten(0, -2) + A_scale_flat = A_scale_shard.new_empty( + A_scale_shard.shape[0] * group.size(), + A_scale_shard.shape[1], + ) + + def row_wise_sharded_consumer(shard: List[core.Tensor], rank: int) -> None: + for idx, (B, kwargs) in enumerate(zip(Bs, kwargs_list)): + mm_out_op( + shard[0], + B, + scale_a=shard[1], + **kwargs, + out=output_shards[idx][rank], + ) + + _pipelined_multi_all_gather_and_consume( + [A_shard_flat, A_scale_shard], + row_wise_sharded_consumer, + [A_flat, A_scale_flat], + group_name, + ) + elif scale_mode == _ScaleMode.ROW_WISE_REPLICATED: + assert A_scale is not None + A_scale_shards = ( + A_scale.movedim(gather_dim, 0).flatten(0, -2).chunk(group.size()) + ) + + def row_wise_replicated_consumer(shard: core.Tensor, rank: int) -> None: + for idx, (B, kwargs) in enumerate(zip(Bs, kwargs_list)): + mm_out_op( + shard, + B, + scale_a=A_scale_shards[rank], + **kwargs, + out=output_shards[idx][rank], + ) + + _pipelined_all_gather_and_consume( + A_shard_flat, + row_wise_replicated_consumer, + A_flat, + group_name, + ) + else: + if scale_mode == _ScaleMode.TENSOR_WISE: + assert A_scale is not None + for kwargs in kwargs_list: + kwargs["scale_a"] = A_scale + else: + assert scale_mode == _ScaleMode.UNSCALED + + def default_consumer(shard: core.Tensor, rank: int) -> None: + for idx, (B, kwargs) in enumerate(zip(Bs, kwargs_list)): + mm_out_op(shard, B, **kwargs, out=output_shards[idx][rank]) + + _pipelined_all_gather_and_consume( + A_shard_flat, + default_consumer, + A_flat, + group_name, + ) + + return unflatten(A_flat), [unflatten(output) for output in outputs] + + +@core.library.impl(lib, "fused_all_gather_matmul", "Meta") +def _fused_all_gather_matmul_fallback( + A_shard: core.Tensor, + Bs: List[core.Tensor], + gather_dim: int, + group_name: str, +) -> Tuple[core.Tensor, List[core.Tensor]]: + group_size = c10d._get_group_size_by_name(group_name) + A = core.ops._c10d_functional.all_gather_into_tensor( + A_shard.contiguous(), group_size, group_name + ) + A = core.ops._c10d_functional.wait_tensor(A) + A = A.view(group_size, *A_shard.shape).movedim(gather_dim + 1, 1).flatten(0, 1) + return A.movedim(0, gather_dim), [ + core.matmul(A, B).movedim(0, gather_dim) for B in Bs + ] + + +@core.library.impl(lib, "fused_all_gather_matmul", "CUDA") +def _fused_all_gather_matmul( + A_shard: core.Tensor, + Bs: List[core.Tensor], + gather_dim: int, + group_name: str, +) -> Tuple[core.Tensor, List[core.Tensor]]: + """ + Perform the following logic with micro-pipelined computation and + communication: + + all_gather_tensor(A_shard, gather_dim, group_name) @ B + + Optimal stride order for A_shard - if A_shard.movedim(gather_dim, 0) is + contiguous, no extra copy is required for input layout transformation. + Otherwise A_shard needs to be copied once. + """ + if _is_test_mode: + return _fused_all_gather_matmul_fallback(A_shard, Bs, gather_dim, group_name) + + with core.profiler.record_function("fused_all_gather_matmul"): + return _fused_all_gather_matmul_impl( + core.ops.aten.mm.out, + A_shard, + Bs, + None, + [{} for B in Bs], + [B.dtype for B in Bs], + gather_dim, + group_name, + ) + + +def _fused_all_gather_matmul_native( + A_shard: core.Tensor, + B: core.Tensor, + group_name: str, +) -> Tuple[core.Tensor, core.Tensor]: + symm_mem = _SymmetricMemory.rendezvous(A_shard) + if symm_mem is None: + symm_mem = get_symm_mem_workspace( + group_name, A_shard.numel() * A_shard.element_size() + ) + symm_mem.barrier() + buf = symm_mem.get_buffer(symm_mem.rank, A_shard.shape, A_shard.dtype) + buf.copy_(A_shard) + A_shard = buf + + rank = symm_mem.rank + world_size = symm_mem.world_size + + current_stream = core.cuda.current_stream() + backend_stream = _get_backend_stream(priority=-1) + + symm_mem.barrier() + current_stream.wait_stream(backend_stream) + backend_stream.wait_stream(current_stream) + + A = A_shard.new_empty(A_shard.shape[0] * world_size, A_shard.shape[1]) + A_signals = core.zeros(world_size, dtype=core.uint32, device=A_shard.device) + A_shards = A.chunk(world_size) + + A_shards[rank].copy_(A_shard) + _SymmetricMemory.stream_write_value32(A_signals, rank, 1) + + out = core.ops.symm_mem._async_input_mm(A, B, A_signals, rank) + for step in range(1, world_size): + src_rank = (rank + step) % world_size + src_buf = symm_mem.get_buffer(src_rank, A_shard.shape, A_shard.dtype) + with core.cuda.stream(backend_stream): + A_shards[src_rank].copy_(src_buf) + # cuStreamWriteValue32 issues a system level fence before the write + _SymmetricMemory.stream_write_value32(A_signals, src_rank, 1) + + current_stream.wait_stream(backend_stream) + backend_stream.wait_stream(current_stream) + + symm_mem.barrier() + return A, out + + +@core.library.impl(lib, "fused_all_gather_scaled_matmul", "Meta") +def _fused_all_gather_scaled_matmul_fallback( + A_shard: core.Tensor, + Bs: List[core.Tensor], + A_scale: core.Tensor, + B_scales: List[core.Tensor], + gather_dim: int, + group_name: str, + biases: List[Optional[core.Tensor]], + result_scales: List[Optional[core.Tensor]], + out_dtypes: List[Optional[core.dtype]], + use_fast_accum: List[bool], +) -> Tuple[core.Tensor, List[core.Tensor]]: + out_dtypes = _maybe_convert_scalar_types_to_dtypes(out_dtypes) + + group_size = c10d._get_group_size_by_name(group_name) + A = core.ops._c10d_functional.all_gather_into_tensor( + A_shard.contiguous(), group_size, group_name + ) + A = core.ops._c10d_functional.wait_tensor(A) + A = A.view(group_size, *A_shard.shape).movedim(gather_dim + 1, 1).flatten(0, 1) + + scale_mode = _check_and_verify_fp8_all_gather_scale_mode( + shard=A_shard, scale=A_scale, gather_dim=gather_dim, group_size=group_size + ) + if scale_mode == _ScaleMode.ROW_WISE_SHARDED: + A_scale_shard = A_scale + A_scale = core.ops._c10d_functional.all_gather_into_tensor( + A_scale.contiguous(), group_size, group_name + ) + A_scale = core.ops._c10d_functional.wait_tensor(A_scale) + A_scale = ( + A_scale.view(group_size, *A_scale_shard.shape) + .movedim(gather_dim + 1, 1) + .flatten(0, -2) + ) + elif scale_mode == _ScaleMode.ROW_WISE_REPLICATED: + A_scale = A_scale.movedim(gather_dim, 0).flatten(0, -2) + else: + assert scale_mode == _ScaleMode.TENSOR_WISE + + def scaled_matmul( + A: core.Tensor, + B: core.Tensor, + A_scale: core.Tensor, + B_scale: core.Tensor, + bias: Optional[core.Tensor], + result_scale: Optional[core.Tensor], + out_dtype: Optional[core.dtype], + use_fast_accum: bool, + ) -> core.Tensor: + leading_dims = A.shape[:-1] + res = core.ops.aten._scaled_mm( + A.flatten(0, -2), + B, + A_scale, + B_scale, + bias, + result_scale, + out_dtype=out_dtype, + use_fast_accum=use_fast_accum, + ) + return res.unflatten(0, leading_dims) + + return A.movedim(0, gather_dim), [ + scaled_matmul( + A, B, A_scale, B_scale, bias, result_scale, out_dtype, fast_accum + ).movedim(0, gather_dim) + for B, B_scale, bias, result_scale, out_dtype, fast_accum in zip( + Bs, B_scales, biases, result_scales, out_dtypes, use_fast_accum + ) + ] + + +@core.library.impl(lib, "fused_all_gather_scaled_matmul", "CUDA") +def _fused_all_gather_scaled_matmul( + A_shard: core.Tensor, + Bs: List[core.Tensor], + A_scale: core.Tensor, + B_scales: List[core.Tensor], + gather_dim: int, + group_name: str, + biases: List[Optional[core.Tensor]], + result_scales: List[Optional[core.Tensor]], + out_dtypes: List[Optional[core.dtype]], + use_fast_accum: List[bool], +) -> Tuple[core.Tensor, List[core.Tensor]]: + """ + Perform the following logic with micro-pipelined computation and + communication: + + A = all_gather_tensor(A_shard, gather_dim, group_name) + leading_dims = A.shape[:-1] + res = core.ops.aten._scaled_mm(A.flatten(0, -2), B, A_scale, B_scale) + res = res.unflatten(0, leading_dims) + + The input `A_scale` can be tensor-wise, row-wise-sharded or + row-wise-replicated. + + Optimal stride order for `A_shard` - if `A_shard.movedim(gather_dim, 0)` is + contiguous, no extra copy is required for input layout transformation. + Otherwise A_shard needs to be copied once. + """ + out_dtypes = _maybe_convert_scalar_types_to_dtypes(out_dtypes) + + if len(biases) != len(Bs): + raise ValueError("len(biases) must be the same as len(Bs)") + if len(result_scales) != len(Bs): + raise ValueError("len(result_scales) must be the same as len(Bs)") + if len(out_dtypes) != len(Bs): + raise ValueError("len(out_dtypes) must be the same as len(Bs)") + if len(use_fast_accum) != len(Bs): + raise ValueError("len(use_gast_accum_list) must be the same as len(Bs)") + + if _is_test_mode: + return _fused_all_gather_scaled_matmul_fallback( + A_shard, + Bs, + A_scale, + B_scales, + gather_dim, + group_name, + biases, + result_scales, + out_dtypes, + use_fast_accum, + ) + + with core.profiler.record_function("fused_all_gather_scaled_matmul"): + return _fused_all_gather_matmul_impl( + core.ops.aten._scaled_mm.out, + A_shard, + Bs, + A_scale, + [ + { + "scale_b": B_scale, + "bias": bias, + "scale_result": result_scale, + "out_dtype": out_dtype, + "use_fast_accum": fast_accum, + } + for B_scale, bias, result_scale, out_dtype, fast_accum in zip( + B_scales, biases, result_scales, out_dtypes, use_fast_accum + ) + ], + out_dtypes, + gather_dim, + group_name, + ) + + +def make_contiguous_for_perm( + t: core.Tensor, + perm: List[int], +) -> core.Tensor: + """ + Restride `t` such that `t.permute(perm)` is contiguous. + """ + inv_perm = [0] * len(perm) + for i, p in enumerate(perm): + inv_perm[p] = i + return t.permute(perm).contiguous().permute(inv_perm) + + +def restride_A_shard_for_fused_all_gather_matmul( + t: core.Tensor, + gather_dim: int, +) -> core.Tensor: + """ + Restride the `A_shard` arg of `fused_all_gather_matmul` for optimal perf. + See the doc for `fused_all_gather_matmul` for detail. + """ + perm = list(range(len(t.shape))) + perm.insert(0, perm.pop(gather_dim)) + return make_contiguous_for_perm(t, perm) + + +def _fused_matmul_reduce_scatter_impl( + mm_out_op: core._ops.OpOverload, + A: core.Tensor, + B: core.Tensor, + A_scale: Optional[core.Tensor], + kwargs: Dict[str, Any], + out_dtype: Optional[core.dtype], + reduce_op: str, + scatter_dim: int, + group_name: str, +) -> core.Tensor: + if A.dim() < 2: + raise ValueError("A_shard must be a matrix") + if scatter_dim < 0 or scatter_dim >= A.dim(): + raise ValueError("Invalid gather_dim") + if B.dim() != 2: + raise ValueError("B must be a matrix") + if reduce_op == "sum": + reduce_fn = partial(core.sum, dim=0) + elif reduce_op == "avg": + reduce_fn = partial(core.mean, dim=0) + else: + raise ValueError("reduce_op must be sum or avg") + + group = c10d._resolve_process_group(group_name) + out_shape = [*A.shape[:-1], B.shape[1]] + out_shape[scatter_dim] //= group.size() + + # Move the scatter_dim to the front and flatten the tensor into a 2D matrix + x = A.movedim(scatter_dim, 0) + leading_dims = [group.size()] + list(x.shape[:-1]) + leading_dims[1] //= group.size() + x = x.flatten(0, -2) + A_shards = x.chunk(group.size()) + + A_scale_shards = None + if A_scale is None: + pass + elif A_scale.numel() == 1: + A_scale_shards = [A_scale] * group.size() + else: + if A_scale.shape[:-1] != A.shape[:-1]: + raise ValueError( + "For row-wise scaling, the leading dims of A_scale " + "must match the leading dims of A " + f"(A shape: {A.shape}, A_scale shape: {A_scale.shape})" + ) + A_scale = A_scale.movedim(scatter_dim, 0).contiguous().flatten(0, -2) + A_scale_shards = list(A_scale.chunk(group.size())) + + # Computing block-wise matmul along the first dim of A + def chunk_producer(rank: int, out: core.Tensor) -> None: + if A_scale_shards is not None: + mm_out_op( + A_shards[rank], B, scale_a=A_scale_shards[rank], **kwargs, out=out + ) + else: + mm_out_op(A_shards[rank], B, **kwargs, out=out) + + stacked_partials = x.new_empty(x.shape[0], B.shape[1], dtype=out_dtype or A.dtype) + + _pipelined_produce_and_all2all( + chunk_producer, + stacked_partials, + group_name, + ) + # Ensures that the transpose and reduction produce contiguous result + # in a single reduction kernel. + return reduce_fn( + stacked_partials.view(*leading_dims, -1) + .movedim(1, scatter_dim + 1) + .movedim(0, scatter_dim), + dim=scatter_dim, + ) + + +@core.library.impl(lib, "fused_matmul_reduce_scatter", "Meta") +def _fused_matmul_reduce_scatter_fallback( + A: core.Tensor, + B: core.Tensor, + reduce_op: str, + scatter_dim: int, + group_name: str, +) -> core.Tensor: + res = funcol.reduce_scatter_tensor(A @ B, reduce_op, scatter_dim, group_name) + res = funcol.wait_tensor(res) + return res + + +@core.library.impl(lib, "fused_matmul_reduce_scatter", "CUDA") +def _fused_matmul_reduce_scatter( + A: core.Tensor, + B: core.Tensor, + reduce_op: str, + scatter_dim: int, + group_name: str, +) -> core.Tensor: + """ + Perform the following logic with micro-pipelined computation and + communication: + + reduce_scatter_tensor(A @ B, reduce_op, scatter_dim, group_name) + + Optimal stride order for A - if A.movedim(scatter_dim, 0) is contiguous, no + extra copy is required for input layout transformation. Otherwise A needs + to be copied once. + """ + if _is_test_mode: + return _fused_matmul_reduce_scatter_fallback( + A, B, reduce_op, scatter_dim, group_name + ) + + with core.profiler.record_function("fused_matmul_reduce_scatter"): + return _fused_matmul_reduce_scatter_impl( + mm_out_op=core.ops.aten.mm.out, + A=A, + B=B, + A_scale=None, + kwargs={}, + out_dtype=A.dtype, + reduce_op=reduce_op, + scatter_dim=scatter_dim, + group_name=group_name, + ) + + +@core.library.impl(lib, "fused_scaled_matmul_reduce_scatter", "Meta") +def _fused_scaled_matmul_reduce_scatter_fallback( + A: core.Tensor, + B: core.Tensor, + A_scale: core.Tensor, + B_scale: core.Tensor, + reduce_op: str, + scatter_dim: int, + group_name: str, + bias: Optional[core.Tensor] = None, + result_scale: Optional[core.Tensor] = None, + out_dtype: Optional[core.dtype] = None, + use_fast_accum: bool = False, +) -> core.Tensor: + if A_scale.numel() > 1: + if A_scale.shape[:-1] != A.shape[:-1]: + raise ValueError( + "For row-wise scaling, the leading dims of A_scale " + "must match the leading dims of A " + f"(A shape: {A.shape}, A_scale shape: {A_scale.shape})" + ) + A_scale = A_scale.flatten(0, -2).contiguous() + elif A_scale.numel() != 1: + raise ValueError( + "Invalid A_scale shape " + f"(A shape: {A.shape}, A_scale shape: {A_scale.shape})" + ) + + C = core._scaled_mm( + A.flatten(0, -2).contiguous(), + B, + A_scale, + B_scale, + bias, + result_scale, + out_dtype, + use_fast_accum, + ) + C = C.view(*A.shape[:-1], B.shape[1]) + res = funcol.reduce_scatter_tensor( + C, + reduce_op, + scatter_dim, + group_name, + ) + res = funcol.wait_tensor(res) + return res + + +@core.library.impl(lib, "fused_scaled_matmul_reduce_scatter", "CUDA") +def _fused_scaled_matmul_reduce_scatter( + A: core.Tensor, + B: core.Tensor, + A_scale: core.Tensor, + B_scale: core.Tensor, + reduce_op: str, + scatter_dim: int, + group_name: str, + bias: Optional[core.Tensor] = None, + result_scale: Optional[core.Tensor] = None, + out_dtype: Optional[core.dtype] = None, + use_fast_accum: bool = False, +) -> core.Tensor: + if _is_test_mode: + return _fused_scaled_matmul_reduce_scatter_fallback( + A, + B, + A_scale, + B_scale, + reduce_op, + scatter_dim, + group_name, + bias, + result_scale, + out_dtype, + use_fast_accum, + ) + with core.profiler.record_function("fused_matmul_reduce_scatter"): + return _fused_matmul_reduce_scatter_impl( + mm_out_op=core.ops.aten._scaled_mm.out, + A=A, + B=B, + A_scale=A_scale, + kwargs={ + "scale_b": B_scale, + "bias": bias, + "scale_result": result_scale, + "out_dtype": out_dtype, + "use_fast_accum": use_fast_accum, + }, + out_dtype=out_dtype, + reduce_op=reduce_op, + scatter_dim=scatter_dim, + group_name=group_name, + ) + + +def restride_A_for_fused_matmul_reduce_scatter( + t: core.Tensor, + scatter_dim: int, +) -> core.Tensor: + """ + Restride the `A_shard` arg of `fused_matmul_reduce_scatter` for optimal + perf. See the doc for `fused_matmul_reduce_scatter` for detail. + """ + perm = list(range(len(t.shape))) + perm.insert(0, perm.pop(scatter_dim)) + return make_contiguous_for_perm(t, perm) + + +def _maybe_convert_scalar_types_to_dtypes( + scalar_types: List[Any], +) -> List[Optional[core.dtype]]: + """ + When a list of `core.dtype`s is passed through the dispatcher as + `ScalarType[]`, it is converted to a list of scalar type enum values. This + function converts it back to a list of `core.dtype`s. + """ + # Order defined in https://github.com/pytorch/pytorch/blob/344defc9733a45fee8d0c4d3f5530f631e823196/c10/core/ScalarType.h + _SCALAR_TYPE_TO_DTYPE = { + 0: core.uint8, + 1: core.int8, + 2: core.short, + 3: core.int, + 4: core.int64, + 5: core.half, + 6: core.float, + 7: core.double, + 8: core.complex32, + 9: core.complex64, + 10: core.complex128, + 11: core.bool, + 12: core.qint8, + 13: core.quint8, + 14: core.qint32, + 15: core.bfloat16, + 16: core.float8_e5m2, + 17: core.float8_e4m3fn, + 18: core.float8_e5m2fnuz, + 19: core.float8_e4m3fnuz, + } + if any(not isinstance(x, (type(None), int)) for x in scalar_types): + return scalar_types + + dtypes: List[Optional[core.dtype]] = [] + for scalar_type in scalar_types: + if scalar_type is None: + dtypes.append(scalar_type) + elif scalar_type not in _SCALAR_TYPE_TO_DTYPE: + raise ValueError("Unrecognized scalar type {scalar_type}") + else: + dtypes.append(_SCALAR_TYPE_TO_DTYPE[scalar_type]) + return dtypes + + +class Work(_Work): + def __init__(self) -> None: + super().__init__() + self.event = core.cuda.Event() + self.event.record() + + def wait(self, timeout: timedelta = timedelta(seconds=0)) -> bool: + self.event.wait() + return True + + +""" +NOTE [low-contention collectives] +When a collective is overlapped with abundant compute, it makes sense to +prioritize reducing the contention between the collective and the overlapped +compute, even at the cost of a slightly slower collective. + +Common collective implementations (e.g., NCCL without user buffer +registration) optimize for throughput with no ambient compute. However, such +implementations may not be optimal when they are overlapped with compute: +- These implementations typically fuse the entire collective into a single +kernel and reserve SM resources based on the most demanding portion of the +collective, even when a large portion of the collective does not require this +much resource. +- These implementations often use SM-based P2P copy as opposed to copy +engine-based P2P copy. Copy engine-based P2P copy may not have a significant +advantage when there's no ambient compute. However, it may significantly +improve overall resource utilization in the presence of ambient compute. + +When overlapped with intensive compute (e.g., persistent matmul kernels), the +SM-usage of a collective can lead to inefficient overlapping. + +Low-contention collectives achieve their goals with the following strategies: +- Use copy engine-based copy whenever possible. +- Break down portions of a collective with different resource requirements +into multiple kernels. This improves the overlapping efficiency at the cost +of additional launching overhead. +""" + + +@core.library.impl(lib, "_low_contention_all_gather", "Meta") +def _low_contention_all_gather_meta( + tensor: core.Tensor, + group_name: str, +) -> core.Tensor: + group_size = c10d._get_group_size_by_name(group_name) + return tensor.new_empty(tensor.shape[0] * group_size, *tensor.shape[1:]) + + +@core.library.impl(lib, "_low_contention_all_gather", "CUDA") +def _low_contention_all_gather( + tensor: core.Tensor, + group_name: str, +) -> core.Tensor: + """ + Performs all-gather with symmetric memory in a low-contention fashion. + + When `tensor` is already in symmetric memory: + - The collective is carried out without using SMs. + - No symmetric memory workspace is required. + + When `tensor` is not in symmetric memory: + - An extra SM-based copy is performed to copy the input data into the + symmetric memory workspace. + - Symmetric memory workspace size requirement: the size of `tensor`. + """ + symm_mem = _SymmetricMemory.rendezvous(tensor) + if symm_mem is not None: + input_is_symm_mem = True + else: + symm_mem = get_symm_mem_workspace( + group_name, tensor.numel() * tensor.element_size() + ) + input_is_symm_mem = False + + rank = symm_mem.rank + world_size = symm_mem.world_size + + output = tensor.new_empty(tensor.shape[0] * world_size, *tensor.shape[1:]) + chunks = output.chunk(world_size) + + _get_backend_stream().wait_stream(core.cuda.current_stream()) + with core.cuda.stream(_get_backend_stream()): + if not input_is_symm_mem: + local_buf = symm_mem.get_buffer(rank, tensor.shape, tensor.dtype) + local_buf.copy_(tensor) + # pull + symm_mem.barrier() + for step in range(0, world_size): + remote_rank = (rank - step) % world_size + src_buf = symm_mem.get_buffer(remote_rank, tensor.shape, tensor.dtype) + chunks[remote_rank].copy_(src_buf) + symm_mem.barrier() + core._C._distributed_c10d._register_work(output, Work()) + return output + + +@core.library.impl(lib, "_low_contention_reduce_scatter", "Meta") +def _low_contention_reduce_scatter_meta( + tensor: core.Tensor, + reduce_op: str, + group_name: str, +) -> core.Tensor: + group_size = c10d._get_group_size_by_name(group_name) + return tensor.unflatten(0, (group_size, -1)).mean(dim=0) + + +def _low_contention_reduce_scatter_with_symm_mem_input( + tensor: core.Tensor, + reduce_op: str, + symm_mem: _SymmetricMemory, +) -> core.Tensor: + rank = symm_mem.rank + world_size = symm_mem.world_size + + assert tensor.shape[0] % world_size == 0 + a2a_res = core.empty_like(tensor) + chunks = a2a_res.chunk(world_size) + + _get_backend_stream().wait_stream(core.cuda.current_stream()) + with core.cuda.stream(_get_backend_stream()): + # pull + offline reduction + symm_mem.barrier() + for step in range(0, world_size): + remote_rank = (rank - step) % world_size + src_buf = symm_mem.get_buffer( + remote_rank, + chunks[0].shape, + chunks[0].dtype, + chunks[0].numel() * rank, + ) + chunks[remote_rank].copy_(src_buf) + symm_mem.barrier() + + ret = a2a_res.unflatten(0, (world_size, -1)) + if reduce_op == "sum": + ret = ret.sum(dim=0) + elif reduce_op == "avg": + ret = ret.mean(dim=0) + else: + raise ValueError(f"reduce_op ({reduce_op}) is not supported") + core._C._distributed_c10d._register_work(ret, Work()) + return ret + + +def _low_contention_reduce_scatter_with_workspace( + tensor: core.Tensor, + reduce_op: str, + workspace: _SymmetricMemory, +) -> core.Tensor: + rank = workspace.rank + world_size = workspace.world_size + + assert tensor.shape[0] % world_size == 0 + chunks = tensor.chunk(world_size) + + _get_backend_stream().wait_stream(core.cuda.current_stream()) + with core.cuda.stream(_get_backend_stream()): + # push + offline reduction + workspace.barrier() + for step in range(0, world_size): + remote_rank = (rank - step) % world_size + dst_buf = workspace.get_buffer( + remote_rank, chunks[0].shape, chunks[0].dtype, chunks[0].numel() * rank + ) + dst_buf.copy_(chunks[remote_rank]) + workspace.barrier() + + buf = workspace.get_buffer(rank, tensor.shape, tensor.dtype) + ret = buf.unflatten(0, (world_size, -1)) + if reduce_op == "sum": + ret = ret.sum(dim=0) + elif reduce_op == "avg": + ret = ret.mean(dim=0) + else: + raise ValueError(f"reduce_op ({reduce_op}) is not supported") + core._C._distributed_c10d._register_work(ret, Work()) + return ret + + +@core.library.impl(lib, "_low_contention_reduce_scatter", "CUDA") +def _low_contention_reduce_scatter( + tensor: core.Tensor, + reduce_op: str, + group_name: str, +) -> core.Tensor: + """ + Performs reduce-scatter with symmetric memory in a low-contention fashion. + + This implementation performs a P2P-based all-to-all followed by an offline + reduction. + + When `tensor` is already in symmetric memory: + - Pull-based all-to-all is used. + - No symmetric memory workspace is required. + + When `tensor` is not in symmetric memory: + - Push-based all-to-all is used. + - Symmetric memory workspace size requirement: the size of `tensor`. + + SM-usage: + - SM-based copy of the rank's own chunk for the all-to-all. + - Reduction on the all-to-all result. + + TODO(yifu): the SM-based copy can be avoided with a list-based reduction + kernel. + """ + symm_mem = _SymmetricMemory.rendezvous(tensor) + if symm_mem is not None: + return _low_contention_reduce_scatter_with_symm_mem_input( + tensor, reduce_op, symm_mem + ) + else: + workspace = get_symm_mem_workspace( + group_name, tensor.numel() * tensor.element_size() + ) + return _low_contention_reduce_scatter_with_workspace( + tensor, reduce_op, workspace + ) + + +# ============================================================================= +# User-facing APIs +# ============================================================================= + + +from typing import Any, overload, Sequence, TYPE_CHECKING, Union + +from core.types import _device, _dtype, _int + + +if TYPE_CHECKING: + from ..c10d import ProcessGroup + + +@overload +def empty( + *size: _int, dtype: Optional[_dtype] = None, device: Optional[_device] = None +) -> core.Tensor: + ... + + +@overload +def empty( + size: Sequence[_int], + *, + dtype: Optional[_dtype] = None, + device: Optional[_device] = None, +) -> core.Tensor: + ... + + +def empty( # type: ignore[misc] + *size: Any, + dtype: Optional[_dtype] = None, + device: Optional[_device] = None, +) -> core.Tensor: + r""" + empty(*size, *, dtype=None, device=None) -> Tensor + + Similar to :func:`core.empty()`. The returned tensor can be used by + :func:`core._distributed._symmetric_memory.rendezvous()` to establish a + symmetric memory tensor among participating processes. + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + dtype (:class:`core.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`core.set_default_dtype`). + device (:class:`core.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`core.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + """ + if len(size) == 1 and isinstance(size[0], Sequence): + size = tuple(size[0]) + else: + size = tuple(size) + + if dtype is None: + dtype = core.get_default_dtype() + + if device is None: + device = core.get_default_device() + + return _SymmetricMemory.empty_strided_p2p( + size=size, + stride=core._prims_common.make_contiguous_strides_for(size), + dtype=dtype, + device=core.device(device), + ) + + +def rendezvous( + tensor: core.Tensor, group: Union[str, "ProcessGroup"] +) -> _SymmetricMemory: + r""" + rendezvous(tensor, group) -> _SymmetricMemory + + Establish a symmetric memory tensor among participating processes. This is + a collective operation. + + Args: + tensor (:class:`core.Tensor`): the local tensor used to establish the symmetric memory tensor. + It must be allocated via :func:`core._distributed._symmetric_memory.empty()`. The shape, + dtype, and device type must be identical across all participating processes. + group (Union[str, :class:`core.distributed.ProcessGroup`]): The group identifying the + participating processes. This can be either a group name or a process group object. + """ + from ..c10d import ProcessGroup + + if isinstance(group, str): + group_name = group + elif isinstance(group, ProcessGroup): + group_name = group.group_name + else: + raise TypeError(f"rendezvous: unsupported group type: {type(group)}") + + enable_symm_mem_for_group(group_name) + return _SymmetricMemory.rendezvous(tensor, group_name) + + +__all__ = ["empty", "rendezvous"] diff --git a/mindnlp/core/distributed/_tensor/__init__.py b/mindnlp/core/distributed/_tensor/__init__.py new file mode 100644 index 000000000..c543f8624 --- /dev/null +++ b/mindnlp/core/distributed/_tensor/__init__.py @@ -0,0 +1,44 @@ +""" +NOTICE: DTensor has moved to core.distributed.tensor + +This file is a shim to redirect to the new location, and +we keep the old import path starts with `_tensor` for +backward compatibility. We will remove this folder once +we resolve all the BC issues. +""" +import sys +from importlib import import_module + + +submodules = [ + # TODO: _shards_wrapper/_utils here mainly for checkpoint BC, remove them + "_shards_wrapper", + "_utils", + "experimental", + "device_mesh", +] + +# Redirect imports +for submodule in submodules: + full_module_name = f"core.distributed.tensor.{submodule}" + sys.modules[f"core.distributed._tensor.{submodule}"] = import_module( + full_module_name + ) + +from core.distributed.tensor import ( # noqa: F401 + DeviceMesh, + distribute_module, + distribute_tensor, + DTensor, + empty, + full, + init_device_mesh, + ones, + Partial, + Placement, + rand, + randn, + Replicate, + Shard, + zeros, +) diff --git a/mindnlp/core/distributed/_tensor/api.py b/mindnlp/core/distributed/_tensor/api.py new file mode 100644 index 000000000..8028ebecc --- /dev/null +++ b/mindnlp/core/distributed/_tensor/api.py @@ -0,0 +1,9 @@ +""" +NOTE: core.distributed._tensor has been moved to core.distributed.tensor. +The imports here are purely for backward compatibility. We will remove these +imports in a few releases + +TODO: throw warnings when this module imported +""" + +from core.distributed.tensor._api import * # noqa: F401, F403 diff --git a/mindnlp/core/distributed/_tensor/placement_types.py b/mindnlp/core/distributed/_tensor/placement_types.py new file mode 100644 index 000000000..d4c09b9fd --- /dev/null +++ b/mindnlp/core/distributed/_tensor/placement_types.py @@ -0,0 +1,10 @@ +""" +NOTE: core.distributed._tensor has been moved to core.distributed.tensor. +The imports here are purely for backward compatibility. We will remove these +imports in a few releases + +TODO: throw warnings when this module imported +""" + +from core.distributed.tensor._dtensor_spec import * # noqa: F401, F403 +from core.distributed.tensor.placement_types import * # noqa: F401, F403 diff --git a/mindnlp/core/distributed/_tools/__init__.py b/mindnlp/core/distributed/_tools/__init__.py new file mode 100644 index 000000000..284b6180a --- /dev/null +++ b/mindnlp/core/distributed/_tools/__init__.py @@ -0,0 +1,12 @@ +from .fsdp2_mem_tracker import FSDPMemTracker +from .mem_tracker import MemTracker +from .memory_tracker import MemoryTracker +from .mod_tracker import ModTracker +from .runtime_estimator import RuntimeEstimator +from .sac_estimator import ( + MSPS, + SACEstimator, + SACGreedyOrderMeta, + SACStats, + SACTradeOffStats, +) diff --git a/mindnlp/core/distributed/_tools/fsdp2_mem_tracker.py b/mindnlp/core/distributed/_tools/fsdp2_mem_tracker.py new file mode 100644 index 000000000..1ee400c7f --- /dev/null +++ b/mindnlp/core/distributed/_tools/fsdp2_mem_tracker.py @@ -0,0 +1,610 @@ +from copy import deepcopy +from datetime import timedelta +from functools import partial, wraps +from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple, Type, Union + +from mindnlp import core +from mindnlp import core.distributed as dist +from mindnlp.core import nn, optim +from core._guards import active_fake_mode +from core.distributed._composable.fsdp import FSDPModule +from core.distributed._composable.fsdp._fsdp_param_group import FSDPParamGroup +from core.distributed._tools.mem_tracker import _RefType, _State, MemTracker +from core.distributed.distributed_c10d import ( + _IllegalWork, + ProcessGroup, + ReduceOp, + Work, +) +from core.futures import Future +from core.utils._python_dispatch import TorchDispatchMode +from core.utils._pytree import tree_map_only +from core.utils.weak import WeakIdKeyDictionary, weakref + + +_TOTAL_KEY = "Total" + +__all__ = ["FSDPMemTracker"] + + +class _FSDPRefType(_RefType): + """ + Enumerates categories of memory usage in FSDP modules, including parameters, gradients, activations, + and optimizer states. + + Attributes: + SHARDED_PARAM (str): Memory usage of sharded parameters. + UNSHARDED_PARAM (str): Memory usage of unsharded parameters. + SHARDED_GRAD (str): Memory usage of sharded gradients corresponding to the sharded parameters. + UNSHARDED_GRAD (str): Memory usage of unsharded gradients corresponding to the unsharded parameters. + ACT (str): Memory usage of activations and tensors from forward and AC recomputation. + TEMP (str): Memory usage of temporary tensors during the backward pass including gradients of activations. + ALL_GATHER (str): Memory usage of all_gather output tensor. + REDUCE_SCATTER (str): Memory usage of reduce_scatter input tensor. + OPT (str): Memory usage of tensors storing optimizer states. + INP (str): Memory usage of input tensors. + """ + + SHARDED_PARAM = "Sharded Param" + UNSHARDED_PARAM = "Unsharded Param" + BUFFER = "Buffer" + SHARDED_GRAD = "Sharded Grad" + UNSHARDED_GRAD = "Unsharded Grad" + ACT = "Activation" + TEMP = "Temp" + ALL_GATHER = "All Gather" + REDUCE_SCATTER = "Reduce Scatter" + OPT = "OptState" + INP = "Inputs" + + +class _SavedFSDPMethods(NamedTuple): + pre_backward: Callable + post_backward: Callable + + +class _SavedCollectives(NamedTuple): + all_gather_into_tensor: Callable + reduce_scatter_tensor: Callable + all_reduce: Callable + barrier: Callable + + +class _FSDPModState(_State): + """ + Enumerates the states of FSDP modules during the forward and backward passes. + """ + + BEF_PRE_FW = "Before Pre-Forward" + AFT_PRE_FW = "After Pre-Forward" + BEF_POST_FW = "Before Post-Forward" + AFT_POST_FW = "After Post-Forward" + BEF_PRE_BW = "Before Pre-Backward" + AFT_PRE_BW = "After Pre-Backward" + BEF_POST_BW = "Before Post-Backward" + AFT_POST_BW = "After Post-Backward" + PRE_FW_AC = "Pre-Forward AC" + POST_FW_AC = "Post-Forward AC" + PEAK_FW = "Peak Forward" + PEAK_BW = "Peak Backward" + + +class _FSDPModMemStats: + """ + A class to store the memory statistics of an FSDP module. + + Args: + mod_fqn (str): The fully qualified name of the FSDP module. + + Attributes: + snapshots (Dict[_FSDPModState, Dict[core.device, Dict[str, int]]]): A dictionary of memory snapshots + of the module at different states as defined by ``_FSDPModState``. Each key is a device, and + each value is another dictionary with keys as memory reference types defined by ``_FSDPRefType`` and + values as the memory consumed in bytes. + + """ + + def __init__(self, mod_fqn: str) -> None: + self.mod_fqn = mod_fqn + self.local_peak: Dict[core.device, int] = {} + self.snapshots: Dict[ + _FSDPModState, List[Dict[core.device, Dict[str, int]]] + ] = {} + + +class FSDPMemTracker(MemTracker): + """ + A ``TorchDispatchMode`` based context manager that extends ``core.distributed._tools.mem_tracker.MemTracker`` to track + and categorize the peak memory and module-wise memory usage of FSDP modules. + + It tracks the peak memory usage across all the devices of all the FSDP modules in the module tree and categorizes + the tensor memory usage as defined by ``_FSDPRefType``. Further, it captures memory `snapshots` at different stages of + the module execution defined by ``_FSDPModState``. + + Attributes: + memory_tracking: A weakref key dictionary to store the memory statistics of each module. Each key is a reference + to a module, and each value is a ``_FSDPModMemStats`` object that stores the memory statistics of the module. + + Args: + mod (core.nn.Module): The root FSDP module to be tracked. + optm (core.optim.Optimizer, optional): The optimizer to be tracked. + + Note: Please refer to ``core.distributed._tools.mem_tracker.MemTracker`` to learn about the limitations. + + Example usage + + .. code-block:: python + + module = ... + optimizer = ... + inp = ... + fmt = FSDPMemTracker(module, optimizer) + fmt.track_inputs((inp,)) + with fmt: + optimizer.zero_grad() + loss = module(inp) + print("After Forward:") + fmt.display_snapshot("current") + loss.backward() + optimizer.step() + fmt.display_snapshot("peak") + fmt.display_modulewise_snapshots(depth = 3, units = "MB") + + """ + + def __init__( + self, + mod: core.nn.Module, + optm: Optional[core.optim.Optimizer] = None, + ) -> None: + super().__init__() + assert isinstance(mod, FSDPModule), "FSDPMemTracker only supports FSDP modules" + self._root_mod = mod + self._optm = optm + self._in_fake_mode: bool = False + self._fsdp_mod_to_saved_methods: WeakIdKeyDictionary = WeakIdKeyDictionary() + self._saved_collectives: _SavedCollectives + self._ref_class: Type[_RefType] = _FSDPRefType + + def _instrument_fsdp_sharded_params_grads( + self, fsdp_param_group: FSDPParamGroup + ) -> None: + # Track sharded params and grads after initilization + for fsdp_param in fsdp_param_group.fsdp_params: + self._update_and_maybe_create_winfos( + fsdp_param.sharded_param, + _FSDPRefType.SHARDED_PARAM, + ) + sharded_grad = fsdp_param.sharded_param.grad + if sharded_grad is not None: + self._update_and_maybe_create_winfos( + sharded_grad, + _FSDPRefType.SHARDED_GRAD, + ) + + def _fsdp_state_pre_forward( + self, + fsdp_mod: FSDPModule, + orig_fsdp_state_pre_fw: Callable, + ) -> Callable: + # We capture memory snapshots before and after ``FSDPState._pre_forward`` to attribute the `unsharded` params + # and `all_gather` buffers. There are three cases: + # Case 1: If the module is not in the ``memory_tracking`` dictionary, create a new ``_FSDPModMemStats`` + # instance for the module and add it to the ``memory_tracking`` dictionary. + # Case 2: If the module is already in the ``memory_tracking`` dictionary and we are in backward, this means + # we are in the AC region. We check if this is the top most module in the AC region. If it is, + # we store a weak reference and set the flag ``_in_ac`` to True. + # Case 3: If the module is already in the ``memory_tracking`` dictionary and we are in forward, this means + # this module is called for the second time. If it is a root module, that means we are in the next + # iteration and we error out. If it is not a root module, that means it's a submodule that is being + # used multiple times in the same iteration, which we allow and track. + # For Case 1 and 3, we also initialiaze the ``local_peak`` and ``PEAK_FW`` snapshot for the module. + # For Case 2 we only capture 1 snapshot after ``FSDPState._pre_forward`` runs because it is a no-op. + @wraps(orig_fsdp_state_pre_fw) + def inner(*args: Any, **kwargs: Any) -> Tuple[Tuple[Any, ...], Dict[str, Any]]: + mod_fqn = self._mod_tracker.get_known_fqn(fsdp_mod) + assert mod_fqn is not None + if fsdp_mod not in self.memory_tracking: + mod_stat = _FSDPModMemStats(mod_fqn) + self.memory_tracking[fsdp_mod] = mod_stat + snapshot = self.get_tracker_snapshot() + mod_stat.local_peak = { + dev: dev_snap[_TOTAL_KEY] for dev, dev_snap in snapshot.items() + } + mod_stat.snapshots.setdefault(_FSDPModState.PEAK_FW, []).append( + snapshot + ) + mod_stat.snapshots.setdefault(_FSDPModState.BEF_PRE_FW, []).append( + deepcopy(snapshot) + ) + elif not self._mod_tracker.is_bw: + parents = self._mod_tracker.parents - {mod_fqn} + if len(parents) == 1 and "Global" in parents: + raise NotImplementedError( + "FSDPMemTracker does not support memory tracking for multiple iterative calls." + " Either use ``reset_mod_stats`` to clear module memory stats for the previous iteration" + " or file a github issue if you need this feature." + ) + + args, kwargs = orig_fsdp_state_pre_fw(*args, **kwargs) + + fsdp_state = fsdp_mod._get_fsdp_state() + if fsdp_param_group := fsdp_state._fsdp_param_group: + for fsdp_param in fsdp_param_group.fsdp_params: + self._update_and_maybe_create_winfos( + fsdp_param.unsharded_param, + _FSDPRefType.UNSHARDED_PARAM, + ) + mod_stat = self.memory_tracking[fsdp_mod] + if self._mod_tracker.is_bw: + state = _FSDPModState.PRE_FW_AC + if self._ac_mod is None: + self._ac_mod = weakref.ref(fsdp_mod) + self._in_ac = True + else: + state = _FSDPModState.AFT_PRE_FW + mod_stat.snapshots.setdefault(state, []).append(self.get_tracker_snapshot()) + return args, kwargs + + return inner + + def _fsdp_state_post_forward( + self, + fsdp_mod: FSDPModule, + orig_fsdp_state_post_fw: Callable, + ) -> Callable: + # We capture memory snapshots before and after ``FSDPState._post_forward`` to capture the resharded state + # if ``reshard_after_forward`` is not ``False``. There are two cases: + # Case 1: This is called in backward, which means we are in the AC region. If this is the top most module + # in the AC region, we set the flag ``_in_ac`` to False. + # Case 2: This is called in forward. + @wraps(orig_fsdp_state_post_fw) + def inner(*args: Any, **kwargs: Any) -> Any: + mod_stat = self.memory_tracking[fsdp_mod] + if self._mod_tracker.is_bw: + state = _FSDPModState.POST_FW_AC + if self._ac_mod is not None and self._ac_mod() is fsdp_mod: + self._ac_mod = None + self._in_ac = False + else: + state = _FSDPModState.BEF_POST_FW + mod_stat.snapshots.setdefault(state, []).append(self.get_tracker_snapshot()) + + output = orig_fsdp_state_post_fw(*args, **kwargs) + + if not self._mod_tracker.is_bw: + mod_stat.snapshots.setdefault(_FSDPModState.AFT_POST_FW, []).append( + self.get_tracker_snapshot() + ) + return output + + return inner + + def _fsdp_param_group_pre_backward( + self, + fsdp_mod: FSDPModule, + orig_fsdp_param_group_pre_backward: Callable, + ) -> Callable: + # We capture memory snapshots before and after ``FSDPParamGroup.pre_backward`` to capture the pre-fetching + # and unsharding of params. We also initialize ``local_peak`` and ``PEAK_BW`` snapshot for the module. + @wraps(orig_fsdp_param_group_pre_backward) + def inner(*args: Any, **kwargs: Any) -> None: + mod_stat = self.memory_tracking[fsdp_mod] + snapshot = self.get_tracker_snapshot() + mod_stat.local_peak = { + dev: dev_snap[_TOTAL_KEY] for dev, dev_snap in snapshot.items() + } + mod_stat.snapshots.setdefault(_FSDPModState.PEAK_BW, []).append(snapshot) + mod_stat.snapshots.setdefault(_FSDPModState.BEF_PRE_BW, []).append( + deepcopy(snapshot) + ) + orig_fsdp_param_group_pre_backward(*args, **kwargs) + + mod_stat.snapshots.setdefault(_FSDPModState.AFT_PRE_BW, []).append( + self.get_tracker_snapshot() + ) + + return inner + + def _fsdp_param_group_post_backward( + self, + fsdp_mod: FSDPModule, + orig_fsdp_param_group_post_backward: Callable, + ) -> Callable: + # We capture the memory snapshots before and after ``FSDPParamGroup.post_backward`` to track and attribute + # the `unsharded` grads before the post backward and then `sharded` grads and `reduce_scatter` buffers + # after the post backward. + @wraps(orig_fsdp_param_group_post_backward) + def inner(*args: Any, **kwargs: Any) -> None: + fsdp_state = fsdp_mod._get_fsdp_state() + if fsdp_param_group := fsdp_state._fsdp_param_group: + for fsdp_param in fsdp_param_group.fsdp_params: + unsharded_grad = fsdp_param._unsharded_param.grad + if unsharded_grad is not None: + self._update_and_maybe_create_winfos( + unsharded_grad, + _FSDPRefType.UNSHARDED_GRAD, + update_existing=True, + ) + + mod_stat = self.memory_tracking[fsdp_mod] + mod_stat.snapshots.setdefault(_FSDPModState.BEF_POST_BW, []).append( + self.get_tracker_snapshot() + ) + + orig_fsdp_param_group_post_backward(*args, **kwargs) + + if fsdp_param_group := fsdp_state._fsdp_param_group: + for fsdp_param in fsdp_param_group.fsdp_params: + sharded_grad = fsdp_param.sharded_param.grad + if sharded_grad is not None: + self._update_and_maybe_create_winfos( + sharded_grad, + _FSDPRefType.SHARDED_GRAD, + ) + + mod_stat.snapshots.setdefault(_FSDPModState.AFT_POST_BW, []).append( + self.get_tracker_snapshot() + ) + + return inner + + def _instrument_fsdp_module(self) -> None: + # We uninstall the existing `FSDPState._pre_forward` and `FSDPState._post_forward` hooks and install + # our own hooks that wrap them. We choose this over monkey-patching `FSDPParamGroup.pre_forward` and + # `FSDPParamGroup.post_forward` because during AC these won't be called. + # TODO(@sanketpurandare): This will need to be modified after this PR (https://github.com/pytorch/pytorch/pull/127786) + # lands. For backward we monkey-patch the `FSDPParamGroup.pre_backward` and `FSDPParamGroup.post_backward`. + for module in self._root_mod.modules(): + if isinstance(module, FSDPModule): + fsdp_state = module._get_fsdp_state() + if fsdp_param_group := fsdp_state._fsdp_param_group: + self._instrument_fsdp_sharded_params_grads(fsdp_param_group) + fsdp_state._pre_forward_hook_handle.remove() + fsdp_state._post_forward_hook_handle.remove() + fsdp_state._pre_forward_hook_handle = ( + module.register_forward_pre_hook( + self._fsdp_state_pre_forward( + module, fsdp_state._pre_forward + ), + prepend=True, + with_kwargs=True, + ) + ) + fsdp_state._post_forward_hook_handle = module.register_forward_hook( + self._fsdp_state_post_forward(module, fsdp_state._post_forward), + prepend=False, + always_call=True, + ) + self._fsdp_mod_to_saved_methods[module] = _SavedFSDPMethods( + fsdp_param_group.pre_backward, + fsdp_param_group.post_backward, + ) + fsdp_param_group.pre_backward = self._fsdp_param_group_pre_backward( # type: ignore[assignment] + module, fsdp_param_group.pre_backward + ) + fsdp_param_group.post_backward = ( # type: ignore[assignment] + self._fsdp_param_group_post_backward( + module, fsdp_param_group.post_backward + ) + ) + + for buffer in self._root_mod.buffers(): + self._update_and_maybe_create_winfos( + buffer, + _FSDPRefType.BUFFER, + ) + + def _instrument_optimizer(self) -> None: + # Register a hook on the optimizer step to track the optimizer states. + # The pre-hook is to set the flag ``_in_opt`` to True. The post-hook unsets the flag, + # and also tracks any optimizer states that are created during the optimizer step. + if self._optm is not None: + self._track_optimizer_states(_FSDPRefType.OPT, self._optm) + + def _opt_step_pre_hook( + optimizer: optim.Optimizer, args: Any, kwargs: Any + ) -> None: + self._in_opt = True + + def _opt_step_post_hook( + optimizer: optim.Optimizer, args: Any, kwargs: Any + ) -> None: + self._track_optimizer_states(_FSDPRefType.OPT, optimizer) + self._in_opt = False + + self._optimizer_hook_handles = ( + self._optm.register_step_pre_hook(_opt_step_pre_hook), + self._optm.register_step_post_hook(_opt_step_post_hook), + ) + + def _register_module_and_optimizer_hooks(self) -> None: + self._instrument_fsdp_module() + self._instrument_optimizer() + + def _deregister_module_and_optimizer_hooks(self) -> None: + for ( + fsdp_mod, + saved_methods, + ) in self._fsdp_mod_to_saved_methods.items(): + fsdp_state = fsdp_mod._get_fsdp_state() + fsdp_state._pre_forward_hook_handle.remove() + fsdp_state._post_forward_hook_handle.remove() + fsdp_state._pre_forward_hook_handle = fsdp_mod.register_forward_pre_hook( + fsdp_state._pre_forward, prepend=True, with_kwargs=True + ) + fsdp_state._post_forward_hook_handle = fsdp_mod.register_forward_hook( + fsdp_state._post_forward, prepend=False + ) + if fsdp_param_group := fsdp_state._fsdp_param_group: + fsdp_param_group.pre_backward = saved_methods.pre_backward + fsdp_param_group.post_backward = saved_methods.post_backward + self._fsdp_mod_to_saved_methods.clear() + + if self._optimizer_hook_handles is not None: + for handle in self._optimizer_hook_handles: + handle.remove() + self._optimizer_hook_handles = None + + def _instrument_and_maybe_bypass_collectives(self) -> None: + # Monkey-patching collectives is required because they do not work with `FakeTensorMode` + # It's also easier to track `all_gather` and `reduce_scatter` buffers faithfully. + self._saved_collectives = _SavedCollectives( + dist.all_gather_into_tensor, + dist.reduce_scatter_tensor, + dist.all_reduce, + dist.barrier, + ) + + class FakeWork(Work): + def __init__(self) -> None: + super().__init__() + + def get_future(self) -> Future: + future: Future = Future() + future.set_result(None) + return future + + def wait(self, timeout: Optional[timedelta] = None) -> bool: + return True + + @wraps(dist.all_gather_into_tensor) + def all_gather_into_tensor( + output_tensor: core.Tensor, + input_tensor: core.Tensor, + group: Union[ProcessGroup, None] = None, + async_op: bool = False, + ) -> Union[Work, _IllegalWork, None]: + self._update_and_maybe_create_winfos( + output_tensor, + _FSDPRefType.ALL_GATHER, + update_existing=True, + ) + + if self._in_fake_mode: + if async_op: + return FakeWork() + return None + else: + return self._saved_collectives.all_gather_into_tensor( + output_tensor, input_tensor, group, async_op + ) + + @wraps(dist.reduce_scatter_tensor) + def reduce_scatter_tensor( + output: core.Tensor, + input: core.Tensor, + op: ReduceOp.RedOpType = dist.ReduceOp.SUM, + group: Union[ProcessGroup, None] = None, + async_op: bool = False, + ) -> Union[Work, _IllegalWork, None]: + self._update_and_maybe_create_winfos( + input, + _FSDPRefType.REDUCE_SCATTER, + update_existing=True, + ) + + if self._in_fake_mode: + if async_op: + return FakeWork() + return None + else: + return self._saved_collectives.reduce_scatter_tensor( + output, input, op, group, async_op + ) + + @wraps(dist.all_reduce) + def all_reduce( + tensor: core.Tensor, + op: ReduceOp.RedOpType = dist.ReduceOp.SUM, + group: Union[ProcessGroup, None] = None, + async_op: bool = False, + ) -> Union[Work, _IllegalWork, None]: + if self._in_fake_mode: + if async_op: + return FakeWork() + return None + else: + return self._saved_collectives.all_reduce(tensor, op, group, async_op) + + @wraps(dist.barrier) + def barrier( + group: Union[ProcessGroup, None] = dist.GroupMember.WORLD, + async_op: bool = False, + device_ids: Union[List[int], None] = None, + ) -> Union[Work, None]: + if self._in_fake_mode: + return None + else: + return self._saved_collectives.barrier(group, async_op, device_ids) + + dist.all_gather_into_tensor = all_gather_into_tensor + dist.reduce_scatter_tensor = reduce_scatter_tensor + dist.all_reduce = all_reduce + dist.barrier = barrier + + def _restore_collectives(self) -> None: + dist.all_gather_into_tensor = self._saved_collectives.all_gather_into_tensor + dist.reduce_scatter_tensor = self._saved_collectives.reduce_scatter_tensor + dist.all_reduce = self._saved_collectives.all_reduce + dist.barrier = self._saved_collectives.barrier + del self._saved_collectives + + def track_inputs(self, inputs: Tuple[Any, ...]) -> None: + """ + This is used to track the input tensors to the model and annotate them as ``Inputs``. + Args: + inputs (Tuple[Any]): A tuple containing the input data. This can include tensors + as well as other data types. Only tensors will be tracked. + """ + + def _track_inputs(t: core.Tensor) -> None: + self._update_and_maybe_create_winfos( + t, + _FSDPRefType.INP, + ) + + tree_map_only(core.Tensor, _track_inputs, inputs) + + def track_external( + self, *external: Union[nn.Module, optim.Optimizer, core.Tensor] + ) -> None: + """This is no-op for ``FSDPMemTracker``""" + + def __enter__(self) -> "FSDPMemTracker": + self._in_fake_mode = True if active_fake_mode() else False + self._register_module_and_optimizer_hooks() + self._instrument_and_maybe_bypass_collectives() + self._track_resize() + self._peak_mem_snap = self.get_tracker_snapshot() + self._peak_mem = { + dev: dev_snap[_TOTAL_KEY] for dev, dev_snap in self._peak_mem_snap.items() + } + self._mod_tracker.__enter__() + TorchDispatchMode.__enter__(self) + return self + + def __exit__(self, *args: Any) -> None: + self._deregister_module_and_optimizer_hooks() + self._restore_collectives() + self._restore_resize() + TorchDispatchMode.__exit__(self, *args) + self._mod_tracker.__exit__(*args) + + def __torch_dispatch__(self, func, types, args=..., kwargs=None): # type: ignore[no-untyped-def] + res = func(*args, **kwargs or {}) + # If we are tracking an optimizer state, we use the optimizer reference type. + # If we are in backward region and not in AC region, we use the backward reference type. + # Else we use the forward reference type. + if self._in_opt: + reftype = _FSDPRefType.OPT + elif self._mod_tracker.is_bw and not self._in_ac: + reftype = _FSDPRefType.TEMP + else: + reftype = _FSDPRefType.ACT + tree_map_only(core.Tensor, partial(self._track, reftype), res) + peak_state = ( + _FSDPModState.PEAK_BW if self._mod_tracker.is_bw else _FSDPModState.PEAK_FW + ) + self._update_peak_stats(peak_state) + return res diff --git a/mindnlp/core/distributed/_tools/ilp_utils.py b/mindnlp/core/distributed/_tools/ilp_utils.py new file mode 100644 index 000000000..7fbfa8e31 --- /dev/null +++ b/mindnlp/core/distributed/_tools/ilp_utils.py @@ -0,0 +1,291 @@ +import copy +from typing import cast, Dict, List, OrderedDict, Tuple, TypedDict + +import numpy as np + +from mindnlp import core +from core.distributed._tools.mem_tracker import ( + _MemRefType, + _ModMemStats, + _ModState, + MemTracker, +) +from core.distributed._tools.runtime_estimator import RuntimeEstimator +from core.distributed._tools.sac_estimator import SACEstimator, SACTradeOffStats + + +class ModOrder(TypedDict): + fw_pre_order: List[str] + bw_pre_order: List[str] + fw_post_order: List[str] + bw_post_order: List[str] + + +class ModRuntime(TypedDict): + fw: float + bw: float + + +class ModStats(TypedDict): + fqn: str + # per-module params + param_per_module: int + # per-module grads + grad_per_module: int + # total accumulated gradients up to and including this module + grad_total: int + # per module fw activation size (excluding input and output) + act_fw_per_module: int + # per module bw activation size during peak_bw + act_bw_per_module: int + # per module activation grad size during peak_bw + act_grad_per_module: int + # total activation size up to but excluding the current module + # includes input of the current module (i.e., output of previous module) + act_total: int + # Inputs to the module + input_per_module: int + # Outputs of the module + output_per_module: int + # Total fw run-time of the module + fw_runtime_per_module: float + # Total bw run-time of the module + bw_runtime_per_module: float + # Is this module a leaf module + is_leaf: bool + # Total ac run-time of the module + sac_runtime: float + # Total ac_memory for the module + sac_memory: int + # Number of piecewise-linear functions used for approximating ac tradeoff curve + n_segments: int + # Slopes of the of piecewise-linear functions + slopes: List[float] + # Intercepts of the of piecewise-linear functions + intercepts: List[float] + # X breakpoints of the of piecewise-linear functions + breakpoints: List[float] + # Original trade-off curves + tradeoff_curve: OrderedDict[float, float] + + +class ModuleInfo(TypedDict): + mod_order: ModOrder + mod_stats: List[ModStats] + + +def aggregate_stats( + model: core.nn.Module, + mem_tracker: MemTracker, + runtime_estimator: RuntimeEstimator, + sac_estimator: SACEstimator, + dev: core.device, +) -> ModuleInfo: + """ + Collect modulewise stats for a given model, including memory, runtime, and AC tradeoff stats. + + Args: + model: nn.Module object + runtime_estimator: RuntimeEstimator object with runtime stats + mem_tracker: MemTracker object with memory stats + sac_estimator: SACEstimator object with AC tradeoff stats + dev: device the model was run on (used to extract memory stats from MemTracker) + + Returns: + ModuleInfo: A dictionary with module order and module stats. + """ + + # Memory stats + mod_mem_stats: Dict[core.nn.Module, _ModMemStats] = dict( + copy.deepcopy(mem_tracker.memory_tracking) + ) + + # Runtime stats + mod_runtime_stats: Dict[str, ModRuntime] = { + fqn: {"fw": v["fw"], "bw": v["bw"]} + for fqn, v in runtime_estimator.mod_runtimes.items() + } + + # Module order + mod_order: ModOrder = { + "fw_pre_order": list(runtime_estimator.mod_fw_pre_order), + "bw_pre_order": list(runtime_estimator.mod_bw_pre_order), + "fw_post_order": list(runtime_estimator.mod_fw_post_order), + "bw_post_order": list(runtime_estimator.mod_bw_post_order), + } + + # Selective Activation Checkpointing stats + sac_estimator.pwlf_sac_tradeoff_curve() + mod_sac_tradeoff_stats: Dict[str, SACTradeOffStats] = copy.deepcopy( + sac_estimator.sac_mod_tradeoff_stats + ) + + module_info: ModuleInfo = { + "mod_order": mod_order, + "mod_stats": [], + } + + for mod in model.modules(): + if mod_mem_stat := mod_mem_stats.get(mod, None): + if tradeoff_stats := mod_sac_tradeoff_stats.get(mod_mem_stat.mod_fqn, None): + sac_runtime = tradeoff_stats.sac_runtime + sac_memory = tradeoff_stats.sac_memory + n_segments = tradeoff_stats.n_segments + slopes = tradeoff_stats.slopes + intercepts = tradeoff_stats.intercepts + breakpoints = tradeoff_stats.fit_breaks + tradeoff_curve = tradeoff_stats.tradeoff_curve + is_leaf = False + else: + sac_runtime = sac_memory = n_segments = 0 + slopes = intercepts = breakpoints = [] + tradeoff_curve: OrderedDict[float, float] = OrderedDict() # type: ignore[no-redef] + is_leaf = True + mod_stat: ModStats = { + "fqn": mod_mem_stat.mod_fqn, + "param_per_module": mod_mem_stat.parameter_mem, + "grad_per_module": mod_mem_stat.parameter_mem, + "grad_total": mod_mem_stat.snapshots[_ModState.PRE_BW][-1][dev][ + _MemRefType.GRAD + ], + "act_fw_per_module": max( + 0, + mod_mem_stat.snapshots[_ModState.POST_FW][-1][dev][_MemRefType.ACT] + - mod_mem_stat.snapshots[_ModState.PRE_FW][-1][dev][_MemRefType.ACT] + - mod_mem_stat.output_mem, + ), + "act_bw_per_module": max( + 0, + mod_mem_stat.snapshots[_ModState.PEAK_BW][-1][dev][_MemRefType.ACT], + ), + "act_grad_per_module": ( + mod_mem_stat.snapshots[_ModState.PEAK_BW][-1][dev][_MemRefType.TEMP] + - mod_mem_stat.snapshots[_ModState.PRE_BW][-1][dev][ + _MemRefType.TEMP + ] + ), + "act_total": mod_mem_stat.snapshots[_ModState.POST_FW][-1][dev][ + _MemRefType.ACT + ], + "input_per_module": mod_mem_stat.input_mem, + "output_per_module": mod_mem_stat.output_mem, + "fw_runtime_per_module": mod_runtime_stats[mod_mem_stat.mod_fqn]["fw"], + "bw_runtime_per_module": mod_runtime_stats[mod_mem_stat.mod_fqn]["bw"], + "is_leaf": is_leaf, + "sac_runtime": sac_runtime, + "sac_memory": sac_memory, + "n_segments": n_segments, + "slopes": slopes, + "intercepts": intercepts, + "breakpoints": breakpoints, + "tradeoff_curve": tradeoff_curve, + } + module_info["mod_stats"].append(mod_stat) + + return module_info + + +class Node(ModStats): + index: int # index according to forward pre-order + pos_fw_post_order: int # index according to forward post-order + + +class Graph: + def __init__(self, n: int) -> None: + self.nodes: List[Node] = [] + self.name2node: Dict[str, Node] = {} + self.ad_matrix = np.zeros((n, n)) + self.fw_post_order: List[str] = [] + + def add_node(self, node: Node) -> None: + self.nodes.append(node) + self.name2node[node["fqn"]] = node + + +def parse_module_info(module_info: ModuleInfo) -> Graph: + """ + Parse module info and create a graph (tree) of modules. The graph will be + used by MILP solver to find optimal SAC and/or FSDP configurations. + """ + mod_stats = module_info["mod_stats"] + fw_pre_order = module_info["mod_order"]["fw_pre_order"] + # assertion and number of nodes + assert len(mod_stats) == len(fw_pre_order) + n_nodes = len(mod_stats) + + # create graph + g = Graph(n_nodes) + g.fw_post_order = module_info["mod_order"]["fw_post_order"] + + # sort the modules by pre-order and add them to the graph + module_info["mod_stats"] = sorted( + mod_stats, key=lambda x: fw_pre_order.index(x["fqn"]) + ) + for i, one_mod_stats in enumerate(mod_stats): + node: Node = cast(Node, one_mod_stats) + node["index"] = i + node["pos_fw_post_order"] = g.fw_post_order.index(node["fqn"]) + g.add_node(node) + + # set up ancestor-descendant matrix + for i in range(n_nodes): + for j in range(i, n_nodes): + if is_self_or_submodule(g.nodes[j]["fqn"], g.nodes[i]["fqn"]): + g.ad_matrix[i][j] = 1 + else: + break + + return g + + +def is_self_or_submodule(name_descendant: str, name_ancestor: str) -> bool: + """ + check if name_descendant is a submodule of name_ancestor, or if they are the same + """ + return name_descendant == name_ancestor or name_ancestor + "." in name_descendant + + +def is_submodule(name_descendant: str, name_ancestor: str) -> bool: + """ + if name_descendant is a submodule of name_ancestor, but not the same + """ + return name_ancestor + "." in name_descendant + + +def display_bytes(b: int, unit: str = "MiB") -> str: + """ + return a string that represent the number of bytes in a desired unit + """ + if unit == "KiB": + return f"{b/2**10:.2f} KiB" + if unit == "MiB": + return f"{b/2**20:.2f} MiB" + if unit == "GiB": + return f"{b/2**30:.2f} GiB" + return f"{b:.2f} bytes" + + +def get_peak_memory_runtime_baseline(graph: Graph) -> Tuple[int, float]: + """ + Get the baseline peak memory and runtime. + Baseline here means there is no FSDP or AC. + Memory includes the parameters, gradients, activations, and activation gradients. + Memory does not include e.g., optimizer states, embedding tables, etc. + + Returns: + int: peak memory in bytes + float: compute time in ms + """ + P_1 = graph.nodes[0]["param_per_module"] + num_nodes = len(graph.nodes) + peak_mem = 0 + for i in range(num_nodes): + TG_i = graph.nodes[i]["grad_total"] + AG_i = graph.nodes[i]["act_grad_per_module"] + TA_i = graph.nodes[i]["act_total"] + peak_mem = max(peak_mem, P_1 + TG_i + AG_i + TA_i) + compute_time = ( + graph.nodes[0]["fw_runtime_per_module"] + + graph.nodes[0]["bw_runtime_per_module"] + ) + return (peak_mem, compute_time) diff --git a/mindnlp/core/distributed/_tools/mem_tracker.py b/mindnlp/core/distributed/_tools/mem_tracker.py new file mode 100644 index 000000000..b0bc2bfdb --- /dev/null +++ b/mindnlp/core/distributed/_tools/mem_tracker.py @@ -0,0 +1,943 @@ +import math +import os +import re +import warnings +from copy import deepcopy +from enum import auto, Enum +from functools import partial, wraps +from typing import ( + Any, + Callable, + Dict, + List, + Optional, + Set, + Tuple, + Type, + TYPE_CHECKING, + Union, +) +from typing_extensions import Self + +from mindnlp import core +from mindnlp.core import nn, optim +from core.distributed._tools.mod_tracker import ModTracker +from core.optim.optimizer import ( + register_optimizer_step_post_hook, + register_optimizer_step_pre_hook, +) +from core.utils._python_dispatch import ( + is_traceable_wrapper_subclass, + TorchDispatchMode, +) +from core.utils._pytree import tree_flatten, tree_map_only +from core.utils.weak import WeakIdKeyDictionary, weakref + + +if TYPE_CHECKING: + from core.utils.hooks import RemovableHandle + +# This value is hard-coded here: +# https://github.com/pytorch/pytorch/blob/5fba5d83f0703ff8077ab65448a998e9ad6598fd/c10/cuda/CUDACachingAllocator.cpp#L117 +_PYTORCH_MIN_ALLOCATE = ( + 2**9 if int(os.environ.get("PYTORCH_NO_CUDA_MEMORY_CACHING", 0)) == 0 else 1 +) +_TOTAL_KEY = "Total" + +__all__ = ["MemTracker"] + + +class _RefType(str, Enum): + """Base Class for defining memory reference types, categorizing tensors based on their usage within a model.""" + + +class _State(str, Enum): + """Base Class for defining module state to capture snapshots .""" + + +class _MemRefType(_RefType): + """ + An enum to define memory reference types, categorizing tensors based on their usage within a model. + + - PARAM: Tensors registered as nn.Parameter within modules. + - BUFFER: Tensors registered as nn.Buffer within modules. + - GRAD: Gradients associated with parameters. + - ACT: Tensors produced during the forward pass and recomputation in activation checkpointing. + - TMP: Temporary memory used during the backward pass, including gradients of activations. + - OPT: Tensors holding optimizer states. + - OTH: Tensors registered via `track_external` that do not fit the above categories. + """ + + PARAM = "Parameter" + BUFFER = "Buffer" + GRAD = "Gradient" + ACT = "Activation" + TEMP = "Temp" + OPT = "Optstate" + OTH = "Other" + + +class _ModState(_State): + """ + An enum to define the state of a module. + + - PRE_FW: The module is about to run the forward pass. + - POST_FW: The module has finished running the forward pass. + - PEAK_FW: The module has reached the peak memory usage during the forward pass. + - PRE_BW: The module is about to run the backward pass. + - PRE_FW_AC: The module is about to run the forward pass with activation checkpointing. + - POST_FW_AC: The module has finished running the forward pass with activation checkpointing. + - POST_BW: The module has finished running the backward pass. + - PEAK_BW: The module has reached the peak memory usage during the backward pass. + """ + + PRE_FW = "Pre-Forward" + POST_FW = "Post-Forward" + PEAK_FW = "Peak-Forward" + PRE_BW = "Pre-Backward" + PRE_FW_AC = "Pre-Forward-AC" + POST_FW_AC = "Post-Forward-AC" + POST_BW = "Post-Backward" + PEAK_BW = "Peak-Backward" + + +class _ModMemStats: + """ + A class to store the memory statistics of a module. + + Args: + mod_fqn (str): The fully qualified name of the module. + Attributes: + mod_fqn (str): The fully qualified name of the module. + parameter_mem (int): The memory usage of the parameters of the module. + buffer_mem (int): The memory usage of the buffers of the module. + input_mem (int): The memory usage of the inputs to the module. + output_mem (int): The memory usage of the outputs from the module. + snapshots (Dict[_ModState, Dict[core.device, Dict[str, int]]]): A dictionary of memory snapshots + of the module at different states defined by ``_ModState``. + Note: + The memory snapshot is stored as a dictionary - Dict[core.device, Dict[str, int]], where each key is a device, + and each value is another dictionary with keys as memory reference types defined by `_MemRefType` and + values as the memory consumed in bytes. + """ + + def __init__(self, mod_fqn: str): + self.mod_fqn = mod_fqn + self.parameter_mem: int + self.buffer_mem: int + self.input_mem: int + self.output_mem: int + self.local_peak: Dict[core.device, int] = {} + self.snapshots: Dict[_ModState, List[Dict[core.device, Dict[str, int]]]] = {} + + +class _WeakRefInfo: + """ + Manages memory statistics and device attributes for tensor storages. + """ + + def __init__( + self, size: int, element_size: int, device: core.device, reftype: _RefType + ) -> None: + """ + Initializes the ``_WeakRefInfo`` object with tensor storage properties. + + Args: + size (int): The number of elements in the tensor storage. + element_size (int): The size of each element in the tensor storage. + device (core.device): The device on which the tensor is allocated. + reftype (_RefType): The reference type of the tensor. + """ + self.size = size + self.element_size = element_size + self.reftype = reftype + self.device = device + self.mem_consumed = self._calculate_mem_consumed() + + def _calculate_mem_consumed(self) -> int: + """ + Calculates the memory consumed by the tensor storage, considering device-specific allocation rules. + + Returns: + int: The memory consumed in bytes. + """ + mem = self.size * self.element_size + if self.device.type == "cuda": + return math.ceil((mem) / _PYTORCH_MIN_ALLOCATE) * _PYTORCH_MIN_ALLOCATE + return mem + + def update_mem_consumed(self, st: core.UntypedStorage) -> int: + """ + Updates and returns the memory consumed if the storage size has changed. + + Args: + st (core.UntypedStorage): The tensor storage to check for size updates. + + Returns: + int: The updated memory consumed in bytes. + """ + if st.size() != self.size: + self.size = st.size() + self.mem_consumed = self._calculate_mem_consumed() + return self.mem_consumed + + @staticmethod + def get_untyped_storages(t: core.Tensor) -> Set[core.UntypedStorage]: + """ + Recursively extracts untyped storages from a tensor or its subclasses. + + Args: + t (core.Tensor): The tensor to extract storages from. + + Returns: + Set[core.UntypedStorage]: A set of untyped storages. + """ + unflattened_tensors = [t] + flattened_tensor_storages = set() + while len(unflattened_tensors) > 0: + obj = unflattened_tensors.pop() + if is_traceable_wrapper_subclass(obj): + attrs, _ = obj.__tensor_flatten__() # type: ignore[attr-defined] + unflattened_tensors.extend([getattr(obj, attr) for attr in attrs]) + else: + if not hasattr(obj, "untyped_storage"): + warnings.warn( + f"Expected a tensor or a traceable wrapper-subclass of tensor, but got {type(obj)}", + category=UserWarning, + stacklevel=2, + ) + else: + flattened_tensor_storages.add(obj.untyped_storage()) + return flattened_tensor_storages + + @classmethod + def create_winfo( + cls, + st: core.UntypedStorage, + device: core.device, + reftype: _RefType, + callback: Optional[Callable[[Self, weakref.ref], Any]] = None, + ) -> Tuple[Self, weakref.ref]: + """ + Creates a new ``_WeakRefInfo`` instance and a weak reference to a ``core.UntypedStorage`` object, + optionally attaching a callback to the weak reference. + + Args: + st (core.UntypedStorage): The storage object for which to create the weak reference info. + device (core.device): The device associated with the storage object. + reftype (_RefType): The type of reference, used to categorize the storage. + callback (Optional[Callable[[Self, weakref.ref]]]): A callback function that is called when + the storage object is about to be finalized (garbage collected). The callback function + should accept two arguments: the ``_WeakRefInfo`` instance and the weak reference to the storage. + Returns: + Tuple[Self, weakref.ref]: A tuple containing the newly created ``_WeakRefInfo`` instance and the + weak reference to the storage object. The weak reference may have an attached callback if provided. + """ + + winfo = cls(st.size(), st.element_size(), device, reftype) + w_st = weakref.ref(st, partial(callback, winfo) if callback else None) + return winfo, w_st + + +def _get_mem_divisor(units: str) -> int: + unit_dict = {"B": 1, "KiB": 2**10, "MiB": 2**20, "GiB": 2**30} + if units in unit_dict: + return unit_dict[units] + else: + raise ValueError( + f"Unsupported unit: {units}. Supported units are: {', '.join(unit_dict.keys())}" + ) + + +def _rounding_fn(value: int, divisor: int, precision: int) -> Union[float, int]: + return value if divisor == 1 else round(value / divisor, precision) + + +def _print_snapshot(snapshot: Dict[core.device, Dict[str, int]], units: str) -> None: + if len(snapshot) == 0: + print("No memory tracked.") + return + divisor = _get_mem_divisor(units) + for dev, dev_snap in snapshot.items(): + if _rounding_fn(dev_snap[_TOTAL_KEY], divisor, 2) <= 0: + continue + print( + f"Device: {dev}", + *( + f"\t{k}: {_rounding_fn(v, divisor, 2)} {units}" + for k, v in dev_snap.items() + ), + sep="\n", + ) + + +def _print_snapshot_tabular( + snapshot: Dict[core.device, Dict[str, int]], units: str +) -> None: + if len(snapshot) == 0: + print("No memory tracked.") + return + try: + from tabulate import tabulate + except ImportError as err: + raise ImportError( + "Please install tabulate to use the tabulate option." + ) from err + divisor = _get_mem_divisor(units) + table_data = [] + key_list = list(next(iter(snapshot.values())).keys()) + headers = ["Device"] + [f"{key}" for key in key_list] + + for dev, dev_snap in snapshot.items(): + if _rounding_fn(dev_snap[_TOTAL_KEY], divisor, 2) <= 0: + continue + row = [str(dev)] + row.extend(f"{_rounding_fn(v, divisor, 2)} {units}" for v in dev_snap.values()) + table_data.append(row) + print(tabulate(table_data, headers=headers, tablefmt="rst")) + + +def _print_state_snapshots( + snapshots: Dict[_State, List[Dict[core.device, Dict[str, int]]]], units: str +) -> None: + for state, snapshot_list in snapshots.items(): + print(f"{state}") + for i, snapshot in enumerate(snapshot_list): + print(f"# {i + 1}:") + _print_snapshot(snapshot, units) + print() + + +def _print_state_snapshots_tabular( + snapshots: Dict[_State, List[Dict[core.device, Dict[str, int]]]], units: str +) -> None: + try: + from tabulate import tabulate + except ImportError as err: + raise ImportError( + "Please install tabulate to use the tabulate option." + ) from err + + table_data = [] + last_state_call = None + divisor = _get_mem_divisor(units) + for state, snapshot_list in snapshots.items(): + for i, snapshot in enumerate(snapshot_list): + state_call = f"{state} # {i + 1}" + for dev, dev_snap in snapshot.items(): + if _rounding_fn(dev_snap[_TOTAL_KEY], divisor, 2) <= 0: + continue + row = { + "State & Call": ( + state_call if state_call != last_state_call else "" + ), + "Device": str(dev), + } + last_state_call = state_call + for k, v in dev_snap.items(): + row[f"{k}"] = f"{_rounding_fn(v, divisor, 2)} {units}" + table_data.append(row) + print(tabulate(table_data, headers="keys", tablefmt="rst")) + + +class _UpdateType(Enum): + # These are used for tracking updates to the continuouly maintained memory snapshot. + # ADD - When a new tensor storage is tracked + # DEL - When a tensor storage is about to be finalized (garbage collected). + # REF - When a tensor reference is updated, for instance, the gradients are marked as + # generic backward reference types until the grad_hook categorizes them as gradients. + # SIZE - When a tensor's storage is resized. + ADD = auto() + DEL = auto() + REF = auto() + SIZE = auto() + + +class MemTracker(TorchDispatchMode): + """ + A TorchDispatchMode to track, categorize and attribute the tensor memory created or accessed within its context. + + It categorizes the tracked tensors as parameters, buffers, activations, gradients, temporary memory and optimizer states + as defined by ``_MemRefType`` within its context. It captures memory `snapshots` for the modules, called within its context, + at various states defined by ``_ModState``. + + Attributes: + memory_tracking: A weakref key dictionary to store the memory statistics of each module. Each key + is a reference to a module, and each value is a ``_ModMemStats`` object that stores the memory + statistics of the module. + + Note: + The MemTracker should be used as a context manager. The modules, optimizers, and any other tensors created within + the context of MemTracker will be tracked by default. Any tensors or stateful objects such as modules, optimizers etc. + that need to be tracked but are created outside the MemTracker should be registered using the `track_external` method. + The `track_external` method should be called before the MemTracker is used. Any tensors created outside the ``MemTracker`` + and not supplied to the `track_external` method will not be tracked by the ``MemTracker``. + + Example usage: + + .. code-block:: python + + module = ... + optimizer = ... + inp = ... + mem_tracker = MemTracker() + mem_tracker.track_external(module, optimizer, inp) + with mem_tracker as mt: + loss = module(inp) + print("After Forward:") + mt.display_snapshot("current") + loss.backward() + optimizer.step() + optimizer.zero_grad() + mt.display_snapshot("peak") + mt.display_modulewise_snapshots(depth = 3, units = "MiB") + + Known Limitations: + - The ``MemTracker`` does not track memory for tensors that bypass the ``TorchDispatchMode`` ex. under ``no_dispatch``. + - Resizing tensor storages directly by using non-Tensor methods other than using ``core.Untyped_Storage.resize_`` + is not tracked. File a Github issue if you have use-cases for this. + - If the tensors are not traceable or wrappable subclasses of ``core.Tensor``, then the tracker does not know how to + track their storages. File a Github issue if you have use-cases for this. + - During AC in the backward pass there might be misattribution between activation and temp memory, but the peak memory + will be tracked accurately. This will be fixed in the next update by hooking intricately with ``core.uitls.checkpoint``. + """ + + def __init__(self) -> None: + self.memory_tracking = WeakIdKeyDictionary() + self._curr_mem_snap: Dict[core.device, Dict[str, int]] = {} + self._peak_mem: Dict[core.device, int] = {} + self._peak_mem_snap: Dict[core.device, Dict[str, int]] = {} + self._param_to_grad_hook_handles = WeakIdKeyDictionary() + self._optimizer_hook_handles: Optional[ + Tuple[RemovableHandle, RemovableHandle] + ] = None + # Dictionary to store the ``_WeakRefInfo`` instances corresponding to each tensor's storage. + self._WINFO = WeakIdKeyDictionary() + self._mod_tracker = ModTracker() + # This is a general memory tracker which can be used with any ``_RefType`` subclass + self._ref_class: Type[_RefType] = _MemRefType + # Flags to track if we are in the AC region or optimizer step region + self._in_opt: bool = False + self._in_ac: bool = False + # Weak references to the topmost AC module currently active + self._ac_mod: Optional[weakref.ref] = None + self._orig_resize = core.UntypedStorage.resize_ + + def _update_snap( + self, + u_type: _UpdateType, + winfo: _WeakRefInfo, + old_mem_consumed: Optional[int] = None, + old_reftype: Optional[_RefType] = None, + ) -> None: + # Initialize a flag to track if the total memory might drop to zero after updates. + maybe_zero = False + # Ensure the device entry exists in the current memory snapshot, initializing if necessary. + dev_snap = self._curr_mem_snap.setdefault( + winfo.device, dict.fromkeys(self._ref_class, 0) + ) + dev_snap.setdefault(_TOTAL_KEY, 0) + # Handle different types of updates based on the update type (`u_type`). + if u_type == _UpdateType.ADD: + # Increase the memory consumed for the specific reference type and update the total. + dev_snap[winfo.reftype] += winfo.mem_consumed + dev_snap[_TOTAL_KEY] += winfo.mem_consumed + elif u_type == _UpdateType.DEL: + # Decrease the memory consumed for the specific reference type and reduce the total. + dev_snap[winfo.reftype] -= winfo.mem_consumed + dev_snap[_TOTAL_KEY] -= winfo.mem_consumed + maybe_zero = True + elif u_type == _UpdateType.REF: + assert old_reftype is not None + # Adjust memory consumption between two reference types within the same device. + dev_snap[old_reftype] -= winfo.mem_consumed + dev_snap[winfo.reftype] += winfo.mem_consumed + elif u_type == _UpdateType.SIZE: + assert old_mem_consumed is not None + # Adjust the memory consumed for a reference type due to a change in size. + change = winfo.mem_consumed - old_mem_consumed + dev_snap[winfo.reftype] += change + dev_snap[_TOTAL_KEY] += change + maybe_zero = True + else: + raise ValueError(f"Invalid update type: {u_type}") + # Check if the total memory for the device has dropped to zero. + if maybe_zero: + if self._curr_mem_snap[winfo.device][_TOTAL_KEY] == 0: + # Remove the device entry from the memory snapshot if the total memory is zero. + del self._curr_mem_snap[winfo.device] + + def _update_and_maybe_create_winfos( + self, + t: core.Tensor, + reftype: _RefType, + update_existing: bool = False, + ) -> Set[_WeakRefInfo]: + sts = _WeakRefInfo.get_untyped_storages(t) + winfos = set() + for st in sts: + # Attempt to retrieve existing ``_WeakRefInfo`` and its weak reference from the tracking dictionary. + winfo, _ = self._WINFO.get(st, (None, None)) + if winfo is not None: + # If ``_WeakRefInfo`` exists, check if the reference type needs to be updated. + old_reftype = winfo.reftype + if old_reftype != reftype: + # Update the reference type and apply changes via ``_update_snap``. + winfo.reftype = reftype + self._update_snap(_UpdateType.REF, winfo, old_reftype=old_reftype) + winfos.add(winfo) + elif update_existing: + # If no existing ``_WeakRefInfo`` is found and update_existing is True, raise an error. + raise KeyError("No existing winfo found") + else: + # If no existing _WeakRefInfo is found and update_existing is False, create a new ``_WeakRefInfo``. + winfo, w_st = _WeakRefInfo.create_winfo( + st, t.device, reftype, self._delete_callback + ) + # Store the new ``_WeakRefInfo`` and its weak reference in the tracking dictionary. + self._WINFO[st] = (winfo, w_st) + # Update the snapshot for the newly added ``_WeakRefInfo``. + if winfo.mem_consumed > 0: + self._update_snap(_UpdateType.ADD, winfo) + winfos.add(winfo) + return winfos + + def _delete_callback(self, winfo: _WeakRefInfo, w_st: weakref.ref) -> None: + # Callback to be called when the storage object corresponding to the ``_WeakRefInfo`` + # instance is about to be finalized. + if winfo.mem_consumed > 0: + self._update_snap(_UpdateType.DEL, winfo) + + def _track_resize(self) -> None: + # Need to monkey-patch this because ``core.UntypedStorage.resize_`` is not captured + # by ``TorchDispatchMode``. + @wraps(self._orig_resize) + def resize_(st: core.UntypedStorage, size: int) -> None: + self._orig_resize(st, size) + winfo, _ = self._WINFO.get(st, (None, None)) + if winfo is not None and winfo.size != st.size(): + old_mem_consumed = winfo.mem_consumed + winfo.update_mem_consumed(st) + self._update_snap( + _UpdateType.SIZE, winfo, old_mem_consumed=old_mem_consumed + ) + + core.UntypedStorage.resize_ = resize_ # type: ignore[method-assign, assignment] + + def _restore_resize(self) -> None: + core.UntypedStorage.resize_ = self._orig_resize # type: ignore[method-assign] + + def _update_peak_stats(self, peak_state: _State) -> None: + # We first capture the current memory snapshot of the current tracker state then, + # We step through each of the modules we have tracked so far in ``memory_tracking`` + # and check if it is currently active by querying ``_mod_tracker.parents`` + # If it is active, we update the per device peak memory usage for the module + # corresponding to the ``_State`` which can be ``PEAK_FW`` or ``PEAK_BW``. + curr_snap = self._curr_mem_snap + + for mod_stats in self.memory_tracking.values(): + if mod_stats.mod_fqn in self._mod_tracker.parents: + if peak_state in mod_stats.snapshots: + for dev, dev_snap in curr_snap.items(): + if mod_stats.local_peak.get(dev, 0) < dev_snap[_TOTAL_KEY]: + mod_stats.local_peak[dev] = dev_snap[_TOTAL_KEY] + mod_stats.snapshots[peak_state][-1][dev] = deepcopy( + dev_snap + ) + + for dev, dev_snap in curr_snap.items(): + if self._peak_mem.get(dev, 0) < dev_snap[_TOTAL_KEY]: + self._peak_mem[dev] = dev_snap[_TOTAL_KEY] + self._peak_mem_snap[dev] = deepcopy(dev_snap) + + def _track(self, reftype: _RefType, t: core.Tensor) -> None: + # Get the storages of the tensor and check if we have already tracked them. + # If yes, then check if the storage size has changed and update the current snapshot. + # Else create a new ``_WeakRefInfo`` instance and add it to the dictionary. + sts = _WeakRefInfo.get_untyped_storages(t) + for st in sts: + winfo, _ = self._WINFO.get(st, (None, None)) + if winfo is not None: + if winfo.size != st.size(): + old_mem_consumed = winfo.mem_consumed + winfo.update_mem_consumed(st) + self._update_snap( + _UpdateType.SIZE, winfo, old_mem_consumed=old_mem_consumed + ) + return + else: + winfo, w_st = _WeakRefInfo.create_winfo( + st, t.device, reftype, self._delete_callback + ) + self._WINFO[st] = (winfo, w_st) + # Update the current snapshot for the newly added ``_WeakRefInfo``. + if winfo.mem_consumed > 0: + self._update_snap(_UpdateType.ADD, winfo) + + def get_tracker_snapshot( + self, type: str = "current" + ) -> Dict[core.device, Dict[str, int]]: + """ + Capture a snapshot of the memory usage breakdown per device, based on the specified type. + + Args: + type (str): The type of snapshot to capture. Can be "current" for the current memory usage or "peak" for the + peak memory usage. Defaults to "current". + Returns: + Dict[core.device, Dict[str, int]]: A dictionary where each key is a core.device, and each value is another + dictionary. This inner dictionary has keys representing memory reference + types as defined in ``_MemRefType`` and values representing the amount of + memory consumed in bytes. + Raises: + ValueError: If an invalid type is specified. + """ + if type == "current": + return deepcopy(self._curr_mem_snap) + elif type == "peak": + return deepcopy(self._peak_mem_snap) + else: + raise ValueError(f"Invalid type {type}") + + def _track_module_params_and_buffers( + self, module: nn.Module, install_grad_hooks: bool = True + ) -> Tuple[int, int]: + # Track the parameters and buffers of the module if not already tracked. + # If the parameters have gradients, track the gradients as well. + # If install_grad_hooks is True, install a gradient hook on the parameters + # to track the gradients, if it has not already been installed. + # Return the total memory consumed by the parameters and buffers. + def _grad_hook(grad: core.Tensor) -> None: + self._update_and_maybe_create_winfos( + grad, + _MemRefType.GRAD, + ) + + param_memory = 0 + for param in module.parameters(): + winfos = self._update_and_maybe_create_winfos( + param, + _MemRefType.PARAM, + ) + param_memory += sum(winfo.mem_consumed for winfo in winfos) + if param.grad is not None: + self._update_and_maybe_create_winfos( + param.grad, + _MemRefType.GRAD, + ) + if ( + self._param_to_grad_hook_handles.get(param, None) is None + and install_grad_hooks + ): + grad_hook_handle = param.register_hook(_grad_hook) + post_acc_grad_hook_handle = param.register_post_accumulate_grad_hook( + lambda p: (_grad_hook(p.grad)) + ) + self._param_to_grad_hook_handles[param] = ( + grad_hook_handle, + post_acc_grad_hook_handle, + ) + buffer_memory = 0 + for buffer in module.buffers(): + winfos = self._update_and_maybe_create_winfos( + buffer, + _MemRefType.BUFFER, + ) + buffer_memory += sum(winfo.mem_consumed for winfo in winfos) + return (param_memory, buffer_memory) + + def _track_inputs_or_outputs(self, args: Any) -> int: + # Calculate the memory consumed by the inputs or outputs of the module. + input_or_output_memory = 0 + + def add_inps_or_outs(t: core.Tensor) -> None: + nonlocal input_or_output_memory + sts = _WeakRefInfo.get_untyped_storages(t) + for st in sts: + winfo, _ = self._WINFO.get(st, (None, None)) + if winfo is not None: + input_or_output_memory += winfo.mem_consumed + + tree_map_only(core.Tensor, add_inps_or_outs, args) + return input_or_output_memory + + def _pre_fw_hook(self, module: nn.Module, inputs: Any) -> None: + # This is installed as a pre-fwd user hook with ``ModTracker.`` Based on the following cases we + # set the state and capture the memory snapshot for the module. + # Case 1: If the module is not in the ``memory_tracking`` dictionary, we track the parameters, buffers, + # input and output memory of the module. Create a new ``_ModMemStats`` instance for the module + # and add it to the ``memory_tracking`` dictionary. + # Case 2: If the module is already in the ``memory_tracking`` dictionary and we are in backward, this means + # we are in the AC region. We check if this is the top most module in the AC region. If it is, + # we store a weak reference and set the flag ``_in_ac`` to True. + # Case 3: If the module is already in the ``memory_tracking`` dictionary and we are in forward, this means + # this module is called for the second time. If it is a root module, that means we are in the next + # iteration and we error out. If it is not a root module, that means it's a submodule that is being + # used multiple times in the same iteration, which we allow and track. + # For Case 1 and 3, we also initialiaze the ``local_peak`` and ``PEAK_FW`` snapshot for the module. + mod_name = self._mod_tracker.get_known_fqn(module) + assert mod_name is not None + if module not in self.memory_tracking: + mod_stats = _ModMemStats(mod_name) + param_mem, buffer_mem = self._track_module_params_and_buffers( + module, install_grad_hooks=True + ) + input_mem = self._track_inputs_or_outputs(inputs) + mod_stats.parameter_mem = param_mem + mod_stats.buffer_mem = buffer_mem + mod_stats.input_mem = input_mem + self.memory_tracking[module] = mod_stats + state = _ModState.PRE_FW + + elif self._mod_tracker.is_bw: + mod_stats = self.memory_tracking[module] + state = _ModState.PRE_FW_AC + if self._ac_mod is None: + self._ac_mod = weakref.ref(module) + self._in_ac = True + else: + parents = set(self._mod_tracker.parents) - {mod_name} + if len(parents) == 1 and "Global" in parents: + raise NotImplementedError( + "MemTracker does not support memory tracking for multiple iterative calls." + " Either use ``reset_mod_stats`` to clear module memory stats for the previous iteration" + " or file a github issue if you need this feature." + ) + mod_stats = self.memory_tracking[module] + state = _ModState.PRE_FW + input_mem = self._track_inputs_or_outputs(inputs) + mod_stats.input_mem = input_mem + + mem_snapshot = self.get_tracker_snapshot() + if state == _ModState.PRE_FW: + mod_stats.local_peak = { + dev: dev_snap[_TOTAL_KEY] for dev, dev_snap in mem_snapshot.items() + } + mod_stats.snapshots.setdefault(_ModState.PEAK_FW, []).append(mem_snapshot) + mod_stats.snapshots.setdefault(state, []).append(deepcopy(mem_snapshot)) + + def _post_fw_hook(self, module: nn.Module, inputs: Any, outputs: Any) -> None: + # This is installed as a post-fwd user hook with ``ModTracker``. Based on the following cases we + # set the state and capture the memory snapshot for the module. + # Case 1: This is called in backward, which means we are in the AC region. If this is the top most module + # in the AC region, we set the flag ``_in_ac`` to False. + # Case 2: This is called in forward so we calculate the output memory + # of the module and update its mod_stats. + mod_stats = self.memory_tracking[module] + if self._mod_tracker.is_bw: + state = _ModState.POST_FW_AC + if self._ac_mod is not None and self._ac_mod() is module: + self._ac_mod = None + self._in_ac = False + else: + state = _ModState.POST_FW + output_mem = self._track_inputs_or_outputs(outputs) + mod_stats.output_mem = output_mem + mod_stats.snapshots.setdefault(state, []).append(self.get_tracker_snapshot()) + + def _pre_bw_hook(self, module: nn.Module, args: Any) -> None: + # This is installed as a pre-bwd user hook with ``ModTracker``. We set the state and capture the + # snapshot for the module. We also initialize the ``local_peak`` and ``PEAK_BW`` snapshot for it. + # If the module is None, we skip the hook. + # This can happen since this installed inside a multi-grad hook on the module's output tensors + # and the module itself may not be alive during backward. + if module is None: + warnings.warn("Module is None. Skipping PRE_BW hook.", stacklevel=2) + return + mod_stats = self.memory_tracking[module] + mem_snapshot = self.get_tracker_snapshot() + mod_stats.local_peak = { + dev: dev_snap[_TOTAL_KEY] for dev, dev_snap in mem_snapshot.items() + } + mod_stats.snapshots.setdefault(_ModState.PEAK_BW, []).append(mem_snapshot) + mod_stats.snapshots.setdefault(_ModState.PRE_BW, []).append( + deepcopy(mem_snapshot) + ) + + def _post_bw_hook(self, module: nn.Module, args: Any) -> None: + # This is installed as a post-bwd user hook with ``ModTracker``. We set the state and capture the + # snapshot for the module if it is not None. + # This can happen since this installed inside a multi-grad hook on the module's input tensors + # and the module itself may not be alive during backward. + if module is None: + warnings.warn("Module is None. Skipping POST_BW hook.", stacklevel=2) + return + mod_stats = self.memory_tracking[module] + mod_stats.snapshots.setdefault(_ModState.POST_BW, []).append( + self.get_tracker_snapshot() + ) + + def _track_optimizer_states( + self, reftype: _RefType, optimizer: optim.Optimizer + ) -> None: + for states in optimizer.state.values(): + for val in states.values(): + if isinstance(val, core.Tensor): + self._update_and_maybe_create_winfos( + val, + reftype, + ) + + def _register_global_optimizer_hook(self) -> None: + # Register a hook on the optimizer step to track the optimizer states. + # The pre-hook is to set the flag ``_in_opt`` to True. The post-hook unsets the flag, + # and also tracks any optimizer states that are created during the optimizer step. + def _opt_step_pre_hook( + optimizer: optim.Optimizer, args: Any, kwargs: Any + ) -> None: + self._in_opt = True + + def _opt_step_post_hook( + optimizer: optim.Optimizer, args: Any, kwargs: Any + ) -> None: + self._track_optimizer_states(_MemRefType.OPT, optimizer) + self._in_opt = False + + self._optimizer_hook_handles = ( + register_optimizer_step_pre_hook(_opt_step_pre_hook), + register_optimizer_step_post_hook(_opt_step_post_hook), + ) + + def _deregister_param_and_optimizer_hooks(self) -> None: + for ( + grad_hook_handle, + post_acc_grad_hook_handle, + ) in self._param_to_grad_hook_handles.values(): + grad_hook_handle.remove() + post_acc_grad_hook_handle.remove() + self._param_to_grad_hook_handles.clear() + + if self._optimizer_hook_handles is not None: + for handle in self._optimizer_hook_handles: + handle.remove() + self._optimizer_hook_handles = None + + def track_external( + self, *external: Union[nn.Module, optim.Optimizer, core.Tensor] + ) -> None: + """ + Track tensors and stateful objects like modules, optimizers etc. that are created outside the MemTracker. + + This method should be called before the ``MemTracker`` is used. Any tensors that are not module parameters, buffers, + gradients activations, or optimizer states will be categorized as ``Other``. If you want them categorized with a + custom name, please file a GitHub issue. Any tensors created outside the MemTracker and not supplied to this + method will not be be tracked by ``MemTracker``. + + Args: + *external (Union[nn.Module, optim.Optimizer, core.Tensor]): The external modules, optimizers, and + tensors to be tracked. + """ + flat_external, _ = tree_flatten(external) + for obj in flat_external: + if isinstance(obj, core.Tensor): + self._update_and_maybe_create_winfos( + obj, + _MemRefType.OTH, + ) + elif isinstance(obj, core.nn.Module): + self._track_module_params_and_buffers(obj, install_grad_hooks=False) + elif isinstance(obj, optim.Optimizer): + self._track_optimizer_states(_MemRefType.OPT, obj) + else: + raise TypeError( + f"Object of type {type(obj)} is not supported for tracking. " + f"Only stateful objects like modules, optimizers, and tensors are supported." + ) + + def display_snapshot( + self, type: str = "current", units: str = "B", tabulate: bool = False + ) -> None: + """ + Display the memory usage breakdown snapshot of the tracker based on the specified type and units. + + Keyword args: + type (str): The type of snapshot to display. Can be "current" for the current memory usage or "peak" for the + peak memory usage. Defaults to "current". + units (str): The units to use for displaying memory usage. Defaults to "B". Supports ["B", "KiB", "MiB", "GiB"]. + tabulate (bool): Whether to display the snapshot in a tabular format. Defaults to False. + """ + snapshot = self.get_tracker_snapshot(type) + if tabulate: + _print_snapshot_tabular(snapshot, units) + else: + _print_snapshot(snapshot, units) + + def display_modulewise_snapshots( + self, depth: int = 2, units: str = "B", tabulate: bool = False + ) -> None: + """ + Print per device memory breakdown snapshot for each module called within MemTracker. + + Snapshots are displayed for the states defined by ``_ModState``. + The module hierarchy is displayed up to the specified depth. + + Keyword Args: + depth (int, optional): The depth of the module hierarchy to display. Defaults to 2. + units (str, optional): The units to use for memory tracking. Defaults to "B". Supports ["B", "KiB", "MiB", "GiB"]. + tabulate (bool, optional): Whether to display the snapshot in a tabular format. Defaults to False. + """ + + def natural_sort_key(s: str) -> List[Union[int, str]]: + return [ + int(text) if text.isdigit() else text.lower() + for text in re.split("([0-9]+)", s) + ] + + for mod_stats in sorted( + self.memory_tracking.values(), + key=lambda m_stats: natural_sort_key(m_stats.mod_fqn), + ): + mod_fqn = mod_stats.mod_fqn + mod_depth = mod_fqn.count(".") + 1 + if mod_depth > depth: + continue + print(f"Module: {mod_fqn}") + if tabulate: + _print_state_snapshots_tabular(mod_stats.snapshots, units) + else: + _print_state_snapshots(mod_stats.snapshots, units) + + def reset_mod_stats(self) -> None: + """ + Reset all the module memory stats. Clears ``memory_tracking`` dictionary. + """ + self.memory_tracking.clear() + + def __enter__(self) -> "MemTracker": + self._register_global_optimizer_hook() + self._mod_tracker.register_user_hooks( + self._pre_fw_hook, + self._post_fw_hook, + self._pre_bw_hook, + self._post_bw_hook, + ) + self._track_resize() + self._peak_mem_snap = self.get_tracker_snapshot() + self._peak_mem = { + dev: dev_snap[_TOTAL_KEY] for dev, dev_snap in self._peak_mem_snap.items() + } + self._mod_tracker.__enter__() + super().__enter__() + return self + + def __exit__(self, *args: Any) -> None: + self._deregister_param_and_optimizer_hooks() + self._mod_tracker.clear_user_hooks() + self._restore_resize() + super().__exit__(*args) + self._mod_tracker.__exit__(*args) + + def __torch_dispatch__(self, func, types, args=(), kwargs=None): # type: ignore[no-untyped-def] + res = func(*args, **kwargs or {}) + # If we are tracking an optimizer state, we use the optimizer reference type. + # If we are in backward region and not in AC region, we use the backward reference type. + # Else we use the forward reference type. + if self._in_opt: + reftype = _MemRefType.OPT + elif self._mod_tracker.is_bw and not self._in_ac: + reftype = _MemRefType.TEMP + else: + reftype = _MemRefType.ACT + tree_map_only(core.Tensor, partial(self._track, reftype), res) + peak_state = _ModState.PEAK_BW if self._mod_tracker.is_bw else _ModState.PEAK_FW + self._update_peak_stats(peak_state) + return res diff --git a/mindnlp/core/distributed/_tools/memory_tracker.py b/mindnlp/core/distributed/_tools/memory_tracker.py new file mode 100644 index 000000000..c35952aab --- /dev/null +++ b/mindnlp/core/distributed/_tools/memory_tracker.py @@ -0,0 +1,295 @@ +# mypy: allow-untyped-defs +import operator +import pickle +from collections import defaultdict +from itertools import chain +from typing import Any, Callable, Dict, List, no_type_check, Sequence, TYPE_CHECKING + +from mindnlp import core +from mindnlp import core.nn as nn +from core.utils._python_dispatch import TorchDispatchMode + + +if TYPE_CHECKING: + from core.utils.hooks import RemovableHandle + + +BYTES_PER_MB = 1024 * 1024.0 + + +class MemoryProfileDispatchMode(TorchDispatchMode): + """Run in ``TorchDispatchMode`` to get memory stats at operator level.""" + + def __init__(self, memory_tracker) -> None: + self.memory_tracker = memory_tracker + + def __torch_dispatch__(self, func, types, args=..., kwargs=None): + rs = func(*args, **kwargs) + if func == core.ops.aten.detach.default: + return rs + func_name: str = ( + self.memory_tracker._cur_module_name + + "." + + func.__name__ + + "_" + + str(self.memory_tracker._operator_names[func.__name__]) + ) + self.memory_tracker._operator_names[func.__name__] = ( + self.memory_tracker._operator_names[func.__name__] + 1 + ) + self.memory_tracker._record_memory_stats(func_name) + + return rs + + +class MemoryTracker: + """ + Collect and plot the memory stats at operator level. + + Includes ``memories_allocated``, ``memories_active`` and ``memories_reserved``. + It also prints a summary for the top 20 operators that generate the most memories. + + Example usage: + + >>> # xdoctest: +SKIP(failing) + >>> net.cuda() + >>> input = input.cuda() + + >>> mem_tracker = MemoryTracker() + >>> mem_tracker.start_monitor(net) + + >>> net.zero_grad(True) + >>> loss = net(input) + >>> if isinstance(loss, dict): + >>> loss = loss['out'] + >>> loss.sum().backward() + >>> net.zero_grad(set_to_none=True) + + >>> mem_tracker.stop() + >>> mem_tracker.summary() + >>> mem_tracker.show_traces() + """ + + def __init__(self) -> None: + core._C._log_api_usage_once("core.distributed.memory_tracker") + self._hooks: List[RemovableHandle] = [] + self._operator_names: Dict[str, int] = defaultdict(int) + self.memories_allocated: Dict[int, Dict[str, float]] = defaultdict() + self.memories_active: Dict[int, Dict[str, float]] = defaultdict() + self.memories_reserved: Dict[int, Dict[str, float]] = defaultdict() + self._markers: Dict[str, int] = defaultdict(int) + self._cur_module_name: str = "" + self._op_index: int = 0 + self._num_cuda_retries: int = 0 + + @no_type_check + def start_monitor(self, root_module: nn.Module) -> None: + """ + Register module hooks and entering ``MemoryProfileDispatchMode``. + + This enables operator level memory stats can be tracked during module runtime. + """ + self._clear_state() + root_module.__setattr__("_memory_tracker_is_root", True) + for name, m in root_module.named_modules(): + if m is not root_module: + m.__setattr__("_memory_tracker_is_root", False) + # fused_proxy_group does not support hooks + if ".fused_proxy_grouped_embedding_bag" in name: + continue + # hook ordering with other hooks added by users is not managed, so + # the memory stats tracked here may not completely accurate. + h1 = m.register_forward_pre_hook(self._create_pre_forward_hook(name)) + h2 = m.register_forward_hook(self._create_post_forward_hook(name)) + # it does not work well with jagged tensor somehow, the root cause is not + # clear and remove it for now as it does not really capture important info. + # h3 = m.register_backward_hook(self._create_backward_hook(name)) + self._hooks.extend([h1, h2]) + core.cuda.empty_cache() + assert getattr(self, "profile_mode", None) is None + self.profile_mode = MemoryProfileDispatchMode(self) + self.profile_mode.__enter__() + + @no_type_check + def stop(self) -> None: + """ + Remove module hooks and exit ``MemoryProfileDispatchMode`` to stop tracking memory stats at operator level. + + Get some aggregated stats when the memory_tracker() is enabled, like cuda ``num_alloc_retries``. + """ + self._num_cuda_retries = core.cuda.memory_stats().get("num_alloc_retries", 0) + + for h in self._hooks: + h.remove() + self._hooks.clear() + assert getattr(self, "profile_mode", None) is not None + self.profile_mode.__exit__(None, None, None) + self.profile_mode = None + + @no_type_check + def summary(self, top: int = 20) -> None: + """ + Print out the top operators that generate the most memories. + + The number of the top operators can be configured. + """ + op_diff: Dict[str, float] = defaultdict(float) + op_name, previous_allocated_memory = self.memories_allocated[0] + for i in range(1, self._op_index): + op_name, current_allocated_memory = self.memories_allocated[i] + op_diff[op_name] = current_allocated_memory - previous_allocated_memory + previous_allocated_memory = current_allocated_memory + + print("------------------------------------------------") + print(f"The number of cuda retries are: {self._num_cuda_retries}") + print(f"Top {top} ops that generates memory are:") + for k, v in sorted(op_diff.items(), key=operator.itemgetter(1), reverse=True)[ + :top + ]: + print(f"{k}: {v}MB") + print("------------------------------------------------") + + @no_type_check + def show_traces(self, path: str = "") -> None: + import matplotlib.pyplot as plt + + def _plot_figure(x, y_values, labels): + min_val = min(list(chain(*y_values))) * 0.999 + max_val = max(list(chain(*y_values))) * 1.001 + plt.figure() + for y, label in zip(y_values, labels): + plt.plot(x, y, label=label) + plt.xlabel("# Operator Calls") + plt.ylabel("Memory (MB)") + plt.legend() + for marker_name, marker in self._markers.items(): + if marker_name == "fw_bw_boundary": + plt.plot( + [marker, marker], + [min_val, max_val], + "r", + lw=2, + label=marker_name, + ) + else: + plt.plot( + [marker, marker], + [min_val, max_val], + "k-", + lw=2, + label=marker_name, + ) + + if path != "": + self.load(path) + + y_1 = [gb for (name, gb) in self.memories_allocated.values()] + y_2 = [gb for (name, gb) in self.memories_active.values()] + y_3 = [gb for (name, gb) in self.memories_reserved.values()] + x = list(range(len(y_1))) + # Split figures when there is big difference between + # "reserved_memory" and "allocated_memory" or "active_memory". + _plot_figure( + x, + [list(y_1), list(y_2), list(y_3)], + ["allocated_memory", "active_memory", "reserved_memory"], + ) + _plot_figure(x, [list(y_1)], ["allocated_memory"]) + _plot_figure(x, [list(y_2)], ["active_memory"]) + _plot_figure(x, [list(y_3)], ["reserved_memory"]) + + def save_stats(self, path: str) -> None: + """Save the stats using pickle during runtime if users want to plot the traces in other places like notebook.""" + stats = { + "memories_allocated": self.memories_allocated, + "memories_active": self.memories_active, + "memories_reserved": self.memories_reserved, + "markers": self._markers, + "num_alloc_retries": self._num_cuda_retries, + } + + with open(path, "wb") as f: + pickle.dump(stats, f, pickle.HIGHEST_PROTOCOL) + + def load(self, path: str) -> None: + """Load the pickled memory stats to plot the traces or print the summary.""" + with open(path, "rb") as f: + stats = pickle.load(f) + + self.memories_allocated = stats["memories_allocated"] + self.memories_active = stats["memories_active"] + self.memories_reserved = stats["memories_reserved"] + self._markers = stats["markers"] + self._num_cuda_retries = stats["num_alloc_retries"] + + def _create_pre_forward_hook(self, name: str) -> Callable: + """Prefix operator name with current module and 'forward', and insert 'fw_start' marker at forward pass start.""" + + def _pre_forward_hook(module: nn.Module, inputs: Any) -> None: + self._cur_module_name = f"{name}.forward" + if ( + hasattr(module, "_memory_tracker_is_root") + and module._memory_tracker_is_root + ): + self._add_marker("fw_start") + + return _pre_forward_hook + + def _create_post_forward_hook(self, name: str) -> Callable: + """Insert the marker 'fw_bw_boundary' at the boundary of forward and backward pass.""" + + def _post_forward_hook( + module: nn.Module, + inputs: Sequence[core.Tensor], + outputs: Sequence[core.Tensor], + ) -> None: + if ( + hasattr(module, "_memory_tracker_is_root") + and module._memory_tracker_is_root + ): + self._add_marker("fw_bw_boundary") + + return _post_forward_hook + + def _create_backward_hook(self, name: str) -> Callable: + """Insert the current module name with backward prefix for the operator name.""" + + def _backward_hook( + module: nn.Module, grad_input: core.Tensor, grad_output: core.Tensor + ) -> None: + self._cur_module_name = f"{name}.backward" + + return _backward_hook + + @no_type_check + def _record_memory_stats(self, fn_name: str) -> None: + """ + Record current memory allocated, current memory active and current memory reserved. + + The memory stats dict is indexed with ``self._op_index``. + """ + memory_allocated: float = core.cuda.memory_allocated() / BYTES_PER_MB + memory_reserved: float = core.cuda.memory_reserved() / BYTES_PER_MB + memory_active: float = ( + core.cuda.memory_stats().get("active_bytes.all.current", 0) / BYTES_PER_MB + ) + self.memories_allocated[self._op_index] = (fn_name, memory_allocated) + self.memories_reserved[self._op_index] = (fn_name, memory_reserved) + self.memories_active[self._op_index] = (fn_name, memory_active) + self._op_index += 1 + + def _add_marker(self, marker_name: str) -> None: + """Set the marker's x-axis value.""" + marker_val = len(self.memories_allocated.values()) + self._markers[marker_name] = marker_val + + def _clear_state(self) -> None: + """Clear states when start_monitor() is called.""" + self._operator_names.clear() + self.memories_allocated.clear() + self.memories_active.clear() + self.memories_reserved.clear() + self._markers.clear() + self._cur_module_name = "" + self._op_index = 0 + self._num_cuda_retries = 0 diff --git a/mindnlp/core/distributed/_tools/mod_tracker.py b/mindnlp/core/distributed/_tools/mod_tracker.py new file mode 100644 index 000000000..3d384822d --- /dev/null +++ b/mindnlp/core/distributed/_tools/mod_tracker.py @@ -0,0 +1,238 @@ +# mypy: allow-untyped-defs +import warnings +import weakref +from typing import Callable, Optional, Set + +from mindnlp import core +from core.autograd.graph import register_multi_grad_hook +from core.nn.modules.module import ( + register_module_forward_hook, + register_module_forward_pre_hook, +) +from core.utils._pytree import tree_flatten + + +__all__ = ["ModTracker"] + + +class ModTracker: + """ + ``ModTracker`` is a context manager that tracks the nn.Module hierarchy during execution + so that other system can query which Module is currently being executed (or its backward is being + executed). + + You can access the ``parents`` attribute on this context manager to get the set of all the + Modules currently being executed via their fqn (fully qualified name, also used as the key within + the state_dict). + You can access the ``is_bw`` attribute to know if you are currently running in backward or not. + + Note that ``parents`` is never empty and always contains the "Global" key. The ``is_bw`` flag + will remain ``True`` after the forward until another Module is executed. If you need it to be + more accurate, please submit an issue requesting this. Adding a map from fqn to the module instance + is possible but not done yet, please submit an issue requesting this if you need it. + + Example usage + + .. code-block:: python + + mod = core.nn.Linear(2, 2) + + with ModTracker() as tracker: + # Access anything during the forward pass + def my_linear(m1, m2, bias): + print(f"Current modules: {tracker.parents}") + return core.mm(m1, m2.t()) + bias + core.nn.functional.linear = my_linear + + mod(core.rand(2, 2)) + + """ + + parents: Set[str] + """ + A Set containing the fqn for each module currently running their forward + """ + + def __init__(self): + self.parents = {"Global"} + self._active_module_cnt = {} + self._known_modules: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary() + self._seen_modules: weakref.WeakSet = weakref.WeakSet() + self._has_callback = False + self._user_pre_fw_hook = None + self._user_post_fw_hook = None + self._user_pre_bw_hook = None + self._user_post_bw_hook = None + + def _maybe_set_engine_callback(self): + # This assumes no concurrent calls to backward + if self._has_callback: + return + + def callback(): + self.parents = {"Global"} + self._has_callback = False + + core.autograd.Variable._execution_engine.queue_callback(callback) + self._has_callback = True + + @property + def is_bw(self): + """ + A boolean marking if this is currently running during the backward pass or not + """ + return core._C._current_graph_task_id() != -1 + + def get_known_fqn(self, mod): + """ + Return the fqn for the given module if it is known to the ``ModTracker``, otherwise ``None``. + """ + return self._known_modules.get(mod, None) + + def register_user_hooks( + self, + pre_fw_hook: Optional[Callable] = None, + post_fw_hook: Optional[Callable] = None, + pre_bw_hook: Optional[Callable] = None, + post_bw_hook: Optional[Callable] = None, + ): + """ + Registers user-specified hooks to be called before/after the forward/backward pass for each + module tracked by the ``ModTracker``. One or more can be ``None``. + Args: + pre_fw_hook (Callable, optional): A hook to be called before the forward pass for the + module. It should have the following signature: + pre_fw_hook (module, input) -> None + post_fw_hook (Callable, optional): A hook to be called after the forward pass for the + module. It should have the following signature: + post_fw_hook (module, input, output) -> None + pre_bw_hook (Callable, optional): A multi-grad hook to be called on all the outputs of + the module that require gradients. It should have the following signature: + pre_bw_hook (module, grad_output) -> None + post_bw_hook (Callable, optional): A multi-grad hook to be called on all the inputs of + the module that require gradients. It should have the following signature: + post_bw_hook (module, grad_input) -> None + Raises: + AssertionError: If a new hook is provided when one is already registered. + Note: + If the module is not alive during the backward pass, the pre_bw_hook and post_bw_hook will + will receive None as the module argument. + The module fqn will be present in the ``parents`` attribute when each of the hooks is called. + Hooks are intended to be used as markers only not to modify the inputs/outputs. + """ + + def set_hook(hook, user_hook, hook_name): + if hook is not None and user_hook is not None: + raise AssertionError( + f"Only one {hook_name} can be registered at a time" + f" Clear the existing hook by calling ``clear_user_hooks`` before registering a new one" + ) + return hook + + self._user_pre_fw_hook = set_hook( + pre_fw_hook, self._user_pre_fw_hook, "pre_fw_hook" + ) + self._user_post_fw_hook = set_hook( + post_fw_hook, self._user_post_fw_hook, "post_fw_hook" + ) + self._user_pre_bw_hook = set_hook( + pre_bw_hook, self._user_pre_bw_hook, "pre_bw_hook" + ) + self._user_post_bw_hook = set_hook( + post_bw_hook, self._user_post_bw_hook, "post_bw_hook" + ) + + def clear_user_hooks(self): + """ + Clears the user specified hooks registered with ``register_user_hooks`` + """ + self._user_pre_fw_hook = None + self._user_post_fw_hook = None + self._user_pre_bw_hook = None + self._user_post_bw_hook = None + + def _get_mod_name(self, mod): + if mod not in self._known_modules: + self._known_modules[mod] = type(mod).__name__ + mod_name = self._known_modules[mod] + if mod not in self._seen_modules: + for name, submod in mod.named_children(): + self._known_modules[submod] = f"{mod_name}.{name}" + self._get_mod_name(submod) + self._seen_modules.add(mod) + return mod_name + + def _get_append_fn(self, w_mod, name, is_bw): + def fn(*args): + if is_bw: + self._maybe_set_engine_callback() + if name in self.parents and not self.is_bw: + + def custom_formatwarning(msg, category, filename, lineno, line=None): + return f"{filename}:{lineno}: {category.__name__}: {msg} \n" + + warnings.formatwarning = custom_formatwarning + warnings.warn( + "The module hierarchy tracking maybe be messed up." + " Please file a bug to PyTorch, if it is the case." + ) + if name not in self.parents: + self._active_module_cnt[name] = 1 + self.parents.add(name) + else: + self._active_module_cnt[name] += 1 + + if self._user_pre_bw_hook is not None and is_bw: + self._user_pre_bw_hook(w_mod(), args) + + return fn + + def _get_pop_fn(self, w_mod, name, is_bw): + def fn(*args): + if self._user_post_bw_hook is not None and is_bw: + self._user_post_bw_hook(w_mod(), args) + if name in self.parents: + self._active_module_cnt[name] -= 1 + if self._active_module_cnt[name] == 0: + self.parents.remove(name) + elif not self.is_bw: + # Due to some input/output not requiring gradients, we cannot enforce + # proper nesting in backward + raise RuntimeError( + "The Module hierarchy tracking is wrong. Report a bug to PyTorch" + ) + + return fn + + def _fw_pre_hook(self, mod, input): + name = self._get_mod_name(mod) + w_mod = weakref.ref(mod) + self._get_append_fn(w_mod, name, False)() + if self._user_pre_fw_hook is not None: + self._user_pre_fw_hook(mod, input) + args, _ = tree_flatten(input) + tensors = [a for a in args if isinstance(a, core.Tensor) and a.requires_grad] + if not self.is_bw and tensors: + register_multi_grad_hook(tensors, self._get_pop_fn(w_mod, name, True)) + + def _fw_post_hook(self, mod, input, output): + name = self._get_mod_name(mod) + w_mod = weakref.ref(mod) + if self._user_post_fw_hook is not None: + self._user_post_fw_hook(mod, input, output) + self._get_pop_fn(w_mod, name, False)() + args, _ = tree_flatten(output) + tensors = [a for a in args if isinstance(a, core.Tensor) and a.requires_grad] + if not self.is_bw and tensors: + register_multi_grad_hook(tensors, self._get_append_fn(w_mod, name, True)) + + def __enter__(self): + self._fw_pre_handle = register_module_forward_pre_hook(self._fw_pre_hook) + self._fw_post_handle = register_module_forward_hook( + self._fw_post_hook, always_call=True + ) + return self + + def __exit__(self, *args): + self._fw_pre_handle.remove() + self._fw_post_handle.remove() diff --git a/mindnlp/core/distributed/_tools/runtime_estimator.py b/mindnlp/core/distributed/_tools/runtime_estimator.py new file mode 100644 index 000000000..0cb5627e4 --- /dev/null +++ b/mindnlp/core/distributed/_tools/runtime_estimator.py @@ -0,0 +1,527 @@ +# Owner(s): ["module: unknown"] +import math +import os +from collections import defaultdict +from typing import Any, Callable, Dict, List, Set, Tuple +from typing_extensions import Self + +from mindnlp import core +from mindnlp import core.utils._pytree as pytree +from core._guards import active_fake_mode +from core._inductor.utils import get_device_tflops, get_gpu_dram_gbps +from core._subclasses.fake_tensor import FakeTensorMode +from core.distributed._tools.mod_tracker import ModTracker +from core.utils._mode_utils import no_dispatch +from core.utils._python_dispatch import TorchDispatchMode +from core.utils.flop_counter import flop_registry + + +aten = core.ops.aten + +# This value is hard-coded here: +# https://github.com/pytorch/pytorch/blob/5fba5d83f0703ff8077ab65448a998e9ad6598fd/c10/cuda/CUDACachingAllocator.cpp#L117 +_PYTORCH_MIN_ALLOCATE = ( + 2**9 if int(os.environ.get("PYTORCH_NO_CUDA_MEMORY_CACHING", 0)) == 0 else 1 +) + +# No fall-back kernel needed/exists for view ops +_VIEW_OPS = { + aten.lift_fresh, + aten.t, + aten.transpose, + aten.view, + aten.detach, + aten._unsafe_view, + aten.split, + aten.adjoint, + aten.as_strided, + aten.diagonal, + aten.expand, + aten.expand_as, + aten.movedim, + aten.permute, + aten.select, + aten.squeeze, + aten.mT, + aten.mH, + aten.real, + aten.imag, + aten.view_as, + aten.unflatten, + aten.unfold, + aten.unbind, + aten.unsqueeze, + aten.vsplit, + aten.hsplit, + aten.split_with_sizes, + aten.swapaxes, + aten.swapdims, + aten.chunk, +} +# We can ignore benchmarking tensor create ops +_CREATE_OPS = { + aten.randint, + aten.randn, + aten.rand, + aten.randn_like, + aten.rand_like, + aten.randint_like, + aten.arange, + aten.ones_like, + aten.zeros_like, +} + +_IGNORE_OPS = _VIEW_OPS | _CREATE_OPS + +__all__ = ["RuntimeEstimator"] + + +class RuntimeEstimator(TorchDispatchMode): + """ + Estimates the GPU runtime in milliseconds using various estimation methods under the ``FakeTensorMode``. + + This class provides a ``TorchDispatchMode`` based context manager that can be used to estimate the eager + runtime of PyTorch functions. It supports two estimation modes, benchmarking (`operator-level-benchmark`) and + roofline cost modeling (`operator-level-cost-model`). + For modules executed under this context manager, it agggregates the forward and backward operation runtimes + and also records their execution orders. + + Attributes: + mod_runtimes (Dict[str, Dict[str, float]]): A dictionary of module runtimes. The key to the outer dictionary + is the fully qualified name (FQN) of the module. For each module the forward and backward runtimes of the + operations are aggregated in the inner dictionary keyed by 'fw' and 'bw'. + mod_fw_pre_order (List[str]): List of module FQNs in pre-forward execution order. + mod_bw_pre_order (List[str]): List of module FQNs in pre-backward execution order. + mod_fw_post_order (List[str]): List of module FQNs in post-forward execution order. + mod_bw_post_order (List[str]): List of module FQNs in post-backward execution order. + total_runtime (float): The total estimated runtime in milliseconds. + + Note: + 1) The benchmarking estimate mode will execute kernels on GPU and assumes that every operation can run in + isolation without causing an OOM error. It is also designed to be used only under ``FakeTensorMode``. + 2) Currently wrapper tensor sub-classes such as ``DTensor`` won't produce correct estimates. We plan to support + them in future PRs. + 3) We only estimate the compute time, if your code has communication, it will not be considered. Again, we will + support this in future PRs. + + Example usage: + + .. code-block:: python + + runtime_estimator = RuntimeEstimator() + with FakeTensorMode(): + module = ... + optimizer = ... + inp = ... + with runtime_estimator(estimate_mode_type="operator-level-cost-model"): + loss = module(inp) + loss.backward() + optimizer.step() + optimizer.zero_grad() + runtime_estimator.display_modulewise_stats() + """ + + _float_types: Set[core.dtype] = { + core.float16, + core.bfloat16, + core.float32, + core.float64, + } + _no_fallback_kernel: Set[core._ops._OpNamespace] = set() + fake_mode: FakeTensorMode + + def __init__(self) -> None: + super().__init__() + self._estimate: Callable + self._estimate_mode_type: str + self._mod_tracker = ModTracker() + self.mod_runtimes: Dict[str, Dict[str, float]] = defaultdict( + lambda: defaultdict(lambda: 0.0) + ) + self.mod_fw_pre_order: List[str] = [] + self.mod_bw_pre_order: List[str] = [] + self.mod_fw_post_order: List[str] = [] + self.mod_bw_post_order: List[str] = [] + self.total_runtime: float = 0.0 + + # Adapted from: https://github.com/pytorch/pytorch/blob/9b902b3ee3bd608a19543362b66bf06c373dd374/torch/_subclasses/fake_tensor.py#L1969 # noqa: PGH004,B950 + # NB: returns fake tensors + @classmethod + def _maybe_run_and_benchmark_fallback_kernel( # type: ignore[no-untyped-def] + cls, + func, + args, + kwargs, + orig_not_implemented_exception, + ): + """ + Runs and benchmarks a fallback kernel for a given function. + + Args: + func (Callable): The function to benchmark. + args (Tuple): The arguments to pass to the function. + kwargs (Dict[str, Any]): The keyword arguments to pass to the function. + orig_not_implemented_exception (Exception): The original exception to raise if the fallback kernel + is not implemented. + + Returns: + Tuple[Any, float]: A tuple containing the result of the function and + the mean operation time in milliseconds. + """ + # these should all be supported, just to be safe + # avoid fallback for operators which inplace modify metadata + # because the input fake tensors would be umodified + if core.Tag.inplace_view in func.tags: # type: ignore[attr-defined] + raise orig_not_implemented_exception + + inp_impls = {} + flat_args, args_spec = pytree.tree_flatten((args, kwargs)) + # Don't use in_kernel_invocation_manager(fake_mode) as we want to do + # REAL compute (not with meta device) + with no_dispatch(): + + def to_real_tensor(e): # type: ignore[no-untyped-def] + if cls.fake_mode.is_our_fake(e): + if e.dtype in cls._float_types: + out = core.rand_like(e, device=e.fake_device) + else: + out = core.ones_like(e, device=e.fake_device) + if e.is_sparse: + out._coalesced_(e.is_coalesced()) + inp_impls[id(out)] = e + return out + return e + + flat_args = [to_real_tensor(a) for a in flat_args] + args, kwargs = pytree.tree_unflatten(flat_args, args_spec) + r = func(*args, **kwargs) + warmup_iters, actual_iters = 2, 3 + for _ in range(warmup_iters): + func(*args, **kwargs) + start_event = core.cuda.Event(enable_timing=True) + end_event = core.cuda.Event(enable_timing=True) + start_event.record(core.cuda.current_stream()) + for _ in range(actual_iters): + func(*args, **kwargs) + end_event.record(core.cuda.current_stream()) + core.cuda.synchronize() + cuda_time = start_event.elapsed_time(end_event) + mean_op_time = cuda_time / actual_iters + + storages = set() + + for e in flat_args: + if isinstance(e, core.Tensor): + if not e.is_sparse: + storages.add(e._typed_storage()._cdata) + + # TODO: also check metadata change on inputs + # proper aliasing/metadata relationship between outputs and inputs will + # not be set up, bc of conversion to device, unless we can reuse an + # input impl + + def map_out(e): # type: ignore[no-untyped-def] + if id(e) not in inp_impls and ( + isinstance(e, core.Tensor) + and not e.is_sparse + and e._typed_storage()._cdata in storages + ): + raise orig_not_implemented_exception + + if isinstance(e, core.Tensor): + if id(e) in inp_impls: + return inp_impls[id(e)] + else: + return cls.fake_mode.fake_tensor_converter.from_real_tensor( + cls.fake_mode, e + ) + else: + return e + + return (pytree.tree_map(map_out, r), mean_op_time) + + @classmethod + def _benchmark_estimate(cls, func, args, kwargs) -> Tuple[Any, float]: # type: ignore[no-untyped-def] + """ + Estimates the runtime of a function using benchmarking. + + Args: + func: The function to estimate. + args: The arguments to pass to the function. + kwargs: The keyword arguments to pass to the function. + res: The result of the function. + + Returns: + Tuple[Any, float]: A tuple containing the result of the function and + the mean operation time in milliseconds. + """ + assert isinstance( + cls.fake_mode, FakeTensorMode + ), "Initialize/Assign FakeTensorMode before using this function" + mean_op_time = 0.0 + if func._overloadpacket not in _VIEW_OPS: + try: + res, mean_op_time = cls._maybe_run_and_benchmark_fallback_kernel( + func, + args, + kwargs, + NotImplementedError, + ) + return (res, mean_op_time) + except NotImplementedError: + cls._no_fallback_kernel.add(func._overloadpacket) + res = func(*args, **kwargs or {}) + return (res, mean_op_time) + + # Adapted from: https://github.com/pytorch/pytorch/blob/9b902b3ee3bd608a19543362b66bf06c373dd374/torch/_inductor/scheduler.py#L589 # noqa: PGH004,B950 + @classmethod + def _roofline_estimate(cls, func, args, kwargs) -> Tuple[Any, float]: # type: ignore[no-untyped-def] + """ + Estimates the runtime of a function using a roofline cost model. + + Args: + func: The function to estimate. + args: The arguments to pass to the function. + kwargs: The keyword arguments to pass to the function. + out: The output of the function. + + Returns: + Tuple[Any, float]: A tuple containing the result of the function and + the mean operation time in milliseconds. + """ + assert ( + core.cuda.is_available() + ), "Roofline estimation needs to access CUDA capabilities to make estimations" + + def get_num_bytes(t: core.Tensor) -> int: + """ + Calculates the memory consumption of a tensor. + + Args: + t (core.Tensor): The input tensor. + + Returns: + int: The memory consumption of the tensor in bytes. + """ + num_bytes = t.untyped_storage().nbytes() + mem_consumed = ( + math.ceil(num_bytes / _PYTORCH_MIN_ALLOCATE) * _PYTORCH_MIN_ALLOCATE + ) + return mem_consumed + + def get_compute_time(func_packet, args, kwargs, out, out_dtypes) -> float: # type: ignore[no-untyped-def] + """ + Estimates the compute time of an aten operator. + + Args: + func_packet: The operator overload packet. + args: The arguments to the operator. + kwargs: The keyword arguments to the operator. + out: The output of the operator. + out_dtypes: The output data types. + + Returns: + float: The estimated compute time in nanoseconds. + """ + if func_packet in flop_registry: + assert ( + len(out_dtypes) == 1 + ), f"Only support single out dtype got {out_dtypes} for {func_packet}" + dtype = out_dtypes.pop() + # This actually gives peta-FLOPs/s hence multiply by 1e15 to get the FLOPs/s + peak_gpu_flops = get_device_tflops(dtype) * 1e15 + # We can expect to achieve 75% of theoretical peak flops + factor = 0.75 + peak_empirical_flops = factor * peak_gpu_flops + flop_count_func = flop_registry[func_packet] + # We divide by a factor of 2 to get the MACs (multiply and accumulate) + flop_count = flop_count_func(*args, **kwargs, out_val=out) / 2 + # We multiply by 1e9 to get the time in nano seconds + compute_time = (flop_count / peak_empirical_flops) * 1e9 + return compute_time + return 0.0 + + def get_transfer_time(flat_args_kwargs, flat_outs) -> float: # type: ignore[no-untyped-def] + """ + Estimates the memory transfer time of input and output tensors. + + Args: + flat_args_kwargs (List[core.Tensor]): The flat list of arguments and keyword arguments. + flat_outs (List[core.Tensor]): The flat list of outputs. + + Returns: + float: The estimated memory transfer time in nanoseconds. + """ + gpu_memory_bandwidth = get_gpu_dram_gbps() + read_bytes = sum( + get_num_bytes(t) + for t in flat_args_kwargs + if isinstance(t, core.Tensor) + ) + write_bytes = sum( + get_num_bytes(t) for t in flat_outs if isinstance(t, core.Tensor) + ) + counted_bytes = read_bytes + write_bytes + # The GPU memory bandwidth is in GB/s so the transfer time is in nanoseconds + transfer_time = counted_bytes / gpu_memory_bandwidth + return transfer_time + + # Roofline Cost Model Explanation + + # The roofline cost model estimates the execution time of an operator based on + # the device's empirical maximum FLOPs/sec (pi) and device DRAM bandwidth (beta). + + # Variables: + # - pi: Maximum empirical FLOPs/sec of the device + # - beta: Maximum empirical device DRAM bandwidth (bytes/sec) of the device + # - I: Arithmetic intensity of the operator (FLOPs/bytes) + # - op_flops: FLOPs required by the operator + # - op_bytes: Bytes transferred to and from DRAM for the operator + + # Calculation Steps: + # 1. Calculate arithmetic intensity: I = op_flops / op_bytes + # 2. Calculate estimated FLOPs/sec: est_flops_sec = min(pi, beta * I) + # 3. Calculate estimated operator time: estimated_op_time = op_flops / est_flops_sec + # This simplifies to: estimated_op_time = max(op_flops / pi, op_flops / (beta * I)) + # Further simplifying: estimated_op_time = max(op_flops / pi, op_bytes / beta) + + # Simplified Formulas: + # - compute_time = op_flops / pi + # - transfer_time = op_bytes / beta + # - estimated_op_time = max(compute_time, transfer_time) + + kwargs = kwargs if kwargs else {} + out = func(*args, **kwargs) + op_time = 0.0 + func_packet = func._overloadpacket + if func_packet not in _IGNORE_OPS: + flat_args_kwargs, args_spec = pytree.tree_flatten((args, kwargs)) + flat_outs, out_spec = pytree.tree_flatten(out) + transfer_time = get_transfer_time(flat_args_kwargs, flat_outs) + + out_dtypes = { + t.dtype + for t in flat_outs + if isinstance(t, core.Tensor) and t.dtype in cls._float_types + } + + args, kwargs = pytree.tree_unflatten(flat_args_kwargs, args_spec) + out = pytree.tree_unflatten(flat_outs, out_spec) + + compute_time = get_compute_time(func_packet, args, kwargs, out, out_dtypes) + # We get the estimated time as the max of the transfer time and + # compute time. We divide by 1e6 to get the time in ms + op_time = max(transfer_time, compute_time) / 1e6 + + return (out, op_time) + + def display_modulewise_stats(self, depth: int = 2) -> None: + """ + Displays module-wise statistics collected by ``RuntimeEstimator``. + + Prints the pre-forward and pre-backward execution orders. + Displays the module-wise forward and backward runtimes in milliseconds. + + Args: + depth (int): The maximum depth of module hierarchy to display (default to 2). + """ + print("Pre-Forward Execution Order: ") + for mod_fqn in self.mod_fw_pre_order: + mod_depth = mod_fqn.count(".") + 1 + if mod_depth > depth: + continue + print(mod_fqn) + print("Pre-Backward Execution Order: ") + for mod_fqn in self.mod_bw_pre_order: + mod_depth = mod_fqn.count(".") + 1 + if mod_depth > depth: + continue + print(mod_fqn) + for mod_fqn, runtimes in self.mod_runtimes.items(): + mod_depth = mod_fqn.count(".") + 1 + if mod_depth > depth: + continue + print( + f"{mod_fqn} fw: {runtimes.get('fw', 0.0):.3f}ms bw: {runtimes.get('bw', 0.0):.3f}ms" + ) + + def __torch_dispatch__(self, func, types, args=..., kwargs=None): # type: ignore[no-untyped-def] + # TODO: @sanketpurandare: Flatten tensors by desugaring the tensor subclasses + # TODO: @sanketpurandare: Add logic for incorporating communication time + res, op_time = self._estimate(func, args, kwargs) + for par in self._mod_tracker.parents: + if self._mod_tracker.is_bw: + self.mod_runtimes[par]["bw"] += op_time + else: + self.mod_runtimes[par]["fw"] += op_time + self.total_runtime += op_time + return res + + def __call__(self, estimate_mode_type: str) -> Self: + """ + Sets the estimate mode type. + + Currently supported modes: + - "operator-level-benchmark": Estimates runtime using operator benchmarking. + - "operator-level-cost-model": Estimates runtime using roofline cost model. + + Args: + estimate_mode_type (str): The type of estimate mode to use. + + Returns: + RuntimeEstimator: The runtime estimator instance. + + Raises: + NotImplementedError: If the estimate mode type is not supported. + """ + if estimate_mode_type == "operator-level-benchmark": + self._estimate = RuntimeEstimator._benchmark_estimate + elif estimate_mode_type == "operator-level-cost-model": + self._estimate = RuntimeEstimator._roofline_estimate + else: + raise NotImplementedError( + f"estimate_mode_type {estimate_mode_type} not supported" + ) + self._estimate_mode_type = estimate_mode_type + return self + + def __enter__(self) -> Self: + fake_mode = active_fake_mode() + assert isinstance( + fake_mode, FakeTensorMode + ), "No FakeTensorMode found, designed to used under FakeTensorMode" + RuntimeEstimator.fake_mode = fake_mode + self.total_runtime = 0.0 + self.mod_runtimes = defaultdict(lambda: defaultdict(lambda: 0.0)) + self.mod_fw_pre_order.clear() + self.mod_bw_pre_order.clear() + self.mod_fw_post_order.clear() + self.mod_bw_post_order.clear() + self._mod_tracker.register_user_hooks( + pre_fw_hook=lambda mod, inp: self.mod_fw_pre_order.append( + self._mod_tracker.get_known_fqn(mod) + ), + pre_bw_hook=lambda mod, g_out: self.mod_bw_pre_order.append( + self._mod_tracker.get_known_fqn(mod) + ), + post_fw_hook=lambda mod, inp, out: self.mod_fw_post_order.append( + self._mod_tracker.get_known_fqn(mod) + ), + post_bw_hook=lambda mod, g_inp: self.mod_bw_post_order.append( + self._mod_tracker.get_known_fqn(mod) + ), + ) + self._mod_tracker.__enter__() + super().__enter__() + return self + + def __exit__(self, *args: Any) -> None: + print( + f"Estimated ({self._estimate_mode_type})" + f"total_time: {self.total_runtime:.3f} ms" + ) + if len(self._no_fallback_kernel) > 0: + print("no_fallback_kernel: ", list(self._no_fallback_kernel)) + super().__exit__(*args) + self._mod_tracker.clear_user_hooks() + self._mod_tracker.__exit__() diff --git a/mindnlp/core/distributed/_tools/sac_estimator.py b/mindnlp/core/distributed/_tools/sac_estimator.py new file mode 100644 index 000000000..bd903ff07 --- /dev/null +++ b/mindnlp/core/distributed/_tools/sac_estimator.py @@ -0,0 +1,997 @@ +import math +import os +import sys +import warnings +from collections import OrderedDict +from dataclasses import astuple, dataclass +from typing import Any, Dict, List, NamedTuple, Optional, Set, Tuple +from typing_extensions import Self + +from mindnlp import core +from mindnlp.core import nan, nn, UntypedStorage +from core._guards import active_fake_mode +from core._subclasses.fake_tensor import FakeTensorMode +from core.distributed._tools.mod_tracker import ModTracker +from core.distributed._tools.runtime_estimator import RuntimeEstimator +from core.testing._internal.composite_compliance import ( + is_inplace, + is_inplace_view_fn, + is_view_fn, +) +from core.utils._python_dispatch import ( + is_traceable_wrapper_subclass, + TorchDispatchMode, +) +from core.utils._pytree import tree_flatten +from core.utils.checkpoint import SAC_IGNORED_OPS + + +__all__ = ["SACEstimator", "SACStats", "MSPS", "SACTradeOffStats", "SACGreedyOrderMeta"] +aten = core.ops.aten + +_ADDITIONAL_IGNORED_OPS = { + aten.lift_fresh.default, # type: ignore[attr-defined] + core.ops.profiler._record_function_exit._RecordFunction, # type: ignore[attr-defined] + aten.clone.default, # type: ignore[attr-defined] # seems needed for core.compile +} +OPS_TO_ALWAYS_SKIP = SAC_IGNORED_OPS | _ADDITIONAL_IGNORED_OPS +# This value is hard-coded here: +# https://github.com/pytorch/pytorch/blob/5fba5d83f0703ff8077ab65448a998e9ad6598fd/c10/cuda/CUDACachingAllocator.cpp#L117 +_PYTORCH_MIN_ALLOCATE = ( + 2**9 if int(os.environ.get("PYTORCH_NO_CUDA_MEMORY_CACHING", 0)) == 0 else 1 +) + + +def _get_untyped_storages(t: core.Tensor) -> Set[core.UntypedStorage]: + """ + Retrieves untyped storages from a `core.Tensor` or one of its traceable wrapper-subclass. + + Args: + t (core.Tensor): Input `core.Tensor` or traceable wrapper-subclass of `core.Tensor`. + + Returns: + Set[core.UntypedStorage]: Set of untyped storages. + + Warns: + UserWarning: If the flattened input is not a tensor or traceable wrapper-subclass. + """ + unflattened_tensors = [t] + flattened_tensor_storages = set() + while len(unflattened_tensors) > 0: + obj = unflattened_tensors.pop() + if is_traceable_wrapper_subclass(obj): + attrs, _ = obj.__tensor_flatten__() # type: ignore[attr-defined] + unflattened_tensors.extend([getattr(obj, attr) for attr in attrs]) + else: + if not hasattr(obj, "untyped_storage"): + warnings.warn( + f"Expected a tensor or a traceable wrapper-subclass of tensor, but got {type(obj)}", + category=UserWarning, + stacklevel=2, + ) + else: + flattened_tensor_storages.add(obj.untyped_storage()) + return flattened_tensor_storages + + +def _display_stats_tabular(headers: List[str], table_data: List[List[Any]]) -> None: + try: + from tabulate import tabulate + except ImportError as err: + raise ImportError("Please install tabulate.") from err + + # Use tabulate to print the table + print(tabulate(table_data, headers=headers, tablefmt="rst")) + + +# Based on: +# https://github.com/fairinternal/xformers/blob/0ded5697a2ea15711ce45131002d04e72053cc6d/xformers/checkpoint.py#L62 +@dataclass +class _SACMetadata: + """ + Stores metadata for a single operator for SAC. + + Attributes: + func (Any): The operator function. + time_taken (float): The time taken by the operator. + memory_used (float): The memory used by the operator. + curr_idx (int): The current operator index. + output_ids (Tuple[int, ...]): The storage IDs of the operator's outputs. + inplace_info (Tuple[int, ...]): Tuple of self and parent operator for in-place operator. + is_view_like (bool): Whether the operator is view-like. + is_rand_op (bool): Whether the operator is a random operator. + """ + + func: Any + time_taken: float + memory_used: float + curr_idx: int + output_ids: Tuple[int, ...] + inplace_info: Tuple[int, ...] + is_view_like: bool + is_rand_op: bool + + +@dataclass +class _SACModMetadata: + """ + Stores metadata for a module for SAC. + + Attributes: + start_idx (int): The starting index of the module's operators. + force_store_random (bool): Whether to force store random operators in the module. + sac_metadata (List[_SACMetadata]): List of metadata for each operator in the module. + """ + + start_idx: int + force_store_random: bool + sac_metadata: List[_SACMetadata] + + +@dataclass +class SACStats: + """ + A class for storing Activation Checkpointing statistics corresponding to a module. + + Attributes: + func_names (List[str]): List of operator names. + runtimes (List[float]): List of operator runtimes in millliseconds. + memory (List[int]): List of operator memory usage in bytes. + view_like_ops (List[int]): Indices of view-like operators. + rand_ops (List[int]): Indices of random operators. + saved_autograd_ops (List[int]): Indices of operator results saved by autograd engine. + inplace_ops (List[Tuple[int, int]]): Tuple of indices of op and its first parent for Inplace operators. + force_store_random (bool): Whether to force store random operator results. + """ + + func_names: List[str] + runtimes: List[float] + memory: List[int] + view_like_ops: List[int] + rand_ops: List[int] + saved_autograd_ops: List[int] + inplace_ops: List[Tuple[int, int]] + force_store_random: bool + + +class MSPS(NamedTuple): + """ + Represents Memory and Runtime Statistics for an operator/operator group. + + Attributes: + func_names (Set[str]): Set of operator/operator group names. + op_idx (int): Operator index (group head index incase of operator groups). + memory (int): Memory usage in bytes. + runtime (float): Runtime in milliseconds. + msps (float): Memory per second calculated as memory/runtime. + """ + + func_names: Set[str] + op_idx: int + memory: int + runtime: float + msps: float + + +@dataclass +class SACTradeOffStats: + """ + Stores statistics for activation-checkpointing trade-off. + + Attributes: + n_segments (int): Number of piecewise linear segments fitted to the trade-off curve. + slopes (List[float]): Slopes of the pieces of linear segments fitted to the trade-off curve. + intercepts (List[float]): Intercepts of the of the pieces of linear segments fitted to the trade-off curve. + fit_breaks (List[float]): Breakpoints of the of the pieces of linear segments fitted to the trade-off curve. + tradeoff_curve (OrderedDict[float, float]): Trade-off curve data of memory discarded vs recomputation time. + sac_memory (int): Total memory of operations available for activation checkpointing in bytes. + sac_runtime (float): Total runtime of operations available for activation checkpointing in milliseconds. + """ + + n_segments: int + slopes: List[float] + intercepts: List[float] + fit_breaks: List[float] + tradeoff_curve: OrderedDict[float, float] + sac_memory: int + sac_runtime: float + + +@dataclass +class SACGreedyOrderMeta: + """ + Stores metadata for Greedy-order SAC. + + Attributes: + recomputed_ops (Set[int]): Set of operator indices to be recomputed. + stored_ops (Set[int]): Set of operator indices to be stored. + inplace_op_groups (Dict[int, Set[int]]): Dictionary of inplace operator groups from group-head to operators. + random_ops_group (Dict[int, Set[int]]): Dictionary of random op group head to random ops. + msps_meta (List[MSPS]): List of Memory and Runtime Statistics for operators. + """ + + recomputed_ops: Set[int] + stored_ops: Set[int] + inplace_op_groups: Dict[int, Set[int]] + random_ops_group: Dict[int, Set[int]] + msps_meta: List[MSPS] + + +class SACEstimator(TorchDispatchMode): + """ + Estimates the memory and recomputation time trade-offs for applying Selective Activation Checkpointing (SAC). + + This class provides a ``TorchDispatchMode`` based context manager that can be used to estimate the memory and + runtime trade-offs of functions or ``core.nn.Module``s for Selective Activation Checkpointing (SAC). It provides + detailed statistics and metadata information for operators of each module and provides a greedy order for selecting + the operators to be recomputed/checkpointed. It also constructs the per-module trade-off graph of discarded memory + vs recomputation time for the obtained greedy order. Using ``RuntimeEstimator`` under the hood, it supports two + estimation modes, `operator-level-benchmark` and (`operator-level-cost-model` (roofline model). + + Attributes: + sac_mod_stats (Dict[str, SACStats]): Dictionary from module FQN (fuly qualified name) to ``SACStats``. + sac_mod_tradeoff_stats (Dict[str, SACTradeOffStats]): Dictionary from module FQN to ``SACTradeOffStats``. + sac_mod_greedy_order_meta (Dict[str, SACGreedyOrderMeta]): Dictionary from module FQN to ``SACGreedyOrderMeta``. + + Note: + 1) This class is designed to be used under ``FakeTensorMode``. + 2) Currently, it only supports estimation of compute time and memory usage, and does not consider communication. + + Example usage: + + .. code-block:: python + + sac_estimator = SACEstimator() + with FakeTensorMode(): + module = ... + inp = ... + with sac_estimator('operator-level-cost-model'): + output = module(inp) + sac_estimator.display_modulewise_sac_stats(depth=4, print_tabular=True) + """ + + def __init__(self) -> None: + self.sac_mod_stats: Dict[str, SACStats] = {} + self.sac_mod_tradeoff_stats: Dict[str, SACTradeOffStats] = {} + self.sac_mod_greedy_order_meta: Dict[str, SACGreedyOrderMeta] = {} + self._mod_tracker = ModTracker() + self._sac_metadata: List[_SACMetadata] = [] + self._sac_mod_metadata: Dict[str, _SACModMetadata] = {} + self._leaf_modules: Set[str] = set() + self._saved_tensor_hook_ctx = core.autograd.graph.saved_tensors_hooks( + self._pack_hook, lambda x: x + ) + self._saved_tensor_ids: Set[int] = set() + self._estimate_runtime = RuntimeEstimator._roofline_estimate + + def _pack_hook(self, x: core.Tensor) -> core.Tensor: + # Hook function to track underlying storage IDs of tensors + # Updates the _saved_tensor_ids set with the IDs of the tensor's storages + # Used in conjunction with core.autograd.graph.saved_tensors_hooks + untyped_storages = _get_untyped_storages(x) + storage_ids = (hash(st) for st in untyped_storages) + self._saved_tensor_ids.update(storage_ids) + return x + + def _pre_fw_hook(self, mod: nn.Module, inputs: Any) -> None: + # Pre-forward hook function to prepare module metadata + # Tracks module FQN, force store random flag, and ``SACModMetadata`` + # Initializes metadata for non-leaf modules, marks leaf modules + mod_fqn = self._mod_tracker.get_known_fqn(mod) + assert mod_fqn is not None + num_children = sum(1 for _ in mod.children()) + if num_children > 0: + force_store_random = self._get_force_store_random(inputs) + self._sac_mod_metadata[mod_fqn] = _SACModMetadata( + start_idx=len(self._sac_metadata), + force_store_random=force_store_random, + sac_metadata=[], + ) + else: + self._leaf_modules.add(mod_fqn) + + def _post_fw_hook(self, mod: nn.Module, inputs: Any, outputs: Any) -> None: + # 1. Retrieves the module's FQN and checks if it's a leaf module + # 2. If not a leaf module, computes: + # - ``SACStats`` using the module's metadata and force store random flag + # - ``SACGreedyOrderMeta`` using the computed SAC statistics + mod_fqn = self._mod_tracker.get_known_fqn(mod) + assert mod_fqn is not None + if mod_fqn in self._leaf_modules: + return + else: + self.sac_mod_stats[mod_fqn] = self._get_sac_stats( + data=self._sac_mod_metadata[mod_fqn].sac_metadata, + force_store_random=self._sac_mod_metadata[mod_fqn].force_store_random, + ) + self.sac_mod_greedy_order_meta[mod_fqn] = self._get_greedy_order_meta( + self.sac_mod_stats[mod_fqn] + ) + + def _get_force_store_random(self, inputs: Any) -> bool: + flat_inputs, _ = tree_flatten(inputs) + return all(not isinstance(x, core.Tensor) for x in flat_inputs) + + def _get_sac_stats( + self, data: List[_SACMetadata], force_store_random: bool + ) -> SACStats: + # 1. Ignore the operations that should be skipped by SAC such as aten.detach.default because autograd + # inserts those during backward and it breaks the fwd-bwd alignment + filtered_data = [x for x in data if x.func not in OPS_TO_ALWAYS_SKIP] + + ( + ops, + runtimes_, + memory_, + new_ids, + output_ids, + inplace_ops_, + view_like_ops_, + rand_ops_, + ) = zip(*[astuple(x) for x in filtered_data], strict=True) + + # 2. Extract the metadata information + runtimes = list(runtimes_) + memory = list(memory_) + func_names = [op._overloadpacket.__name__ for op in ops] + view_like_ops = [i for i, x in enumerate(view_like_ops_) if x] + rand_ops = [i for i, x in enumerate(rand_ops_) if x] + saved_autograd_ops = [ + i + for i, out_ids in enumerate(output_ids) + if set(out_ids).issubset(self._saved_tensor_ids) + ] + + # 3. Remap the inplace indices as we have removed OPS_TO_ALWAYS_SKIP + # FIXME @sanketpurandare: Fix this by changing the parent of the inplace-op + # to itself if the original parent is in OPS_TO_ALWAYS_SKIP. + try: + inplace_ops = [tuple(map(new_ids.index, x)) for x in inplace_ops_ if x] + except ValueError as err: + raise ValueError( + f"The remapping of inplace ops failed since one of the inplace op parents" + f" must have been present in {OPS_TO_ALWAYS_SKIP}" + ) from err + + # 4. The last operation is always stored as the output of the checkpoint + # block, so we can avoid recomputing it. We set the memory to zero + # instead of adding a new constraint because we want both the 0 and 1 + # endpoints for memory_budget to be valid + # FIXME @sanketpurandare: this heuristic for finding the last non-view non-inplace op + # might not always be correct, which would yield suboptimal policies + last_op = len(ops) - 1 + skip_ops_ = set(view_like_ops) | set({x[0] for x in inplace_ops}) + reversed_skip_ops = sorted(skip_ops_, reverse=True) + for op in reversed_skip_ops: + if op == last_op: + last_op -= 1 + + memory[last_op] = 0 + + # 5. Create a single ``SACStats`` object for the entire block of ``_SACMetadata``. + return SACStats( + func_names=func_names, + runtimes=runtimes, + memory=memory, + view_like_ops=view_like_ops, + rand_ops=rand_ops, + saved_autograd_ops=saved_autograd_ops, + inplace_ops=inplace_ops, # type: ignore[arg-type] + force_store_random=force_store_random, + ) + + def _get_inplace_metadata( + self, func: Any, out_storages: Set[UntypedStorage] + ) -> Tuple[int, Tuple[int, ...], Dict[str, Tuple[int, ...]]]: + # 1. Get the current index of the metadata obtained so far + curr_idx = len(self._sac_metadata) + # 2. Get the set of active modules that are not leaf + active_mod_fqns: Set[str] = { + par for par in self._mod_tracker.parents if par not in self._leaf_modules + } + # 3. Output ids are the identifies of the storage objects corresponding to the tensors + output_ids = tuple(hash(st) for st in out_storages) + # 4. If the function is not inplace, return + if not is_inplace(func): + return curr_idx, output_ids, {mod_fqn: () for mod_fqn in active_mod_fqns} + + op_idx = curr_idx + # 5. Initialize the parent op ids of the inplace op for each of the active modules + mod_op_parent_idxs: Dict[str, int] = { + mod_fqn: -1 for mod_fqn in active_mod_fqns + } + for i, d in enumerate(self._sac_metadata): + # 6. Find the first occurence of a tensor corresponding to each module that + # shares the same storage as the current tensor + past_output_ids = d.output_ids + if set(output_ids).issubset(set(past_output_ids)): + for mod_fqn, op_parent_idx in mod_op_parent_idxs.items(): + if op_parent_idx == -1: + if acm_stats := self._sac_mod_metadata.get(mod_fqn, None): + if i >= acm_stats.start_idx: + mod_op_parent_idxs[mod_fqn] = i + else: + assert mod_fqn == "Global" + mod_op_parent_idxs[mod_fqn] = i + # 7. If no parent tensor is found, then it's probably an inplace op on the arguments + # so one can just store the current-op idx as parent idx + for mod_fqn, op_parent_idx in mod_op_parent_idxs.items(): + if op_parent_idx < 0: + mod_op_parent_idxs[mod_fqn] = op_idx + mod_inplace_info = { + mod_fqn: (op_idx, mod_op_parent_idxs[mod_fqn]) + for mod_fqn in active_mod_fqns + } + return curr_idx, output_ids, mod_inplace_info # type: ignore[return-value] + + def __torch_dispatch__( # type: ignore[no-untyped-def] + self, func, types, args=..., kwargs=None + ): + # 1. Get the runtime estimate + out, op_time = self._estimate_runtime(func, args, kwargs) + flat_outs, _ = tree_flatten(out) + out_storages_cuda: Set[UntypedStorage] = set() + out_storages_cpu: Set[UntypedStorage] = set() + cuda_devices: Set[core.device] = set() + for o in flat_outs: + if isinstance(o, core.Tensor): + if o.device.type == "cuda": + out_storages_cuda.update(_get_untyped_storages(o)) + cuda_devices.add(o.device) + else: + out_storages_cpu.update(_get_untyped_storages(o)) + + # Check if there's more than 1 CUDA device + assert ( + len(cuda_devices) <= 1 + ), f"{func.__name__}'s output has more than 1 CUDA devices {cuda_devices}" + + # 2. Get the memory consumed by output + nbytes_cuda = sum( + math.ceil(st.nbytes() / _PYTORCH_MIN_ALLOCATE) * _PYTORCH_MIN_ALLOCATE + for st in out_storages_cuda + ) + nbytes_cpu = sum(st.nbytes() for st in out_storages_cpu) + nbytes = nbytes_cuda + nbytes_cpu + # 3. Get the current operator index, output storage identifiers and inplace metadata + out_storages = out_storages_cuda | out_storages_cpu + curr_idx, output_ids, mod_inplace_info = self._get_inplace_metadata( + func, out_storages + ) + # 4. Determine if the function is in-place, random-op or a view-like + is_view_like = is_view_fn(func) or is_inplace_view_fn(func) + is_rand_op = core.Tag.nondeterministic_seeded in func.tags + if is_view_like: + nbytes = 0 + # sdpa has non-deterministic seed, but might be deterministic + # if no dropout is applied + if func.overloadpacket.__name__ == "_scaled_dot_product_flash_attention": + is_rand_op = kwargs.get("dropout_p", 0) != 0 + # 5. Create metadata information per active non-leaf module + for mod_fqn in self._mod_tracker.parents: + if mod_fqn in self._leaf_modules: + continue + acm = _SACMetadata( + func=func, + time_taken=op_time, + memory_used=nbytes, + curr_idx=curr_idx, + output_ids=output_ids, + inplace_info=mod_inplace_info[mod_fqn], + is_view_like=is_view_like, + is_rand_op=is_rand_op, + ) + if acm_stats := self._sac_mod_metadata.get(mod_fqn, None): + acm_stats.sac_metadata.append(acm) + else: + assert ( + mod_fqn == "Global" + ), f"Module {mod_fqn} not found in AC Mod Stats" + self._sac_metadata.append(acm) + + return out + + def _get_greedy_order_meta(self, sac_stats: SACStats) -> SACGreedyOrderMeta: + # An inplace-op group is a set of inplace-ops that operate on the same underlying tensor storage. + # 1. inplace_op_groups: A dictionary from the top-most parent of inplace-ops to the inplace-ops in the group + # The top-most op can itself be an inplace-op or can be a non-inplace op. + # 2. inplace_op_to_group_head: A dictionary that maps all the inplace-ops to their respective group heads. + inplace_op_groups: Dict[int, Set[int]] = {} + inplace_op_to_group_head: Dict[int, int] = dict(sac_stats.inplace_ops) + + # Initialize inplace_op_groups using inplace_op_to_group_head + for op_idx, group_head_idx in inplace_op_to_group_head.items(): + op_group = inplace_op_groups.setdefault(group_head_idx, {group_head_idx}) + op_group.add(op_idx) + + # Like inplace ops, all of the random ops in the function/module should all be either recomputed or saved + # as a group. This is because, they affect the ranom seed generator. If force_store_random is set True, + # all of the random ops will be stored by default. For easy of manageability, we store the top-most random op + # as the leader of the random_ops_group. + random_ops_group: Dict[int, Set[int]] = {} + random_group_head_idx = min(sac_stats.rand_ops, default=-1) + has_rand_ops = bool(sac_stats.rand_ops) + if has_rand_ops: + random_ops_group[random_group_head_idx] = set(sac_stats.rand_ops) + + # 1. Random ops are stored if force_store_random is set + # 2. View-like ops are recomputed by default + # 3. For inplace_op_groups: + # a) If the head of this group is an inplace op, then we have to store the entire group. + # b) If any op in the group is random and force_store_random is set, then entire group will be stored. + # c) If none of ops in the group are random and the head of the group is not an in-place op, then + # this group can be considered for recomputation in its entireity + stored_ops: Set[int] = set() + recomputed_ops: Set[int] = set() + # Case 1: + if has_rand_ops and sac_stats.force_store_random: + stored_ops.add(random_group_head_idx) + # Case 2: + recomputed_ops.update(set(sac_stats.view_like_ops)) + + for group_head_idx, op_group in inplace_op_groups.items(): + # Case 3a: + if group_head_idx in inplace_op_to_group_head: + stored_ops.add(group_head_idx) + # Case 3b: + if ( + sac_stats.force_store_random & len(op_group & set(sac_stats.rand_ops)) + > 0 + ): + stored_ops.add(group_head_idx) + + # The potential recompute candidates are populated as: + recompute_candidates: Set[int] = set() + # 1) The random group head if it is not stored + if has_rand_ops and random_group_head_idx not in stored_ops: + recompute_candidates.add(random_group_head_idx) + # 2) The in-place op group heads that are not stored + recompute_candidates.update(set(inplace_op_groups.keys()) - stored_ops) + # 3) The non-inplace and non-random ops that are neither stored nor recomputed by default + recompute_candidates.update( + set(range(len(sac_stats.memory))) + - recomputed_ops + - stored_ops + - set(inplace_op_to_group_head.keys()) + - set(sac_stats.rand_ops) + ) + + # We define msps for a recomp candidate as the ratio of memory/runtime aka memory savings per second + msps_meta: List[MSPS] = [] + for cand_idx in recompute_candidates: + op_indices = {cand_idx} + if cand_idx in inplace_op_groups: + op_indices.update(inplace_op_groups[cand_idx]) + if has_rand_ops and cand_idx == random_group_head_idx: + op_indices.update(sac_stats.rand_ops) + + mem = sum(sac_stats.memory[op_idx] for op_idx in op_indices) + runtime = sum(sac_stats.runtimes[op_idx] for op_idx in op_indices) + func_names = {sac_stats.func_names[op_idx] for op_idx in op_indices} + msps = (mem / runtime) if runtime > 0 else sys.float_info.max + msps_meta.append(MSPS(func_names, cand_idx, mem, runtime, msps)) + # We choose canidates to be recomputed based on increasing msps + msps_meta.sort(key=lambda x: x.msps, reverse=True) + return SACGreedyOrderMeta( + recomputed_ops, stored_ops, inplace_op_groups, random_ops_group, msps_meta + ) + + def _get_sac_tradeoff_pwlf_stats( + self, + sac_stats: SACStats, + greedy_order_meta: SACGreedyOrderMeta, + n_segments: int = 2, + save_tradeoff_graph: bool = False, + filename: str = "ac_tradeoff", + ) -> SACTradeOffStats: + try: + import numpy as np # type: ignore[import-not-found] + import pwlf # type: ignore[import-untyped, import-not-found] + except ImportError as err: + raise ImportError("Please install pwlf and numpy package.") from err + + stored_ops, recomputed_ops, inplace_op_groups, random_ops_group, msps_meta = ( + greedy_order_meta.stored_ops, + greedy_order_meta.recomputed_ops, + greedy_order_meta.inplace_op_groups, + greedy_order_meta.random_ops_group, + greedy_order_meta.msps_meta, + ) + # 1. Intitialize the discarded memory and recomputation runtime to sum of already chosen recomputed_ops + recomp_indices: Set[int] = set() + for r_idx in recomputed_ops: + recomp_indices.add(r_idx) + if r_idx in inplace_op_groups: + recomp_indices.update(inplace_op_groups[r_idx]) + if r_idx in random_ops_group: + recomp_indices.update(random_ops_group[r_idx]) + + discarded_mem = sum(sac_stats.memory[op_idx] for op_idx in recomp_indices) + recomp_runtime = sum(sac_stats.runtimes[op_idx] for op_idx in recomp_indices) + # 2. Initialize the max recomputation time and total recomputation memory + sac_runtime = sum(sac_stats.runtimes) + sac_memory = sum(sac_stats.memory) + # 3. Tradeoff curve stores the KV pair of the dicarded memory to total memory and, + # recomputation time to total runtime incurred. + delta = 1e-2 + tradeoff_curve = OrderedDict() + # 4. Initialize the trade-off curve with the stats of of already chosen recomputed_ops + tradeoff_curve[(discarded_mem / sac_memory) + delta] = ( + recomp_runtime / sac_runtime + ) + # 5. Update the trade-off curve with memory and runtime stats of SAC candidates in the + # greedy order of their ``MSPS``. + for cand in msps_meta: + discarded_mem += cand.memory + recomp_runtime += cand.runtime + tradeoff_curve[(discarded_mem / sac_memory) + delta] = ( + recomp_runtime / sac_runtime + ) + # 6. Finally, we add the memory and recomputation time of the always stored ops. + stored_indices: Set[int] = set() + for s_idx in stored_ops: + stored_indices.add(s_idx) + if s_idx in inplace_op_groups: + stored_indices.update(inplace_op_groups[s_idx]) + if s_idx in random_ops_group: + stored_indices.update(random_ops_group[s_idx]) + discarded_mem += sum(sac_stats.memory[op_idx] for op_idx in stored_indices) + recomp_runtime += sum(sac_stats.runtimes[op_idx] for op_idx in stored_indices) + tradeoff_curve[(discarded_mem / sac_memory) + delta] = ( + recomp_runtime / sac_runtime + ) + x_ = list(tradeoff_curve.keys()) + y_ = list(tradeoff_curve.values()) + # 7. We shift the y values to left and x values to right to upperbound the trade-off function + # TODO: Write a better explanation why this needs to be done + x = x_[: len(x_) - 1] + y = y_[1:] + tradeoff_pwlf = pwlf.PiecewiseLinFit(x, y) + # 8. Fit a piecewise linear function with the specified number of segments to the trade-off curve. + n_segments = max(min(len(x) - 2, n_segments), 1) + tradeoff_pwlf.fit(n_segments=n_segments) + + # save prediction graph + def save_prediction_graph( + pwlf_: pwlf.PiecewiseLinFit, x: List[float], y: List[float], filename: str + ) -> None: + try: + import matplotlib.pyplot as plt # type: ignore[import-not-found] + import numpy as np # type: ignore[import-not-found] + except ImportError as err: + raise ImportError( + "Install matplotlib and numpy using pip: pip install matplotlib numpy" + ) from err + # predict for the determined points + xHat = np.linspace(min(x), max(x), num=10000) + yHat = pwlf_.predict(xHat) + + # plot the results + plt.figure() + plt.plot(x, y, "o", label="Shifted") + plt.plot(xHat, yHat, "-", label="Predicted") + plt.plot(x_, y_, "x", label="Original") + plt.ylabel("Recomp time / Total recomp time") + plt.xlabel("Memory discarded / Total memory") + plt.legend() + plt.title(f"{filename}") + plt.suptitle( + f"Total Memory = {sac_memory} B Total Runtime = {sac_runtime:.4f} ms", + fontsize=10, + ) + folder_name = "tradeoff_graphs" + if not os.path.exists(folder_name): + os.makedirs(folder_name) + # Save the plots in the folder + plt.savefig(os.path.join(folder_name, f"{filename}.png")) + + if save_tradeoff_graph: + save_prediction_graph(tradeoff_pwlf, x, y, filename) + # 9. Obtain the slopes, intercepts and breakpoints of the fitted piecewise linear functions + slopes = tradeoff_pwlf.calc_slopes().tolist() + assert isinstance(tradeoff_pwlf.intercepts, np.ndarray) and isinstance( + tradeoff_pwlf.fit_breaks, np.ndarray + ) + intercepts = tradeoff_pwlf.intercepts.tolist() + fit_breaks = tradeoff_pwlf.fit_breaks.tolist() + return SACTradeOffStats( + n_segments=n_segments, + slopes=slopes, + intercepts=intercepts, + fit_breaks=fit_breaks, + tradeoff_curve=tradeoff_curve, + sac_memory=sac_memory, + sac_runtime=sac_runtime, + ) + + def display_sac_stats( + self, sac_stats: SACStats, print_tabular: bool = False + ) -> None: + """ + Displays the SAC statistics. + + Args: + sac_stats (SACStats): The SAC statistics to display. + print_tabular (bool, optional): Whether to print the statistics in a tabular format. Defaults to False. + + Prints: + 1. Total Memory: The total memory usage in bytes. + 2. Total Runtime: The total runtime in milliseconds. + 3. Store Random: A flag indicating whether to force store random operator results. + + Followed by a table with the following columns: + 1. Op Idx: The operator index. + 2. Op Name: The operator name. + 3. Runtimes (ms): The operator runtime in milliseconds. + 4. Memory (B): The operator memory usage in bytes. + 5. View-like: A flag indicating whether the operator is view-like. + 6. Random: A flag indicating whether the operator is random. + 7. Saved Autograd: A flag indicating whether the operator's result is saved by autograd engine. + 8. In-place: The index of the operator's first parent, or None if not in-place. + + If print_tabular is True, the table is printed in a tabular format. + Otherwise, the table is printed in a plain text format. + """ + print( + f"Total Memory: {sum(sac_stats.memory)} B Total Runtime: {sum(sac_stats.runtimes)} ms" + f" Store Random: {sac_stats.force_store_random}" + ) + table_data = [] + op_parent = dict(sac_stats.inplace_ops) + for i, fn_name in enumerate(sac_stats.func_names): + row = [ + str(i), + fn_name, + f"{sac_stats.runtimes[i]:.4f}", + str(sac_stats.memory[i]), + str(i in sac_stats.view_like_ops), + str(i in sac_stats.rand_ops), + str(i in sac_stats.saved_autograd_ops), + str(op_parent.get(i, None)), + ] + table_data.append(row) + # Define headers + headers = [ + "Op Idx", + "Op Name", + "Runtimes(ms)", + "Memory (B)", + "View-like", + "Random", + "Saved Autograd", + "In-place", + ] + if print_tabular: + _display_stats_tabular(headers, table_data) + else: + max_widths = [0 for _ in range(len(headers))] + table_data.insert(0, headers) + for row in table_data: + for i, elem in enumerate(row): + max_widths[i] = max(max_widths[i], len(elem)) + for row in table_data: + print( + "\t".join( + [f"{elem:<{max_widths[i]}}" for i, elem in enumerate(row)] + ) + ) + + def display_sac_tradeoff_stats( + self, + greedy_order_meta: SACGreedyOrderMeta, + sac_stats: SACStats, + print_tabular: bool = False, + ) -> None: + """ + Displays the SAC trade-off statistics. + + Args: + greedy_order_meta (SACGreedyOrderMeta): The SAC greedy order metadata. + sac_stats (SACStats): The SAC statistics. + print_tabular (bool, optional): Whether to print the statistics in a tabular format. Defaults to False. + + Prints: + A table with the following columns: + 1. Op Id(s): The operator index(es). + 2. Op Name(s): The operator name(s). + 3. Discarded Mem (%): The percentage of discarded memory. + 4. Discarded Mem (B): The discarded memory in bytes. + 5. Recomp time (%): The percentage of recomputed time. + 6. Recomp time (ms): The recomputed time in milliseconds. + 7. MSPS: The memory per second. + 8. Always Stored: A flag indicating whether the operator is always stored. + 9. Always Recomputed: A flag indicating whether the operator is always recomputed. + + If print_tabular is True, the table is printed in a tabular format. + Otherwise, the table is printed in a plain text format. + """ + table_data = [] + total_memory, total_runtime = sum(sac_stats.memory), sum(sac_stats.runtimes) + discarded_mem: int = 0 + recomp_runtime: float = 0.0 + + def append_row( + op_indices: Set[int], + func_names: Set[str], + msps: Optional[float] = None, + stored: Optional[bool] = False, + recomputed: Optional[bool] = False, + ) -> None: + row = [ + str(op_indices), + str(func_names), + f"{discarded_mem / total_memory:.4f}", + str(discarded_mem), + f"{recomp_runtime / total_runtime:.4f}", + str(recomp_runtime), + f"{msps:.2e}" if msps is not None else str(nan), + str(stored), + str(recomputed), + ] + table_data.append(row) + + stored_ops, recomputed_ops, inplace_op_groups, random_ops_group, msps_meta = ( + greedy_order_meta.stored_ops, + greedy_order_meta.recomputed_ops, + greedy_order_meta.inplace_op_groups, + greedy_order_meta.random_ops_group, + greedy_order_meta.msps_meta, + ) + + for op_idx in recomputed_ops: + op_indices: Set[int] = {op_idx} + if op_idx in inplace_op_groups: + op_indices.update(inplace_op_groups[op_idx]) + if op_idx in random_ops_group: + op_indices.update(random_ops_group[op_idx]) + discarded_mem += sum(sac_stats.memory[i] for i in op_indices) + recomp_runtime += sum(sac_stats.runtimes[i] for i in op_indices) + func_names = {sac_stats.func_names[i] for i in op_indices} + append_row(op_indices, func_names, recomputed=True) + + for cand in msps_meta: + discarded_mem += cand.memory + recomp_runtime += cand.runtime + op_indices = {cand.op_idx} + if cand.op_idx in inplace_op_groups: + op_indices.update(inplace_op_groups[cand.op_idx]) + if cand.op_idx in random_ops_group: + op_indices.update(random_ops_group[cand.op_idx]) + append_row(op_indices, cand.func_names, msps=cand.msps) + + for op_idx in stored_ops: + op_indices = {op_idx} + if op_idx in inplace_op_groups: + op_indices.update(inplace_op_groups[op_idx]) + if op_idx in random_ops_group: + op_indices.update(random_ops_group[op_idx]) + discarded_mem += sum(sac_stats.memory[i] for i in op_indices) + recomp_runtime += sum(sac_stats.runtimes[i] for i in op_indices) + func_names = {sac_stats.func_names[i] for i in op_indices} + append_row(op_indices, func_names, stored=True) + + headers = [ + "Op Id(s)", + "Op Name(s)", + "Discarded Mem (%)", + "Discarded Mem (B)", + "Recomp time (%)", + "Recomp time (ms)", + "MSPS", + "Always Stored", + "Always Recomputed", + ] + if print_tabular: + _display_stats_tabular(headers, table_data) + else: + max_widths = [0 for _ in range(len(headers))] + table_data.insert(0, headers) + for row in table_data: + for i, elem in enumerate(row): + max_widths[i] = max(max_widths[i], len(elem)) + for row in table_data: + print( + "\t".join( + [f"{elem:<{max_widths[i]}}" for i, elem in enumerate(row)] + ) + ) + + def pwlf_sac_tradeoff_curve( + self, + n_segments: int = 2, + save_tradeoff_graphs: bool = False, + ) -> None: + """ + Fits a piecewise linear function with the specified sumber of segments to the SAC trade-off curve of + discarded memory vs recomputation time. + + Args: + n_segments (int, optional): The number of segments to be used for fitting the piecewise linear function to + the trade-off curve. Defaults to 2. + save_tradeoff_graphs (bool, optional): Whether to save the trade-off graphs to file. Defaults to False. + + If save_tradeoff_graphs is True, the trade-off graphs are saved to file using the module FQN as the filename. + """ + for mod_fqn, sac_stats in self.sac_mod_stats.items(): + self.sac_mod_tradeoff_stats[mod_fqn] = self._get_sac_tradeoff_pwlf_stats( + sac_stats=sac_stats, + greedy_order_meta=self.sac_mod_greedy_order_meta[mod_fqn], + n_segments=n_segments, + save_tradeoff_graph=save_tradeoff_graphs, + filename=mod_fqn, + ) + + def display_modulewise_sac_stats( + self, depth: int = 2, print_tabular: bool = False + ) -> None: + """ + Displays the SAC and trade-off statistics for each module. + + Args: + depth (int, optional): The maximum depth of modules to display. Defaults to 2. + print_tabular (bool, optional): Whether to print the statistics in a tabular format. Defaults to False. + + Prints: + For each module with depth less than or equal to the specified depth: + 1. The SAC statistics for the module (using display_sac_stats). + 2. The SAC trade-off statistics for the module (using display_sac_tradeoff_stats). + + If print_tabular is True, the statistics are printed in a tabular format. + Otherwise, the statistics are printed in a plain text format. + """ + for mod_fqn, sac_stats in self.sac_mod_stats.items(): + mod_depth = mod_fqn.count(".") + 1 + if mod_depth > depth: + continue + print(f"Module: {mod_fqn}") + self.display_sac_stats(sac_stats, print_tabular) + print(f"AC Trade-off for Module: {mod_fqn} MSPS = Memory/Runtime") + self.display_sac_tradeoff_stats( + self.sac_mod_greedy_order_meta[mod_fqn], sac_stats, print_tabular + ) + + def __call__(self, estimate_mode_type: str) -> Self: + """ + Sets the estimate mode type. + + Currently supported modes: + - "operator-level-benchmark": Estimates runtime using operator benchmarking. + - "operator-level-cost-model": Estimates runtime using roofline cost model. + + Args: + estimate_mode_type (str): The type of estimate mode to use. + + Returns: + SACEstimator: The SAC estimator instance. + + Raises: + NotImplementedError: If the estimate mode type is not supported. + """ + if estimate_mode_type == "operator-level-benchmark": + self._estimate_runtime = RuntimeEstimator._benchmark_estimate + elif estimate_mode_type == "operator-level-cost-model": + self._estimate_runtime = RuntimeEstimator._roofline_estimate + else: + raise NotImplementedError( + f"estimate_mode_type {estimate_mode_type} not supported" + ) + return self + + def __enter__(self) -> Self: # type: ignore[no-untyped-def] + fake_mode = active_fake_mode() + assert isinstance( + fake_mode, FakeTensorMode + ), "SAC Estimator should be called in FakeTensorMode" + RuntimeEstimator.fake_mode = fake_mode + self._mod_tracker.register_user_hooks( + pre_fw_hook=self._pre_fw_hook, + post_fw_hook=self._post_fw_hook, + ) + self._mod_tracker.__enter__() + self._saved_tensor_hook_ctx.__enter__() + return super().__enter__() + + def __exit__(self, *args: Any) -> None: # type: ignore[no-untyped-def] + self._saved_tensor_hook_ctx.__exit__() + self._mod_tracker.__exit__(*args) + super().__exit__(*args) diff --git a/mindnlp/core/distributed/_tools/sac_ilp.py b/mindnlp/core/distributed/_tools/sac_ilp.py new file mode 100644 index 000000000..9efe680c2 --- /dev/null +++ b/mindnlp/core/distributed/_tools/sac_ilp.py @@ -0,0 +1,295 @@ +import logging +import math +from enum import IntEnum +from typing import Dict, List, Optional, Tuple + +from core.distributed._tools.ilp_utils import Graph, is_submodule +from core.distributed._tools.sac_estimator import SACStats + + +try: + from pulp import ( # type: ignore[import-untyped,import-not-found] + lpDot, + LpInteger, + LpMaximize, + LpMinimize, + LpProblem, + LpStatus, + lpSum, + LpVariable, + PULP_CBC_CMD, + value, + ) +except ImportError as err: + raise ImportError( + "Please install pulp package. See: https://github.com/coin-or/pulp." + ) from err + +# Create a logger object +logger = logging.getLogger(__name__) + +# Set the logging level to INFO +logger.setLevel(logging.INFO) + + +def sac_milp( + graph: Graph, + memory_budget: float, + world_size: int = 1, + ac_units: Optional[List[str]] = None, + fsdp_units: Optional[List[str]] = None, +) -> Tuple[Dict[str, float], float, int]: + """ + MILP to decide which modules to AC and how much memory to discard. + The objective is to minimize recomputation time. + The constraint is to ensure peak memory is under budget. + + Args: + graph: graph representation of the model as a module submodule tree + where each node is a submodule with memory & runtime stats + memory_budget: memory budget in GiB + world_size: number of GPUs. In the case of FSDP, world_size will be + used to compute the amount of parameter and gradient memory on each rank + ac_units: a list of user-specified AC units. + fsdp_units: a list of FSDP units. AC units cannot be supermodules of FSDP units. + + Returns: + Dict[str, float]: the optimal SAC solution, mapping from module fqn to + the percentage of activation memory to **discard** + float: the recomputation time of the optimal SAC solution + int: upper bound on the peak memory of the optimal SAC solution. + note that value of -1 means that the ILP solver failed to find a solution. + + """ + num_nodes = len(graph.nodes) + M = 10**2 # note: numerical issue may occur if M is too big + MEM_MULTIPLIER = 2**30 + + # Create a MILP problem + prob = LpProblem("SAC", LpMinimize) + + # Create decision variables + # y_i: indicator for if module i is AC'ed + y = LpVariable.matrix("y", list(range(num_nodes)), 0, 1, LpInteger) + # r_i: percentage of discarded activation memory + r = LpVariable.matrix("r", list(range(num_nodes)), 0, 1) + # d_i: discarded activation memory for module i + d = LpVariable.matrix("d", list(range(num_nodes)), 0) + # a_i: total activation memory at module i + a = LpVariable.matrix("a", list(range(num_nodes)), 0) + # m_i: memory at module i, combining parameters, gradients, and activations + m = LpVariable.matrix("m", list(range(num_nodes)), 0) + # rcp_i: percentage of recomputation time + rcp = LpVariable.matrix("rcp", list(range(num_nodes)), 0) + # rct_i: recomputation time for module i (in ms) + rct = LpVariable.matrix("rct", list(range(num_nodes)), 0) + # max_m: peak memory + max_m = LpVariable("max_m", 0) + + # Add constraints + # [Constraint] User specified AC units + if ac_units: + ac_units_set = set(ac_units) + for i in range(num_nodes): + if graph.nodes[i]["fqn"] not in ac_units_set: + prob += y[i] == 0 + + # [Constraint] AC units cannot be supmodules of user specified FSDP units + if fsdp_units: + for i in range(num_nodes): + if any( + is_submodule(fsdp_unit, graph.nodes[i]["fqn"]) + for fsdp_unit in fsdp_units + ): + prob += y[i] == 0 + + # [Constraint] No nested AC units + for i in range(num_nodes): + for j in range(i + 1, num_nodes): + if graph.ad_matrix[i][j] == 1: + prob += y[i] + y[j] <= 1 + + # [Constraint] Do not AC leaf modules + for i in range(num_nodes): + if graph.nodes[i]["is_leaf"]: + prob += y[i] == 0 + + # [Constraint] Express amount of discarded activation memory + for i in range(num_nodes): + # There are two measures for activation memory: ACM and IA + # 1. IA is the activation memory saved when not using AC + # 2. ACM is the total activation memory, including those + # that are not typically saved when not using AC + # Note: ACM >= IA + if (not graph.nodes[i]["is_leaf"]) and graph.nodes[i][ + "sac_memory" + ] < graph.nodes[i]["act_fw_per_module"]: + logger.warning("For module {%s}: ", graph.nodes[i]["fqn"]) + logger.warning( + "activation memory from memory tracker is {%d},", + graph.nodes[i]["act_fw_per_module"], + ) + logger.warning( + "activation memory from SAC estimator is {%d}.", + graph.nodes[i]["sac_memory"], + ) + logger.warning("Something is wrong. Please check!") + logger.warning("Overriding the latter with the former.") + graph.nodes[i]["sac_memory"] = graph.nodes[i]["act_fw_per_module"] + ACM_i = graph.nodes[i]["sac_memory"] / MEM_MULTIPLIER + IA_i = graph.nodes[i]["act_fw_per_module"] / MEM_MULTIPLIER + prob += d[i] == ACM_i * r[i] - (ACM_i - IA_i) * y[i] + + # [Constraint] Ensure correctness of r_i + # There are two parts to its correctness + # 1. r_i > 0 only if y_i == 1 (discard only if it is an AC unit) + # 2. r_i needs to be large enough to cover the difference between + # ACM and IA. Otherwise, we are not saving any memory + for i in range(num_nodes): + prob += y[i] >= r[i] + if graph.nodes[i]["is_leaf"]: + continue + ACM_i = graph.nodes[i]["sac_memory"] / MEM_MULTIPLIER + IA_i = graph.nodes[i]["act_fw_per_module"] / MEM_MULTIPLIER + prob += r[i] >= (ACM_i - IA_i) / ACM_i * y[i] + + # [Constraint] Express total activation memory in the backward pass + for i in range(num_nodes): + AG_i = graph.nodes[i]["act_grad_per_module"] / MEM_MULTIPLIER + TA_i = graph.nodes[i]["act_total"] / MEM_MULTIPLIER + # related to discarded amount of memory + pos = graph.nodes[i]["pos_fw_post_order"] + coeff = [0] * num_nodes + for p in range(pos): + j = graph.name2node[graph.fw_post_order[p]]["index"] + coeff[j] = 1 + prob += a[i] == TA_i + AG_i - lpDot(coeff, d) + + # [Constraint] Express the total amount of memory at each module + # Note that unsharded parameters and gradients are not included here + P_1 = graph.nodes[0]["param_per_module"] / MEM_MULTIPLIER + for i in range(num_nodes): + TG_i = graph.nodes[i]["grad_total"] / MEM_MULTIPLIER + prob += m[i] == a[i] + (P_1 + TG_i) / world_size + + # [Constraint] Express peak memory + for i in range(num_nodes): + prob += max_m >= m[i] + + # [Constraint] Express percentage of recomputation time + for i in range(num_nodes): + for s in range(graph.nodes[i]["n_segments"]): + slope = graph.nodes[i]["slopes"][s] + intercept = graph.nodes[i]["intercepts"][s] + prob += rcp[i] >= slope * r[i] + intercept + + # [Constraint] Express recomputation time + # rct_i = (rcp_i * ACT_i) if y_i == 1 else 0 + for i in range(num_nodes): + ACT_i = graph.nodes[i]["sac_runtime"] + prob += rct[i] <= M * y[i] + prob += rct[i] <= ACT_i * rcp[i] + prob += rct[i] >= ACT_i * rcp[i] - M * (1 - y[i]) + + # [Constraint] Peak memory should be below budget + prob += max_m <= memory_budget + + # Set Objeictive + prob += lpSum(rct) + + # Solve + solver = PULP_CBC_CMD(gapRel=0.05, timeLimit=180, msg=0) + status = prob.solve(solver) + + # If solver fails, print status and return empty solution + if status != 1: + logger.error("Solver failed to find a solution: %s", LpStatus[status]) + return {}, 0, -1 + + # Gather and return solution if optimal solution is found + ac_decisions = {} + for i in range(num_nodes): + if round(y[i].varValue) == 1: + ac_decisions[graph.nodes[i]["fqn"]] = round(r[i].varValue, 4) + recomputation_time = round(value(prob.objective), 2) + peak_mem = round(max_m.varValue * MEM_MULTIPLIER) + + return ac_decisions, recomputation_time, peak_mem + + +class SACDecision(IntEnum): + RECOMPUTE = 0 + SAVE = 1 + + +def get_optimal_checkpointing_policy_per_module( + sac_stats: SACStats, memory_budget: float +) -> List[int]: + """ + This is adapted from -- + https://github.com/facebookresearch/xformers/blob/c6c0ac31f1b08542a0bc27278c6ed10f825f6963/xformers/checkpoint.py#L375 + + Given the SACStats of a module, including list of operators, their memory, runtimes, and metadata, + decide via MILP an optimal set of operators to checkpoint under a given ``memory_budget``. + + Args: + sac_stats: the SACStats object of the module + memory_budget: a float between zero and one + + Returns: + List[int]: the decision whether each operator should be saved (1) or recomptued (0). + """ + if not (0 <= memory_budget <= 1): + raise ValueError( + f"`memory_budget` must be a float between 0 and 1. Got {memory_budget}." + ) + num_ops = len(sac_stats.func_names) + + # Create a MILP problem + prob = LpProblem("SAC-per-module", LpMaximize) + + # Create decision variables + # x[i] = 1 means the i-th operator should be saved, otherwise it should be recomputed + x = LpVariable.matrix("x", list(range(num_ops)), 0, 1, LpInteger) + + # Add constraints + # [Constraint] random ops should be saved if ``force_store_random`` is True + # otherwise, random ops should either be all recomputed or all saved + if sac_stats.force_store_random: + for i in sac_stats.rand_ops: + prob += x[i] == SACDecision.SAVE.value + else: + for i1, i2 in zip(sac_stats.rand_ops[:-1], sac_stats.rand_ops[1:]): + prob += x[i1] == x[i2] + + # [Constraint] view-like ops should always be recomputed + for i in sac_stats.view_like_ops: + prob += x[i] == SACDecision.RECOMPUTE.value + + # [Constraint] inplace ops should always be done in conjunction with its parent op + for op, op_parent in sac_stats.inplace_ops: + if op != op_parent: + prob += x[op] == x[op_parent] + else: + prob += x[op] == SACDecision.SAVE.value + + # [Constraint] saved memory should be under the ``memory_budget`` + max_memory = math.ceil(memory_budget * sum(sac_stats.memory)) + prob += lpDot(x, sac_stats.memory) <= max_memory + + # [Objective] minimize recomputation time, note the ILP is a maximization problem + # because x[i] == 1 means the op is saved (not recomputed), and thus recomputation + # time is sum(sac_stats.runtimes) - lpDot(x, sac_stats.runtimes) + prob += lpDot(x, sac_stats.runtimes) + + # Solve + solver = PULP_CBC_CMD(gapRel=0.05, timeLimit=10, msg=0) + status = prob.solve(solver) + + # If solver fails, print status and return empty solution + if status != 1: + logger.error("Solver failed to find a solution: %s", LpStatus[status]) + return [] + + # Gather and return solution if optimal solution is found + return [round(x[i].varValue) for i in range(num_ops)] diff --git a/mindnlp/core/distributed/algorithms/__init__.py b/mindnlp/core/distributed/algorithms/__init__.py new file mode 100644 index 000000000..b6650ed30 --- /dev/null +++ b/mindnlp/core/distributed/algorithms/__init__.py @@ -0,0 +1 @@ +from .join import Join, Joinable, JoinHook diff --git a/mindnlp/core/distributed/algorithms/_checkpoint/__init__.py b/mindnlp/core/distributed/algorithms/_checkpoint/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/mindnlp/core/distributed/algorithms/_checkpoint/checkpoint_wrapper.py b/mindnlp/core/distributed/algorithms/_checkpoint/checkpoint_wrapper.py new file mode 100644 index 000000000..a1534c61b --- /dev/null +++ b/mindnlp/core/distributed/algorithms/_checkpoint/checkpoint_wrapper.py @@ -0,0 +1,323 @@ +# mypy: allow-untyped-defs +import warnings +from abc import ABC, abstractmethod +from enum import auto, Enum +from functools import partial +from typing import Any, Callable, Dict, Iterator, Optional, Tuple + +from mindnlp import core +from mindnlp import core.nn as nn +from core.autograd.graph import save_on_cpu +from core.distributed.utils import _pack_kwargs, _replace_by_prefix, _unpack_kwargs +from core.utils.checkpoint import checkpoint as torch_utils_checkpoint + + +_CHECKPOINT_WRAPPED_MODULE = "_checkpoint_wrapped_module" +_CHECKPOINT_PREFIX = _CHECKPOINT_WRAPPED_MODULE + "." + + +class CheckpointImpl(Enum): + REENTRANT = auto() + NO_REENTRANT = auto() + + +class ActivationWrapper(core.nn.Module, ABC): + """ + Base class for Activation Checkpoint and Activation Offload. + + Not meant to be instantiated directly. + """ + + def __init__(self, mod): + super().__init__() + self._checkpoint_wrapped_module = mod + # state_dict post hook to remove prefix to allow loading into a + # non-checkpoint wrapped module. + self._register_state_dict_hook(self._post_state_dict_hook) + # load_state_dict pre-hook to allow loading back into + # checkpoint-wrapped module. + self.register_load_state_dict_pre_hook(self._pre_load_state_dict_hook) + + @abstractmethod + def forward(self, *args, **kwargs): + raise ValueError("Subclasses should implement forward().") + + def __getattr__(self, name: str) -> Any: + """Forward missing attributes to wrapped module.""" + try: + return super().__getattr__(name) # defer to nn.Module's logic + except AttributeError: + return getattr(self._checkpoint_wrapped_module, name) + + def __getitem__(self, key: int) -> Any: + """Forward indexing calls in case the module is a nn.Sequential.""" + return self._checkpoint_wrapped_module.__getitem__(key) # type: ignore[operator] + + def named_parameters( + self, + *args, + **kwargs, + ) -> Iterator[Tuple[str, core.nn.Parameter]]: + """ + Override :meth:`named_parameters()` to intercept parameter names. + + remove all occurrences of ``_CHECKPOINT_PREFIX``. + """ + for param_name, param in super().named_parameters(*args, **kwargs): + yield param_name.replace(_CHECKPOINT_PREFIX, ""), param + + @staticmethod + def _post_state_dict_hook( + module: nn.Module, + state_dict: Dict[str, Any], + prefix: str, + *args: Any, + ) -> Dict[str, Any]: + """ + _post_state_dict_hook() is called after the state_dict() of this FSDP module is executed. + + For ``checkpoint_wrapper``, it will strip checkpoint-wrapped module prefix, + so that this module can be loaded into non-checkpointed modules. + It would still be able to be loaded into checkpoint-wrapped modules as this class, + adds the prefix back before loading the state_dict. + """ + _replace_by_prefix(state_dict, f"{prefix}{_CHECKPOINT_PREFIX}", prefix) + return state_dict + + @staticmethod + def _pre_load_state_dict_hook( + module: nn.Module, + state_dict: Dict[str, Any], + prefix: str, + *args: Any, + ) -> None: + """ + ``_pre_state_dict_hook` is called before ``self._load_from_state_dict()`` is called. + + For ``checkpoint_wrapper``, it will add back the module + prefix so that non-checkpointed modules can be loaded into + checkpoint_wrapper modules properly. + """ + _replace_by_prefix(state_dict, prefix, prefix + f"{_CHECKPOINT_PREFIX}") + + +class OffloadWrapper(ActivationWrapper): + def __init__(self, mod): + super().__init__(mod) + + def forward(self, *args, **kwargs): + with save_on_cpu(pin_memory=True): + return self._checkpoint_wrapped_module(*args, **kwargs) + + +class CheckpointWrapper(ActivationWrapper): + """ + An ``nn.Module`` that wraps another ``nn.Module`` with checkpointing. + + Note that this module is not meant to be used directly but instead, + it is to be used through the ``checkpoint_wrapper`` function. + """ + + def __init__( + self, + mod: core.nn.Module, + checkpoint_impl: CheckpointImpl = CheckpointImpl.NO_REENTRANT, + checkpoint_fn=None, + **checkpoint_fn_kwargs, + ): + super().__init__(mod) + self.checkpoint_impl = checkpoint_impl + if checkpoint_fn is None: + # use core.utils.checkpoint + self.checkpoint_fn = partial( + torch_utils_checkpoint, + use_reentrant=(self.checkpoint_impl == CheckpointImpl.REENTRANT), + **checkpoint_fn_kwargs, + ) + else: + # Construct user-specified checkpoint function. + self.checkpoint_fn = partial( + checkpoint_fn, + **checkpoint_fn_kwargs, + ) + + def forward(self, *args, **kwargs): + # Support keyword arguments for reentrant checkpoint. Note that this + # only works if user has specified self.checkpoint_impl and is not + # using their own custom checkpoint_fn. + if self.checkpoint_impl == CheckpointImpl.REENTRANT and kwargs != {}: + # Pack the args and kwargs + flat_args, kwarg_keys = _pack_kwargs(*args, **kwargs) + + # Function that only takes (packed) args, but can unpack them + # into the original args and kwargs for the checkpointed + # function, and runs that function. + def my_function(*inputs): + # unpack back into args and kwargs + unpacked_args, unpacked_kwargs = _unpack_kwargs(inputs, kwarg_keys) + # run original module + return self._checkpoint_wrapped_module( + *unpacked_args, **unpacked_kwargs + ) + + # Pass the function that only takes packed args into reentrant + # checkpoint API. + return self.checkpoint_fn( # type: ignore[misc] + my_function, + *flat_args, + ) + else: + return self.checkpoint_fn( # type: ignore[misc] + self._checkpoint_wrapped_module, *args, **kwargs + ) + + +def offload_wrapper(module: core.nn.Module) -> core.nn.Module: + """ + Wrap a module for activation offloading to CPU. + + Offloads intermediate activations to the CPU for modules wrapped with this function. + Wrappers with activation offload can be composed with ones that do recomputation-based + checkpoint to trade off increased compute versus increased CPU + memory usage and additional H2D transfers. + + Usage:: + offloaded_module = offload_wrapper(module) + outputs = checkpointed_module(inputs) + Args: + module (nn.Module): + The module to be wrapped + Returns: + (nn.Module): + Wrapped module + """ + return OffloadWrapper(module) + + +def checkpoint_wrapper( + module: core.nn.Module, + checkpoint_impl: CheckpointImpl = CheckpointImpl.NO_REENTRANT, + checkpoint_fn=None, + **checkpoint_fn_kwargs, +) -> core.nn.Module: + """ + Wrap a module for activation checkpointing. + + If the module is wrapped with this function, all subsequent calls to the module will, + automatically perform checkpointing without the user having to explicitly call ``checkpoint`` function. + + Usage:: + checkpointed_module = checkpoint_wrapper(module) + outputs = checkpointed_module(inputs) + Args: + module (nn.Module): + The module to be wrapped + checkpoint_impl (Optional[CheckpointImpl]): + The checkpointing implementation to use. Note that this will only + be passed into the ``core.utils.checkpoint.checkpoint`` + implementation, and is ignored if a custom ``checkpoint_fn`` is + specified. Note that for implementations using reentrant checkpoint + from ``core.utils.checkpoint``, keyword arguments will only be + supported if ``checkpoint_impl`` is passed as ``CheckpointImpl.REENTRANT`. + checkpoint_fn (Optional[Callable]): + Functional checkpoint implementation to use. If this is specified, + it will be used over the default ``core.utils.checkpoint.checkpoint`` + implementation and the `checkpoint_impl` argument will be ignored. + **checkpoint_fn_kwargs: (Dict[str, Any]): Keyword arguments to pass into `checkpoint_fn`. + + Returns: + (nn.Module): + Wrapped module + """ + + if checkpoint_impl == CheckpointImpl.REENTRANT: + warnings.warn( + f"Please specify {CheckpointImpl.NO_REENTRANT} as " + f"{CheckpointImpl.REENTRANT} will soon be removed as " + "the default and eventually deprecated.", + FutureWarning, + stacklevel=2, + ) + return CheckpointWrapper( + module, + checkpoint_impl, + checkpoint_fn, + **checkpoint_fn_kwargs, + ) + + +def apply_activation_checkpointing( + model, + checkpoint_wrapper_fn=checkpoint_wrapper, + check_fn=lambda _: True, + auto_wrap_policy: Optional[Callable[[nn.Module, bool, int], bool]] = None, +): + """ + Apply :func:`checkpoint_wrapper` to modules within `model` based on a user-defined configuration. + + For each module within `model`, the `check_fn` is used to decide + whether `module` should be wrapped with :func:`checkpoint_wrapper` or not. + + Note:: + This function modifies `model` in place and replaces appropriate layers with + their checkpoint-wrapped modules. + Note:: + This function will not wrap the overall root module. If this is needed, please directly use + :func:`checkpoint_wrapper` or :func:`offload_wrapper`. + Usage:: + model = nn.Sequential( + nn.Linear(10, 10), nn.Linear(10, 10), nn.Linear(10, 10) + ) + check_fn = lambda l: isinstance(l, nn.Linear) + # checkpoint activations + apply_activation_checkpointing(model, checkpoint_wrapper_fn=checkpoint_wrapper, check_fn=check_fn) + # Or offload activations to CPU + apply_activation_checkpointing(model, checkpoint_wrapper_fn=offload_wrapper, check_fn=check_fn) + Args: + model (nn.Module): + The model whose submodules should be wrapped with activation checkpointing. + checkpoint_wrapper_fn (Optional[Callable[nn.Module]]) + A ``Callable`` which will wrap modules + check_fn (Optional[Callable[nn.Module, nn.Module]]) + A lambda function which will be passed each child submodule of ``model`` and returns + ``True`` or ``False`` depending on whether the submodule should be wrapped. + auto_wrap_policy (Optional[Callable[[nn.Module, bool, int], bool]]): A policy to wrap model's + submodules with AC. Note that if this is specified, it takes precedence over ``check_fn``. + Returns: None (`model` is modified inplace) + """ + # TODO: Importing inside function to avoid circular import issue between FSDP and + # checkpoint_wrapper. This can be resolved once wrap() APIs are decoupled from FSDP code. + from core.distributed.fsdp._wrap_utils import _construct_wrap_fn, _post_order_apply + from core.distributed.fsdp.wrap import ( + _Policy, + _recursive_wrap, + lambda_auto_wrap_policy, + ) + + policy = ( + auto_wrap_policy + if auto_wrap_policy is not None + else partial(lambda_auto_wrap_policy, lambda_fn=check_fn) + ) + if not callable(policy): + if not isinstance(policy, _Policy): + raise ValueError( + f"Expected {policy} to be callable or be a pre-defined wrap policy" + ) + target_module_to_kwargs = policy._run_policy( + model, ignored_modules=set(), root_kwargs={} + ) + wrap_fn = _construct_wrap_fn( + model, target_module_to_kwargs, checkpoint_wrapper_fn + ) + _post_order_apply(model, wrap_fn) + return + + _recursive_wrap( + module=model, + auto_wrap_policy=policy, # type: ignore[arg-type] + wrapper_cls=checkpoint_wrapper_fn, + ignored_modules=set(), + ignored_params=set(), + only_wrap_children=True, + ) diff --git a/mindnlp/core/distributed/algorithms/_comm_hooks/__init__.py b/mindnlp/core/distributed/algorithms/_comm_hooks/__init__.py new file mode 100644 index 000000000..eca9f2168 --- /dev/null +++ b/mindnlp/core/distributed/algorithms/_comm_hooks/__init__.py @@ -0,0 +1,7 @@ +from . import default_hooks as default + + +LOW_PRECISION_HOOKS = [ + default.fp16_compress_hook, + default.bf16_compress_hook, +] diff --git a/mindnlp/core/distributed/algorithms/_comm_hooks/default_hooks.py b/mindnlp/core/distributed/algorithms/_comm_hooks/default_hooks.py new file mode 100644 index 000000000..5aafed3b4 --- /dev/null +++ b/mindnlp/core/distributed/algorithms/_comm_hooks/default_hooks.py @@ -0,0 +1,192 @@ +# mypy: allow-untyped-defs +import functools +from typing import Optional + +from mindnlp import core +from mindnlp import core.distributed as dist + + +class DefaultState: + r""" + Stores state needed to perform the default communication algorithm within a communication hook. + + Args: + process_group (ProcessGroup): The process group to be used. + """ + + __slots__ = [ + "process_group", + "world_size", + "gradient_predivide_factor", + "gradient_postdivide_factor", + ] + + def __init__(self, process_group: dist.ProcessGroup): + if process_group is None: + raise ValueError(f"Expected to pass in an explicit ProcessGroup to {self}.") + self.process_group = process_group + self.world_size = dist.get_world_size(process_group) + # Setting two factors `self.gradient_predivide_factor` + # and `self.gradient_postdivide_factor` to avoid underflow and overflow + self.gradient_predivide_factor = self._get_gradient_predivide_factor( + self.world_size + ) + self.gradient_postdivide_factor = ( + self.world_size / self.gradient_predivide_factor + ) + + @staticmethod + def _get_gradient_predivide_factor(world_size: int) -> float: + factor: int = 1 + while world_size % factor == 0 and world_size / factor > factor: + factor *= 2 + return float(factor) + + +class LowPrecisionState(DefaultState): + r""" + Stores state needed to perform gradient communication in a lower precision within a communication hook. + + Communication hook will cast gradients back to the original + parameter precision specified by ``parameter_type`` (default: core.float32). + Builds on top of the :class:`DefaultState`. + + Args: + parameter_type (core.dtype): The precision of model's parameters. + Required for a hook to cast gradients back to a parameter's precision. + """ + + __slots__ = [ + "parameter_type", + ] + + def __init__( + self, + process_group, + parameter_type=core.float32, + ): + super().__init__(process_group) + self.parameter_type = parameter_type + + +def _decompress(state: LowPrecisionState, grad: core.Tensor): + """ + Casts gradients back to full parameter precision so that further computation happens in full precision. + """ + orig_grad_data = grad.data + grad.data = grad.data.to(state.parameter_type) + device_type = "" + try: + if grad.device.type == "privateuse1": + device_type = core._C._get_privateuse1_backend_name() + else: + device_type = grad.device.type + backend = getattr(torch, device_type) + except AttributeError as e: + raise AttributeError( + f"Device {grad.device} does not have a \ + corresponding backend registered as 'core.device_type'." + ) from e + + # Don't let this memory get reused until after the transfer. + orig_grad_data.record_stream(backend.current_stream()) # type: ignore[arg-type] + + +def allreduce_hook(state: DefaultState, grad: core.Tensor): + r""" + Implement the FSDP communication hook for ``all_reduce`` algorithm and a necessary pre- and post-division of gradients. + + Args: + state (DefaultState): State information, configures pre- and post-division factors. + grad (core.Tensor): A gradient for the local batch that needs to be communicated across ranks. + """ + # Average grad by pre-division factor. Together pre- and post-division factors + # lead to an overall averaging by world_size, required for consistency with PyTorch DDP. + # This is a two-step process to avoid potential underflow and overflow. + if state.gradient_predivide_factor > 1: + grad.div_(state.gradient_predivide_factor) + dist.all_reduce(grad, group=state.process_group) + # Average grad by post-division factor. + if state.gradient_postdivide_factor > 1: + grad.div_(state.gradient_postdivide_factor) + + +def reduce_scatter_hook(state: DefaultState, grad: core.Tensor, output: core.Tensor): + r""" + Implement the FSDP communication hook for ``reduce_scatter`` algorithm. + + For sharded FSDP strategies and a necessary pre- and post-division of gradients. + + Args: + state (DefaultState): State information, configures pre- and post-division factors. + grad (core.Tensor): An unsharded gradient for the local batch that needs to be + communicated across ranks. + output (core.Tensor): Stores a single shard of the gradient after ``reduce_scatter``. + """ + # Average grad by pre-division factor. + if state.gradient_predivide_factor > 1: + grad.div_(state.gradient_predivide_factor) + dist.reduce_scatter_tensor(output, grad, group=state.process_group) + # Average grad's shard by post-division factor. + if state.gradient_postdivide_factor > 1: + output.div_(state.gradient_postdivide_factor) + + +def _low_precision_hook( + prec: core.dtype, + state: LowPrecisionState, + grad: core.Tensor, + output: Optional[core.Tensor], +): + if grad.dtype != prec: + grad.data = grad.data.to(prec) + if output is not None: + if output.dtype != prec: + output.data = output.data.to(prec) + reduce_scatter_hook(state, grad, output) + _decompress(state, output) + else: + allreduce_hook(state, grad) + _decompress(state, grad) + + +def fp16_compress_hook( + state: LowPrecisionState, grad: core.Tensor, output: Optional[core.Tensor] = None +): + r""" + Implement FSDP communication hook for a simple gradient compression approach. + Casts ``grad`` to half-precision floating-point format (``core.float16``). + + It also averages gradients by ``world_size`` in two steps: first it pre-divides gradients by a + ``state.gradient_predivide_factor``, and after a communication step (``all_reduce`` or ``reduce_scatter``) + gradients are averaged by a ``state.gradient_postdivide_factor``. + Once post-division is done, compressed gradients are casted back to parameters' precision. + + Args: + state (LowPrecisionState): State information, configures pre- and post-division factors, parameters' precision. + grad (core.Tensor): A gradient for the local batch that needs to be communicated across ranks in a lower precision. + output (core.Tensor): Stores a single shard of the gradient after ``reduce_scatter``. + """ + fp16_hook = functools.partial(_low_precision_hook, core.float16) + return fp16_hook(state, grad, output) + + +def bf16_compress_hook( + state: LowPrecisionState, grad: core.Tensor, output: Optional[core.Tensor] = None +): + r""" + Implement FSDP communication hook for a simple gradient compression approach . + Casts ``grad`` to half-precision floating-point format. + + It also averages gradients by ``world_size`` in two steps: first it pre-divides gradients by a + ``state.gradient_predivide_factor``, and after a communication step (``all_reduce`` or ``reduce_scatter``) + gradients are averaged by a ``state.gradient_postdivide_factor``. + Once post-division is done, compressed gradients are casted back to parameters' precision. + + Args: + state (LowPrecisionState): State information, configures pre- and post-division factors, parameters' precision. + grad (core.Tensor): A gradient for the local batch that needs to be communicated across ranks in a lower precision. + output (core.Tensor): Stores a single shard of the gradient after ``reduce_scatter``. + """ + bf16_hook = functools.partial(_low_precision_hook, core.bfloat16) + return bf16_hook(state, grad, output) diff --git a/mindnlp/core/distributed/algorithms/_optimizer_overlap/__init__.py b/mindnlp/core/distributed/algorithms/_optimizer_overlap/__init__.py new file mode 100644 index 000000000..9460c12ce --- /dev/null +++ b/mindnlp/core/distributed/algorithms/_optimizer_overlap/__init__.py @@ -0,0 +1 @@ +from .optimizer_overlap import _as_overlapped_optim diff --git a/mindnlp/core/distributed/algorithms/_optimizer_overlap/optimizer_overlap.py b/mindnlp/core/distributed/algorithms/_optimizer_overlap/optimizer_overlap.py new file mode 100644 index 000000000..f3baf4d42 --- /dev/null +++ b/mindnlp/core/distributed/algorithms/_optimizer_overlap/optimizer_overlap.py @@ -0,0 +1,97 @@ +# mypy: allow-untyped-defs +import inspect +from abc import ABC, abstractmethod +from typing import Dict, Type + +from core.distributed.algorithms.ddp_comm_hooks.default_hooks import allreduce_hook +from core.distributed.algorithms.ddp_comm_hooks.optimizer_overlap_hooks import ( + _hook_then_optimizer, + _OptimizerHookState, +) +from core.distributed.fsdp import FullyShardedDataParallel +from core.distributed.optim import as_functional_optim +from core.nn.parallel import DistributedDataParallel +from core.optim import Optimizer + + +# Contains the mappings between the regular and overlapped optimizer types. +_registered_overlapped_optims: Dict[Type, Type] = {} + + +def register_overlapped(optim_cls): + def decorator(target_overlapped_optim_cls): + if target_overlapped_optim_cls in _registered_overlapped_optims: + raise ValueError( + f"{target_overlapped_optim_cls} already registered with optim_cls " + f"{_registered_overlapped_optims[optim_cls]} {optim_cls}, trying to" + f"re-register it for {optim_cls} is not supported." + ) + _registered_overlapped_optims[optim_cls] = target_overlapped_optim_cls + return target_overlapped_optim_cls + + return decorator + + +class OverlappedOptimizer(ABC): + def __init__(self, optim_cls: Type) -> None: + """ + Initialize the OverlappedOptimizer. + + Overlappedoptimizer is a base class that child classes can implement to + specify how different optimizers will register themselves with DDP. + """ + self.optim_cls = optim_cls + + @abstractmethod + def register_ddp(self, ddp: DistributedDataParallel) -> None: + """Registers the overlapped optimizer with DDP.""" + raise NotImplementedError( + f"{self.__class__.__name__} does not support overlapped DDP." + ) + + @abstractmethod + def register_fsdp(self, fsdp: FullyShardedDataParallel) -> None: + """Registers the overlapped optimizer with FSDP.""" + raise NotImplementedError( + f"{self.__class__.__name__} does not support overlapped FSDP." + ) + + +@register_overlapped(Optimizer) +class _OverlappedStandardOptimizer(OverlappedOptimizer): + """Overlaps a regular ``Optimizer``.""" + + def __init__(self, optim_cls: Type, params, *optim_args, **optim_kwargs) -> None: + super().__init__(optim_cls) + f_optim = as_functional_optim(self.optim_cls, *optim_args, **optim_kwargs) + self._opt_hook_state = _OptimizerHookState(f_optim, params) + + def register_ddp(self, ddp_inst: DistributedDataParallel): + # NOTE: using a custom communication hook and fused optimizer is not + # yet supported. + ddp_inst.register_comm_hook( # type: ignore[operator] + None, # wrapped hook state + _hook_then_optimizer(allreduce_hook, self._opt_hook_state), + ) + + # TODO: register_fsdp once FSDP supports communication hook. + def register_fsdp(self, fsdp: FullyShardedDataParallel) -> None: + """Register the overlapped optimizer with FSDP.""" + raise NotImplementedError( + f"{self.__class__.__name__} does not support overlapped FSDP." + ) + + +def _as_overlapped_optim(optim_cls: Type, params, *args, **kwargs): + """Return a new ``OverlappedOptimizer`` instance that supports ``optim_cls``.""" + for clz in inspect.getmro(optim_cls): + try: + return _registered_overlapped_optims[clz]( + optim_cls, params, *args, **kwargs + ) + except KeyError: + pass + + # Fallback to standard overlapped optimizer, which will raise errors if user + # is attempting to use an unsupported optimizer. + return _OverlappedStandardOptimizer(optim_cls, params, *args, **kwargs) diff --git a/mindnlp/core/distributed/algorithms/_quantization/__init__.py b/mindnlp/core/distributed/algorithms/_quantization/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/mindnlp/core/distributed/algorithms/_quantization/quantization.py b/mindnlp/core/distributed/algorithms/_quantization/quantization.py new file mode 100644 index 000000000..665282b79 --- /dev/null +++ b/mindnlp/core/distributed/algorithms/_quantization/quantization.py @@ -0,0 +1,150 @@ +# mypy: allow-untyped-defs +import functools +from enum import Enum + +from mindnlp import core +from mindnlp import core.distributed as dist + + +TORCH_HALF_MIN = core.finfo(core.float16).min +TORCH_HALF_MAX = core.finfo(core.float16).max + + +class DQuantType(Enum): + """ + Different quantization methods for auto_quantize API are identified here. + + auto_quantize API currently supports fp16 and bfp16 methods. + """ + + FP16 = ("fp16",) + BFP16 = "bfp16" + + def __str__(self) -> str: + return self.value + + +def _fp32_to_fp16_with_clamp(tensor: core.Tensor) -> core.Tensor: + return core.clamp(tensor, TORCH_HALF_MIN, TORCH_HALF_MAX).half() + + +def _quantize_tensor(tensor, qtype): + if not isinstance(tensor, core.Tensor): + raise RuntimeError( + f"_quantize_tensor expecting core.Tensor as input but found {type(tensor)}" + ) + if qtype == DQuantType.FP16: + return _fp32_to_fp16_with_clamp(tensor) + elif qtype == DQuantType.BFP16: + return core.ops.quantization._FloatToBfloat16Quantized(tensor) + else: + raise RuntimeError(f"Quantization type {qtype} is not supported") + + +def _quantize_tensor_list(tensor_list, qtype): + if not isinstance(tensor_list, list) or not all( + isinstance(p, core.Tensor) for p in tensor_list + ): + raise RuntimeError( + f"_quantize_tensor_list expecting list of core.Tensor as input but found {type(tensor_list)}" + ) + quantized_tensor_list = [_quantize_tensor(t, qtype) for t in tensor_list] + return quantized_tensor_list + + +def _dequantize_tensor(tensor, qtype, quant_loss=None): + if not isinstance(tensor, core.Tensor): + raise RuntimeError( + f"_dequantize_tensor expecting core.Tensor as input but found {type(tensor)}" + ) + if qtype == DQuantType.FP16: + if tensor.dtype != core.float16: + raise RuntimeError( + f"tensor dtype is {tensor.dtype} while expected to be FP16." + ) + elif tensor.dtype == core.float16 and quant_loss is None: + return tensor.float() + else: + return tensor.float() / quant_loss + elif qtype == DQuantType.BFP16: + if tensor.dtype != core.float16: + raise RuntimeError( + f"tensor dtype is {tensor.dtype} while expected to be FP16." + ) + else: + return core.ops.quantization._Bfloat16QuantizedToFloat(tensor) + else: + raise RuntimeError(f"Quantization type {qtype} is not supported") + + +def _dequantize_tensor_list(tensor_list, qtype, quant_loss=None): + if not isinstance(tensor_list, list) or not all( + isinstance(p, core.Tensor) for p in tensor_list + ): + raise RuntimeError( + f"_dequantize_tensor_list expecting list of core.Tensor as input but found {type(tensor_list)}" + ) + dequantized_tensor_list = [_dequantize_tensor(t, qtype) for t in tensor_list] + return dequantized_tensor_list + + +def auto_quantize(func, qtype, quant_loss=None): + """ + Quantize the input tensors, choose the precision types, and pass other necessary arguments and then dequantizes the output. + + Currently it only supports: + . FP16 and BFP16 quantization method supported for gloo and nccl backends + . all_gather, all_to_all collective ops + Note: BFP16 only supports 2D tensors. + Args: + func (Callable): A function representing collective operations. + qtype (QuantType): Quantization method + quant_loss (float, optional): This can be used to improve accuracy in the dequantization. + Returns: + (Callable): the same collective as func but enables automatic quantization/dequantization. + """ + + @functools.wraps(func) + def wrapper(*args, **kwargs): + group = kwargs.get("group", None) + async_op = kwargs.get("async_op", False) + if async_op is True: + raise RuntimeError("The async_op=True mode is not supported yet.") + if func == dist.all_gather: + tensors = args[0] + input_tensors = _quantize_tensor(args[1], qtype) + out_tensors = _quantize_tensor_list(tensors, qtype) + dist.all_gather(out_tensors, input_tensors, group=group, async_op=async_op) + for i, t in enumerate( + _dequantize_tensor_list(out_tensors, qtype, quant_loss=quant_loss) + ): + tensors[i] = t + + elif func == dist.all_to_all: + tensors = args[0] + input_tensors = _quantize_tensor_list(args[1], qtype) + out_tensors = _quantize_tensor_list(tensors, qtype) + dist.all_to_all(out_tensors, input_tensors, group=group, async_op=async_op) + for i, t in enumerate( + _dequantize_tensor_list(out_tensors, qtype, quant_loss=quant_loss) + ): + tensors[i] = t + + elif func == dist.all_to_all_single: + tensors = args[0] + out_splits = kwargs.get("out_splits", None) + in_splits = kwargs.get("in_splits", None) + # Quantizing the input/output tensor + input_tensors = _quantize_tensor(args[1], qtype) + out_tensors = _quantize_tensor(tensors, qtype) + dist.all_to_all_single( + out_tensors, input_tensors, out_splits, in_splits, group=group + ) + for i, t in enumerate( + _dequantize_tensor(out_tensors, qtype, quant_loss=quant_loss) + ): + tensors[i] = t + else: + raise RuntimeError(f"The collective op {func} is not supported yet") + + return wrapper diff --git a/mindnlp/core/distributed/algorithms/ddp_comm_hooks/__init__.py b/mindnlp/core/distributed/algorithms/ddp_comm_hooks/__init__.py new file mode 100644 index 000000000..9ba9bae59 --- /dev/null +++ b/mindnlp/core/distributed/algorithms/ddp_comm_hooks/__init__.py @@ -0,0 +1,110 @@ +# mypy: allow-untyped-defs +from enum import Enum +from functools import partial + +from mindnlp import core.distributed as dist + +from . import ( + debugging_hooks as debugging, + default_hooks as default, + optimizer_overlap_hooks as optimizer_overlap, + powerSGD_hook as powerSGD, + quantization_hooks as quantization, +) + + +__all__ = ["DDPCommHookType", "register_ddp_comm_hook"] + + +def _ddp_comm_hook_wrapper(comm_hook, model, state): + model.register_comm_hook(state, comm_hook) + + +def _powerSGD_comm_hook_wrapper( + comm_hook, + model, + state, + matrix_approximation_rank, + start_powerSGD_iter=1_000, +): + """ + Wrap PowerSGD communication hook. + + To be consistent with the wrappers of other DDP comm hooks, the input state only needs to be a process group, + which will be wrapped up with other state info. + """ + powerSGD_state = powerSGD.PowerSGDState( + process_group=state, + matrix_approximation_rank=matrix_approximation_rank, + start_powerSGD_iter=start_powerSGD_iter, + ) + model.register_comm_hook(powerSGD_state, comm_hook) + + +class DDPCommHookType(Enum): + """ + Enumerate ``ddp_comm_hooks`` and ``ddp_comm_hook_wrapper`` communucation hook types. + + DDPCommHookType enumerates the hooks of ``core.distributed.algorithms.ddp_comm_hooks`` + as names and ``ddp_comm_hook_wrapper`` partials with hook specified. As an example, + you can register allreduce hook by + ``DDPCommHookType.ALLREDUCE.value(model=model, state=process_group)``. + """ + + ALLREDUCE = partial(_ddp_comm_hook_wrapper, comm_hook=default.allreduce_hook) + FP16_COMPRESS = partial( + _ddp_comm_hook_wrapper, comm_hook=default.fp16_compress_hook + ) + BF16_COMPRESS = partial( + _ddp_comm_hook_wrapper, comm_hook=default.bf16_compress_hook + ) + QUANTIZE_PER_TENSOR = partial( + _ddp_comm_hook_wrapper, comm_hook=quantization.quantization_pertensor_hook + ) + QUANTIZE_PER_CHANNEL = partial( + _ddp_comm_hook_wrapper, comm_hook=quantization.quantization_perchannel_hook + ) + POWER_SGD = partial( + _powerSGD_comm_hook_wrapper, + comm_hook=powerSGD.powerSGD_hook, + matrix_approximation_rank=1, + ) + # Rank-2 PowerSGD can give a higher accuracy than the default rank-1 version, + # but it runs slower and consumes more memory. + POWER_SGD_RANK2 = partial( + _powerSGD_comm_hook_wrapper, + comm_hook=powerSGD.powerSGD_hook, + matrix_approximation_rank=2, + ) + # Batching can lead to a faster training at the cost of accuracy. + BATCHED_POWER_SGD = partial( + _powerSGD_comm_hook_wrapper, + comm_hook=powerSGD.batched_powerSGD_hook, + matrix_approximation_rank=1, + ) + BATCHED_POWER_SGD_RANK2 = partial( + _powerSGD_comm_hook_wrapper, + comm_hook=powerSGD.batched_powerSGD_hook, + matrix_approximation_rank=2, + ) + NOOP = partial( + _ddp_comm_hook_wrapper, + comm_hook=debugging.noop_hook, + ) + + +def register_ddp_comm_hook(comm_hook_type: DDPCommHookType, model, state=None): + """ + Register ``ddp_comm_hooks`` to DDP model. + + Registers the hooks of ``core.distributed.algorithms.ddp_comm_hooks`` + to the DDP model. User can specify the type of hook as an enum + ``DDPCommHookType`` type using ``comm_hook_type`` input. State input will + be passed to the model. + Uses Python comm hook implementations. + + Example:: + >>> # xdoctest: +SKIP + >>> register_ddp_comm_hook(DDPCommHookType.FP16_COMPRESS, model, state) + """ + comm_hook_type.value(model=model, state=state) diff --git a/mindnlp/core/distributed/algorithms/ddp_comm_hooks/ddp_zero_hook.py b/mindnlp/core/distributed/algorithms/ddp_comm_hooks/ddp_zero_hook.py new file mode 100644 index 000000000..92970f838 --- /dev/null +++ b/mindnlp/core/distributed/algorithms/ddp_comm_hooks/ddp_zero_hook.py @@ -0,0 +1,460 @@ +# mypy: allow-untyped-defs +import weakref +from typing import Any, Callable, List, Optional + +from mindnlp import core +from mindnlp import core.distributed as dist +from core.distributed.optim import ZeroRedundancyOptimizer +from core.distributed.optim.zero_redundancy_optimizer import _OverlapStatus +from core.nn.parallel.distributed import DistributedDataParallel + + +__all__ = ["hook_with_zero_step", "hook_with_zero_step_interleaved"] + +# Functional optimizers require passing a list of gradients to their `step()` +# method, and ZeRO requires a functional optimizer to overlap with DDP +# Passing a `None` instead of an actual gradient indicates to the optimizer +# to not update the corresponding parameter +_NO_PARAM_UPDATE: None = None + + +def _perform_local_step( + bucket: dist.GradBucket, + zero: ZeroRedundancyOptimizer, + rank: int, +): + r""" + Perform a local optimizer step using the gradients provided by ``bucket``. + + Arguments: + bucket (dist.GradBucket): the bucket providing the gradients. + zero (ZeroRedundancyOptimizer): the :class:`ZeroRedundancyOptimizer` + instance to perform the :meth:`_local_step`. + rank (int): the calling process's rank. + + .. warning:: + This function assumes that appropriate synchronization has taken place + so that the bucket's gradients can be used. + """ + overlap_info = zero._overlap_info + bucket_index = bucket.index() + assert ( + len(zero.optim.param_groups) == 1 + ), "Overlapping DDP with ZeRO only supports a single parameter group" + + # Construct the `gradients` input for the local optimizer step, which + # expects `None` in a list position to indicate that the corresponding + # parameter should not be updated + num_local_optim_params = len(zero.optim.param_groups[0]["params"]) + gradients: List[Optional[core.Tensor]] = [ + _NO_PARAM_UPDATE for _ in range(num_local_optim_params) + ] + assert ( + bucket_index in overlap_info.offsets + ), f"Bucket index {bucket_index} was not assigned to rank {rank}" + gradients_offset = overlap_info.offsets[bucket_index] + bucket_assignment = zero._bucket_assignments_per_rank[rank][bucket_index] + bucket_offset = bucket_assignment.offset + length = len(bucket_assignment.parameters) + bucket_gradients = bucket.gradients()[bucket_offset : bucket_offset + length] + for i, grad in enumerate(bucket_gradients): + gradients[gradients_offset + i] = grad + + zero._local_step(gradients) + + +def _broadcast_bucket( + bucket_index: int, + zero: ZeroRedundancyOptimizer, +): + r""" + Broadcasts a bucket's parameters. + + Arguments: + bucket_index (int): the index of the bucket corresponding to the + parameters to broadcast. + zero (ZeroRedundancyOptimizer): the calling process's + :class:`ZeroRedundancyOptimizer` instance. + """ + overlap_info = zero._overlap_info + assert ( + len(overlap_info.assigned_ranks_per_bucket) > bucket_index + ), "`assigned_ranks_per_bucket` is not fully constructed" + # Sort to ensure the same ordering across ranks + assigned_ranks = sorted(overlap_info.assigned_ranks_per_bucket[bucket_index]) + assert len(assigned_ranks) > 0, ( + f"Bucket {bucket_index} should be " "assigned to at least one rank" + ) + for assigned_rank in assigned_ranks: + bucket_assignments = zero._bucket_assignments_per_rank[assigned_rank] + if bucket_index in bucket_assignments: + send_tensor = bucket_assignments[bucket_index].tensor + assert send_tensor is not None + overlap_info.broadcast_handles.append( + dist.broadcast( + send_tensor, + src=dist.get_global_rank(zero.process_group, assigned_rank), + group=zero.process_group, + async_op=True, + ) + ) + + +def _save_ddp_bucket_info( + bucket: dist.GradBucket, + zero: ZeroRedundancyOptimizer, +): + r""" + Save :class:`DistributedDataParallel` gradient bucket information for :class:`ZeroRedundancyOptimizer` instance ``zero``. + + In particular, this function is meant to be called upon seeing each + gradient bucket to use when overlapping, meaning it does not save or compute any global + information. + + Arguments: + bucket (dist.GradBucket): the current gradient bucket. + zero (ZeroRedundancyOptimizer): the calling process's + :class:`ZeroRedundancyOptimizer` instance. + """ + overlap_info = zero._overlap_info + bucket_params = bucket.parameters() + assert len(bucket_params) > 0, "Empty bucket" + + # Save the parameters in the bucket + overlap_info.params_per_bucket.append(bucket_params) + if overlap_info.shard_buckets: + # Additionally save the bucket size for the assignment heuristic to use + bucket_size = 0 + for param in bucket_params: + bucket_size += param.numel() + assert overlap_info.total_size is not None + overlap_info.total_size += bucket_size + + +def _hook_with_zero_step_setup( + ddp_ref: weakref.ReferenceType, + zero: ZeroRedundancyOptimizer, + bucket: dist.GradBucket, +): + r""" + Encapsulate the setup logic for :func:`hook_with_zero_step` and :func:`hook_with_zero_step_interleaved`. + + This means the logic to run in the + hook before the backward pass and optimizer step can actually be + overlapped. This is factored out since it is common to both + :func:`hook_with_zero_step` and :func:`hook_with_zero_step_interleaved`. + + Arguments: + ddp_ref (weakref.ReferenceType): weak reference to the process's + :class:`DistributedDataParallel` instance. + zero (ZeroRedundancyOptimizer): the calling process's + :class:`ZeroRedundancyOptimizer` instance. + bucket (dist.GradBucket): the current gradient bucket. + """ + # Proceed as normal until the DDP buckets have been rebuilt + if not ddp_ref()._has_rebuilt_buckets: # type: ignore[union-attr] + assert zero._overlap_info.status == _OverlapStatus.UNINITIALIZED + return + + bucket_index = bucket.index() + overlap_info = zero._overlap_info + if overlap_info.status == _OverlapStatus.UNINITIALIZED: + overlap_info.status = _OverlapStatus.DDP_HAS_REBUILT_BUCKETS + + if overlap_info.status == _OverlapStatus.DDP_HAS_REBUILT_BUCKETS: + if bucket_index == 0 and len(overlap_info.params_per_bucket) > 0: + # This corresponds to the first bucket of the backward pass + # immediately after all information has been saved, so we + # can perform the delayed ZeRO initialization + zero._init_zero_for_overlap() + else: + # Once DDP buckets have been rebuilt but ZeRO has not been + # properly initialized yet, save the information needed + _save_ddp_bucket_info(bucket, zero) + + +def hook_with_zero_step( + hook: Callable[[Any, dist.GradBucket], core.futures.Future], + ddp: DistributedDataParallel, + zero: ZeroRedundancyOptimizer, + shard_buckets: bool = False, +) -> Callable[[Any, dist.GradBucket], core.futures.Future[core.Tensor]]: + r""" + Modify ``hook`` to overlap :class:`ZeroRedundancyOptimizer` optimizer step with :class:`DistributedDataParallel` backward pass. + + This approach overlaps the optimizer computation and communication with the + backward communication. In particular, the backward computation proceeds + contiguously, and the optimizer computation follows, overlapping with + outstanding backward communication (i.e. all-reduces) and possibly other + optimizer communication (i.e. broadcasts). + The optimizer step computation begins after the last gradient bucket computation has finished. + + This approach may be preferred over :meth:`hook_with_zero_step_interleaved` + if communication is relatively slow compared to computation. + + Arguments: + hook (Callable[[Any, dist.GradBucket], core.futures.Future]): the hook + to modify. + ddp (DistributedDataParallel): the :class:`DistributedDataParallel` + instance to use. + zero (ZeroRedundancyOptimizer): the :class:`ZeroRedundancyOptimizer` + instance to use. + shard_buckets (bool): if ``True``, then the assignment of each + :class:`DistributedDataParallel` bucket is partitioned across + possibly multiple :class:`ZeroRedundancyOptimizer` instances (i.e. + across possibly multiple ranks) to approximate uniformity; if + ``False``, then each bucket is wholly assigned to a single + :class:`ZeroRedundancyOptimizer` instance (i.e. to a single rank). + + Returns: + The modified hook. + + Raises: + ValueError: if ``zero`` was constructed with ``overlap_with_ddp=False``. + RuntimeError: if using any backend other than NCCL/HCCL since currently + Gloo may hang. + + .. warning:: + Given the way that overlapping :class:`DistributedDataParallel` with + :class:`ZeroRedundancyOptimizer` is currently implemented, the first + two or three training iterations do not perform parameter updates in + the optimizer step, depending on if ``static_graph=False`` or + ``static_graph=True``, respectively. This is because it needs + information about the gradient bucketing strategy used by + :class:`DistributedDataParallel`, which is not finalized until the + second forward pass if ``static_graph=False`` or until the third + forward pass if ``static_graph=True``. + """ + if not zero._overlap_with_ddp: + raise ValueError( + "ZeroRedundancyOptimizer must be constructed with " + "`overlap_with_ddp=True` to use this hook properly" + ) + ddp_ref = weakref.ref(ddp) + + # NOTE: Gloo may hang with this overlapping approach, so we require + # NCCL/HCCL backend for now; see https://github.com/pytorch/pytorch/issues/62300 + pg = dist.get_backend(ddp_ref().process_group) # type: ignore[union-attr] + if (pg != dist.Backend.NCCL) and (pg != "hccl"): + raise RuntimeError( + "Overlapping DDP with ZeRO using this approach currently requires " + "NCCL/HCCL backend to avoid hangs" + ) + + if shard_buckets: + zero._overlap_info.shard_buckets = True + zero._overlap_info.total_size = 0 + + def hook_with_zero_fn( + state: Any, + bucket: dist.GradBucket, + ) -> core.futures.Future[core.Tensor]: + r""" + Return :class:`Future` that runs the optimizer step if this corresponds to the last gradient bucket. + + Perform equivalent of :class:`ZeroRedundancyOptimizer` :meth:`step` if ``bucket`` is last gradient bucket. + The function gives a gradient bucket tensor and + performs additional computation on the iteration that + the :class:`DistributedDataParallel` buckets are rebuilt to collect + information used to implement the modified hook. + + Arguments: + state (Any): any state for the hook. + bucket (dist.GradBucket): the :class:`DistributedDataParallel` + gradient bucket. + """ + fut = hook(state, bucket) + _hook_with_zero_step_setup(ddp_ref, zero, bucket) + if zero._overlap_info.status != _OverlapStatus.INITIALIZED: + return fut + + overlap_info = zero._overlap_info + bucket_index = bucket.index() + rank = zero.global_rank + + assert overlap_info.status == _OverlapStatus.INITIALIZED + assert ( + len(overlap_info.assigned_ranks_per_bucket) > bucket_index + ), "`assigned_ranks_per_bucket` is not fully constructed" + assigned_to_bucket = ( + rank in overlap_info.assigned_ranks_per_bucket[bucket_index] + ) + + # Save the bucket reference and all-reduce future for the final bucket + if assigned_to_bucket: + overlap_info.bucket_index_to_bucket[bucket_index] = bucket + overlap_info.bucket_index_to_future[bucket_index] = fut + + # Check that buckets are indexed incrementally starting from 0 in the + # order of their autograd hooks firing + if len(overlap_info.bucket_indices_seen) > 0: + assert ( + overlap_info.bucket_indices_seen[-1] == bucket_index - 1 + ), "Bucket indices are not in incremental order" + else: + assert bucket_index == 0, "Bucket indices do not start from 0" + overlap_info.bucket_indices_seen.append(bucket_index) + + # Directly return the future without any optimizer computation if this + # is not the last bucket + num_buckets = len(overlap_info.params_per_bucket) + is_last_bucket = bucket_index == num_buckets - 1 + if not is_last_bucket: + return fut + + # Perform partial optimizer step on all buckets after the final + # bucket has been computed + # NOTE: This should not be chained as a callback to the last bucket's + # all-reduce future since that would add synchronization that delays + # all optimizer computation to wait for that last all-reduce + for bucket_index in range(num_buckets): + assigned_ranks = overlap_info.assigned_ranks_per_bucket[bucket_index] + if rank in assigned_ranks: + # Wait on the bucket's all-reduce future to ensure correct + # gradients + assert bucket_index in overlap_info.bucket_index_to_future, ( + f"All-reduce future for bucket {bucket_index} not saved " + f"on rank {rank}" + ) + allreduce_future = overlap_info.bucket_index_to_future[bucket_index] + allreduce_future.wait() + + # Perform the partial optimizer step + curr_bucket = overlap_info.bucket_index_to_bucket[bucket_index] + _perform_local_step(curr_bucket, zero, rank) + + _broadcast_bucket(bucket_index, zero) + + # Ensure that all parameter updates are finished before the + # next forward pass + overlap_info.wait_for_broadcasts() + overlap_info.clear_per_iter_info() + + return fut + + return hook_with_zero_fn + + +def hook_with_zero_step_interleaved( + hook: Callable[[Any, dist.GradBucket], core.futures.Future], + ddp: DistributedDataParallel, + zero: ZeroRedundancyOptimizer, + shard_buckets: bool = False, +) -> Callable[[Any, dist.GradBucket], core.futures.Future[core.Tensor]]: + r""" + Modify ``hook`` to overlap :class:`ZeroRedundancyOptimizer` optimizer step with :class:`DistributedDataParallel` backward pass + + This approach overlaps the optimizer computation and communication with the + backward computation and communication. In particular, once a bucket's + gradients have been computed, the optimizer computation using those + gradients is launched (though the actual computation must wait for the + bucket's all-reduce to complete). This yields an interleaving of all- + reduces and broadcasts in the communication stream. + + This approach may be preferred over :meth:`hook_with_zero_step` if + communication is relatively fast compared to computation. + + Arguments: + hook (Any * dist.GradBucket -> core.futures.Future): the hook to + modify. + ddp (DistributedDataParallel): the :class:`DistributedDataParallel` + instance to use. + zero (ZeroRedundancyOptimizer): the :class:`ZeroRedundancyOptimizer` + instance to use. + shard_buckets (bool): if ``True``, then the assignment of each + :class:`DistributedDataParallel` bucket is partitioned across + possibly multiple :class:`ZeroRedundancyOptimizer` instances (i.e. + across possibly multiple ranks) to approximate uniformity; if + ``False``, then each bucket is wholly assigned to a single + :class:`ZeroRedundancyOptimizer` instance (i.e. to a single rank). + + Returns: + The modified hook. + + Raises: + ValueError: if ``zero`` was constructed with ``overlap_with_ddp=False``. + RuntimeError: if using any backend other than NCCL since currently + Gloo may hang. + + .. warning:: + Given the way that overlapping :class:`DistributedDataParallel` with + :class:`ZeroRedundancyOptimizer` is currently implemented, the first + two or three training iterations do not perform parameter updates in + the optimizer step, depending on if ``static_graph=False`` or + ``static_graph=True``, respectively. This is because it needs + information about the gradient bucketing strategy used by + :class:`DistributedDataParallel`, which is not finalized until the + second forward pass if ``static_graph=False`` or until the third + forward pass if ``static_graph=True``. + """ + if not zero._overlap_with_ddp: + raise ValueError( + "ZeroRedundancyOptimizer must be constructed with " + "`overlap_with_ddp=True` to use this hook properly" + ) + ddp_ref = weakref.ref(ddp) + + # NOTE: Gloo may hang with this overlapping approach, so we require + # NCCL/HCCL backend for now; see https://github.com/pytorch/pytorch/issues/62300 + pg = dist.get_backend(ddp_ref().process_group) # type: ignore[union-attr] + if (pg != dist.Backend.NCCL) and (pg != "hccl"): + raise RuntimeError( + "Overlapping DDP with ZeRO using this approach currently requires " + "NCCL/HCCL backend to avoid hangs" + ) + + if shard_buckets: + zero._overlap_info.shard_buckets = True + zero._overlap_info.total_size = 0 + + def hook_with_zero_interleaved_fn( + state, + bucket: dist.GradBucket, + ) -> core.futures.Future[core.Tensor]: + r""" + Return :class:`Future` that gives gradient bucket tensor and performs partial :class:`ZeroRedundancyOptimizer` :meth:`step`. + + This function uses the gradients in gradient in given bucket to perform a partial + :class:`ZeroRedundancyOptimizer` :meth:`step` + + Arguments: + state: any state for the hook. + bucket (dist.GradBucket): the :class:`DistributedDataParallel` + gradient bucket. + """ + fut = hook(state, bucket) + _hook_with_zero_step_setup(ddp_ref, zero, bucket) + if zero._overlap_info.status != _OverlapStatus.INITIALIZED: + return fut + + def zero_step(fut: core.futures.Future) -> core.Tensor: + r""" + Perform partial :class:`ZeroRedundancyOptimizer` :meth:`step` using gradients in the :class:`DistributedDataParallel`. + + Returns: + A :class:`core.Tensor` representing the contents of the + gradient bucket. + """ + overlap_info = zero._overlap_info + bucket_index = bucket.index() + rank = zero.global_rank + + assigned_ranks = overlap_info.assigned_ranks_per_bucket[bucket_index] + overlap_info.bucket_indices_seen.append(bucket_index) + if rank in assigned_ranks: + _perform_local_step(bucket, zero, rank) + + _broadcast_bucket(bucket_index, zero) + + num_buckets = len(overlap_info.params_per_bucket) + if len(overlap_info.bucket_indices_seen) == num_buckets: + # Ensure that all parameter updates are finished before the + # next forward pass + overlap_info.wait_for_broadcasts() + overlap_info.clear_per_iter_info() + + return bucket.buffer() + + return fut.then(zero_step) + + return hook_with_zero_interleaved_fn diff --git a/mindnlp/core/distributed/algorithms/ddp_comm_hooks/debugging_hooks.py b/mindnlp/core/distributed/algorithms/ddp_comm_hooks/debugging_hooks.py new file mode 100644 index 000000000..fa891fa8f --- /dev/null +++ b/mindnlp/core/distributed/algorithms/ddp_comm_hooks/debugging_hooks.py @@ -0,0 +1,29 @@ +from typing import Any + +from mindnlp import core +from core.distributed import GradBucket + + +__all__ = ["noop_hook"] + + +def noop_hook(_: Any, bucket: GradBucket) -> core.futures.Future[core.Tensor]: + """ + Return a future that wraps the input, so it is a no-op that does not incur any communication overheads. + + This hook should **only** be used for headroom analysis of allreduce optimization, + instead of the normal gradient synchronization. + For example, if only less than 10% speedup of training time can be observed after this hook is registered, + it usually implies that allreduce is not a performance bottleneck for this case. + Such instrumentation can be particularly useful + if GPU traces cannot be easily retrieved or the trace analysis is complicated + some factors such as the overlap between allreduce and computation or the desynchronization across ranks. + + Example:: + >>> # xdoctest: +SKIP + >>> ddp_model.register_comm_hook(None, noop_hook) + """ + fut: core.futures.Future[core.Tensor] = core.futures.Future() + fut.set_result(bucket.buffer()) + + return fut diff --git a/mindnlp/core/distributed/algorithms/ddp_comm_hooks/default_hooks.py b/mindnlp/core/distributed/algorithms/ddp_comm_hooks/default_hooks.py new file mode 100644 index 000000000..7358bbd2a --- /dev/null +++ b/mindnlp/core/distributed/algorithms/ddp_comm_hooks/default_hooks.py @@ -0,0 +1,205 @@ +# mypy: allow-untyped-defs +from typing import Any, Callable, cast, Tuple + +from mindnlp import core +from mindnlp import core.distributed as dist + + +__all__ = [ + "allreduce_hook", + "fp16_compress_hook", + "bf16_compress_hook", + "fp16_compress_wrapper", + "bf16_compress_wrapper", +] + + +def _allreduce_fut( + process_group: dist.ProcessGroup, tensor: core.Tensor +) -> core.futures.Future[core.Tensor]: + """Average the input gradient tensor by allreduce and returns a future.""" + group_to_use = process_group if process_group is not None else dist.group.WORLD + + # Apply the division first to avoid overflow, especially for FP16. + tensor.div_(group_to_use.size()) + + return ( + dist.all_reduce(tensor, group=group_to_use, async_op=True) + .get_future() + .then(lambda fut: fut.value()[0]) + ) + + +def allreduce_hook( + process_group: dist.ProcessGroup, bucket: dist.GradBucket +) -> core.futures.Future[core.Tensor]: + """ + Call ``allreduce`` using ``GradBucket`` tensors. + + Once gradient tensors are aggregated across all workers, its ``then`` + callback takes the mean and returns the result. + + If user registers this DDP communication hook, + DDP results is expected to be same as the case where no hook was registered. + Hence, this won't change behavior of DDP and user can use this as a reference + or modify this hook to log useful information or any other purposes while + unaffecting DDP behavior. + + Example:: + >>> # xdoctest: +SKIP + >>> ddp_model.register_comm_hook(process_group, allreduce_hook) + """ + return _allreduce_fut(process_group, bucket.buffer()) + + +def _compress_hook( + dtype: core.dtype, + process_group: dist.ProcessGroup, + bucket: dist.GradBucket, +) -> core.futures.Future[core.Tensor]: + group_to_use = process_group if process_group is not None else dist.group.WORLD + world_size = group_to_use.size() + + buffer = ( + cast(Tuple[core.Tensor, ...], bucket)[0] + if isinstance(bucket, tuple) + else bucket.buffer() + ) + compressed_tensor = buffer.to(dtype).div_(world_size) + + def decompress(fut): + decompressed_tensor = buffer + # Decompress in place to reduce the peak memory. + # See: https://github.com/pytorch/pytorch/issues/45968 + value = fut if isinstance(fut, core.Tensor) else fut.value()[0] + decompressed_tensor.copy_(value) + return decompressed_tensor + + if core._utils.is_compiling(): + grad = dist._functional_collectives.all_reduce( + compressed_tensor, "sum", group_to_use + ) + return decompress(grad) + else: + fut = dist.all_reduce( + compressed_tensor, group=group_to_use, async_op=True + ).get_future() + return fut.then(decompress) + + +def fp16_compress_hook( + process_group: dist.ProcessGroup, + bucket: dist.GradBucket, +) -> core.futures.Future[core.Tensor]: + """ + Compress by casting ``GradBucket`` to ``core.float16`` divided by process group size. + + This DDP communication hook implements a simple gradient compression + approach that casts ``GradBucket`` tensor to half-precision floating-point format (``core.float16``) + and then divides it by the process group size. + It allreduces those ``float16`` gradient tensors. Once compressed gradient + tensors are allreduced, the chained callback ``decompress`` casts it back to the input data type (such as ``float32``). + + Example:: + >>> # xdoctest: +SKIP + >>> ddp_model.register_comm_hook(process_group, fp16_compress_hook) + """ + return _compress_hook(core.float16, process_group, bucket) + + +def bf16_compress_hook( + process_group: dist.ProcessGroup, + bucket: dist.GradBucket, +) -> core.futures.Future[core.Tensor]: + """ + Warning: This API is experimental, and it requires NCCL version later than 2.9.6. + + This DDP communication hook implements a simple gradient compression + approach that casts ``GradBucket`` tensor to half-precision + `Brain floating point format `_ (``core.bfloat16``) + and then divides it by the process group size. + It allreduces those ``bfloat16`` gradient tensors. Once compressed gradient + tensors are allreduced, the chained callback ``decompress`` casts it back to the input data type (such as ``float32``). + + Example:: + >>> # xdoctest: +SKIP + >>> ddp_model.register_comm_hook(process_group, bf16_compress_hook) + """ + return _compress_hook(core.bfloat16, process_group, bucket) + + +def fp16_compress_wrapper( + hook: Callable[[Any, dist.GradBucket], core.futures.Future[core.Tensor]] +) -> Callable[[Any, dist.GradBucket], core.futures.Future[core.Tensor]]: + """ + Cast input tensor to ``core.float16``, cast result of hook back to input dtype. + + This wrapper casts the input gradient tensor of a given DDP communication hook to half-precision + floating point format (``core.float16``), and casts the resulting tensor of the given hook back to + the input data type, such as ``float32``. + Therefore, ``fp16_compress_hook`` is equivalent to ``fp16_compress_wrapper(allreduce_hook)``. + + Example:: + >>> # xdoctest: +SKIP + >>> state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1, start_powerSGD_iter=10) + >>> ddp_model.register_comm_hook(state, fp16_compress_wrapper(powerSGD_hook)) + """ + + def fp16_compress_wrapper_hook( + hook_state, bucket: dist.GradBucket + ) -> core.futures.Future[core.Tensor]: + # Cast bucket tensor to FP16. + bucket.set_buffer(bucket.buffer().to(core.float16)) + + fut = hook(hook_state, bucket) + + def decompress(fut): + decompressed_tensor = bucket.buffer() + # Decompress in place to reduce the peak memory. + # See: https://github.com/pytorch/pytorch/issues/45968 + decompressed_tensor.copy_(fut.value()) + return decompressed_tensor + + # Decompress after hook has run. + return fut.then(decompress) + + return fp16_compress_wrapper_hook + + +def bf16_compress_wrapper( + hook: Callable[[Any, dist.GradBucket], core.futures.Future[core.Tensor]] +) -> Callable[[Any, dist.GradBucket], core.futures.Future[core.Tensor]]: + """ + Warning: This API is experimental, and it requires NCCL version later than 2.9.6. + + This wrapper casts the input gradient tensor of a given DDP communication hook to half-precision + `Brain floating point format `_ (``core.bfloat16``), + and casts the resulting tensor of the given hook back to the input data type, such as ``float32``. + + Therefore, ``bf16_compress_hook`` is equivalent to ``bf16_compress_wrapper(allreduce_hook)``. + + Example:: + >>> # xdoctest: +SKIP + >>> state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1, start_powerSGD_iter=10) + >>> ddp_model.register_comm_hook(state, bf16_compress_wrapper(powerSGD_hook)) + """ + + def bf16_compress_wrapper_hook( + hook_state, bucket: dist.GradBucket + ) -> core.futures.Future[core.Tensor]: + # Cast bucket tensor to BF16. + bucket.set_buffer(bucket.buffer().to(core.bfloat16)) + + fut = hook(hook_state, bucket) + + def decompress(fut): + decompressed_tensor = bucket.buffer() + # Decompress in place to reduce the peak memory. + # See: https://github.com/pytorch/pytorch/issues/45968 + decompressed_tensor.copy_(fut.value()) + return decompressed_tensor + + # Decompress after hook has run. + return fut.then(decompress) + + return bf16_compress_wrapper_hook diff --git a/mindnlp/core/distributed/algorithms/ddp_comm_hooks/mixed_precision_hooks.py b/mindnlp/core/distributed/algorithms/ddp_comm_hooks/mixed_precision_hooks.py new file mode 100644 index 000000000..bee797dca --- /dev/null +++ b/mindnlp/core/distributed/algorithms/ddp_comm_hooks/mixed_precision_hooks.py @@ -0,0 +1,86 @@ +from dataclasses import dataclass +from typing import Any, no_type_check + +from mindnlp import core +from mindnlp import core.distributed as dist +from core.autograd import Variable +from core.distributed.utils import _free_storage + + +@dataclass +class _AllreduceUpcastHookState: + """ + State to manage DDP mixed precision in backward / gradient communication. + + This contains a weakref to the DDP module for access to reducer and process + group, and a stream to run parameter and gradient upcasts. + """ + + ddp_weakref: Any + upcast_stream: core.Stream + wait_for_stream_enqueued: bool = False + + +@no_type_check +def _reducer_allreduce_and_upcast_hook( + hook_state: _AllreduceUpcastHookState, bucket: dist.GradBucket +) -> core.futures.Future[core.Tensor]: + """ + Perform allreduce in precision ``reduce_dtype``, upcast to prepare for optimizer. + + Performs allreduce in the reduced precision given by DDP's mixed precision + reduce_dtype, and upcasts parameters and gradients to fp32 in preparation + to run the optimizer. + """ + ddp_weakref = hook_state.ddp_weakref + reducer, process_group = ddp_weakref().reducer, ddp_weakref().process_group + # Cast bucket if different than param_dtype. + if ( + ddp_weakref().mixed_precision.param_dtype + != ddp_weakref().mixed_precision.reduce_dtype + ): + # Cast bucket tensor to reduce_dtype + bucket.set_buffer( + bucket.buffer().to(ddp_weakref().mixed_precision.reduce_dtype) + ) + fut = reducer._run_allreduce_hook(bucket) + ret_fut = core.futures.Future() + stream = hook_state.upcast_stream + with core.get_device_module().stream(stream): + fut.wait() + bucket.buffer().div_(process_group.size()) + ret_fut.set_result(bucket.buffer()) + + # Upcast parameters and gradients so optimizer step can run in fp32. + for p in bucket.parameters(): + p.data = p._fp_param + # free storage for mp param as it will be allocated again in next + # forward pass. + _free_storage(p._mp_param) + p.grad.data = p.grad.to(p.data.dtype) + + # enqueue a callback to wait for this stream at end of backward + def wait_for_stream_cb(): + core.accelerator.current_stream().wait_stream(stream) + # Remove post-backward hooks since they are re-installed in next + # iteration, similar to FSDP. + # Parameters that don't require grad still needed to be casted since + # they may participate in computation. However, they would not be recast + # by hook above as they don't have a grad hook installed, so cast them + # back here. + for _, p in ddp_weakref().module.named_parameters(): + if hasattr(p, "_ddp_mp_hook_state"): + p._ddp_mp_hook_state[1].remove() + delattr(p, "_ddp_mp_hook_state") + if not p.requires_grad and not hasattr(p, "_ddp_ignored"): + p.data = p._fp_param + + # reset for next backward pass + hook_state.wait_for_stream_enqueued = False + + if not hook_state.wait_for_stream_enqueued: + Variable._execution_engine.queue_callback(wait_for_stream_cb) + # mark that the callback is enqueued + hook_state.wait_for_stream_enqueued = True + + return ret_fut diff --git a/mindnlp/core/distributed/algorithms/ddp_comm_hooks/optimizer_overlap_hooks.py b/mindnlp/core/distributed/algorithms/ddp_comm_hooks/optimizer_overlap_hooks.py new file mode 100644 index 000000000..45a7b33d7 --- /dev/null +++ b/mindnlp/core/distributed/algorithms/ddp_comm_hooks/optimizer_overlap_hooks.py @@ -0,0 +1,162 @@ +# mypy: allow-untyped-defs +from dataclasses import dataclass +from functools import partial +from typing import Any, Callable, List, no_type_check + +from mindnlp import core +from mindnlp import core.distributed as dist +from core.autograd import Variable + + +__all__: List[str] = [] + +_FUNCTIONAL_OPTIM_STEP_METHOD_NAME = "step_param" + + +class _OptimizerHookState: + """ + Holds state for running optimizer in-line after DDP communication hook. + + Currently contains only optimizer class which must have a method `step_param`. + """ + + __slots__ = ["functional_optimizer", "params_to_optimize"] + + def __init__(self, functional_optim, params=None): + self.functional_optimizer = functional_optim + self._check_valid_functional_optim() + self._set_params_to_optimize(params) + + def _set_params_to_optimize(self, params): + if params is not None: + self.params_to_optimize = set(params) + + def _check_valid_functional_optim(self): + if not hasattr(self.functional_optimizer, _FUNCTIONAL_OPTIM_STEP_METHOD_NAME): + raise ValueError( + f"Class {type(self.functional_optimizer)} must implement method " + f"{_FUNCTIONAL_OPTIM_STEP_METHOD_NAME}." + ) + + +@dataclass +class _OptimInBackwardHookState: + optim_stream: core.Stream + wait_for_optim_stream_enqueued: bool + + +@no_type_check +def _apply_optim_in_backward_hook( + gradient_is_bucket_view: bool, +) -> Callable[[Any, dist.GradBucket], core.futures.Future[core.Tensor]]: + r""" + Register hook to apply the optimizer in backward. + + If core.distributed.optim._apply_optimizer_in_backward is used to overlap + optimizer with backward pass, DDP will run the below hook to run optimizer + step for parameters after gradient communication has taken place. + """ + optim_in_bwd_state = _OptimInBackwardHookState( + optim_stream=core.Stream(), + wait_for_optim_stream_enqueued=False, + ) + + def apply_optim_in_backward_hook( + hook_state: Any, + bucket: dist.GradBucket, + optim_stream_state, + ) -> core.futures.Future[core.Tensor]: + # Run original hook + ddp_weakref = hook_state + ddp_inst = ddp_weakref() + reducer, process_group = ddp_inst.reducer, ddp_inst.process_group + fut = reducer._run_allreduce_hook(bucket) + optimizer_stream = optim_stream_state.optim_stream + with core.get_device_module().stream(optimizer_stream): + fut.wait() + # Apply gradient division since C++ side only allreduces and does + # not average. TODO: (rohan-varma) the div factor may be different + # when running with join hook + bucket.buffer().div_(process_group.size()) + model_params = bucket.parameters() + grads = bucket.gradients() + # TODO (rohan-varma): upcast as needed for DDP mixed precision, + # once optimizer in backward + DDP mixed precision is supported. + for p, g in zip(model_params, grads): + if hasattr(p, "_in_backward_optimizers"): + # Note: need to set grad to the bucket's grad, because + # running allreduce results in the bucket's grad being + # reduced, but not grad field. + if not gradient_is_bucket_view: + p.grad = g + for optim in p._in_backward_optimizers: + optim.step() + + # Need to return a Future[Tensor] to obey comm hook API contract. + ret_fut = core.futures.Future() + ret_fut.set_result(bucket.buffer()) + + # enqueue a callback to wait for this optimizer stream at the end of + # backward and set all DDP managed grads to None. + def wait_for_optim_stream_callback(): + core.accelerator.current_stream().wait_stream( + optim_stream_state.optim_stream + ) + # Set DDP managed grads to None + for param in ddp_inst._get_data_parallel_params(ddp_inst.module): + if hasattr(param, "_in_backward_optimizers"): + param.grad = None + + # reset for the next backwards pass + optim_stream_state.wait_for_optim_stream_enqueued = False + + if not optim_stream_state.wait_for_optim_stream_enqueued: + Variable._execution_engine.queue_callback(wait_for_optim_stream_callback) + # mark that the callback is enqueued + optim_stream_state.wait_for_optim_stream_enqueued = True + + return ret_fut + + comm_hook = partial( + apply_optim_in_backward_hook, optim_stream_state=optim_in_bwd_state + ) + # These are needed for DDP's logging of comm hooks + comm_hook.__name__ = apply_optim_in_backward_hook.__name__ + comm_hook.__qualname__ = apply_optim_in_backward_hook.__qualname__ + + return comm_hook + + +def _hook_then_optimizer( + hook: Callable[[Any, dist.GradBucket], core.futures.Future[core.Tensor]], + optimizer_state: _OptimizerHookState, +) -> Callable[[Any, dist.GradBucket], core.futures.Future[core.Tensor]]: + r"""Run optimizer in a functional fashion after DDP communication hook.""" + has_set_params = ( + hasattr(optimizer_state, "params_to_optimize") + and optimizer_state.params_to_optimize is not None + ) + + def hook_then_optimizer_wrapper( + hook_state, bucket: dist.GradBucket + ) -> core.futures.Future[core.Tensor]: + # Run original hook + fut = hook(hook_state, bucket) + + def optimizer_step(fut): + gradient_tensors = bucket.gradients() + model_params = bucket.parameters() + for grad_tensor, model_param in zip(gradient_tensors, model_params): + if ( + not has_set_params + or model_param in optimizer_state.params_to_optimize + ): + optimizer_state.functional_optimizer.step_param( + model_param, + grad_tensor, + ) + return bucket.buffer() + + return fut.then(optimizer_step) + + return hook_then_optimizer_wrapper diff --git a/mindnlp/core/distributed/algorithms/ddp_comm_hooks/post_localSGD_hook.py b/mindnlp/core/distributed/algorithms/ddp_comm_hooks/post_localSGD_hook.py new file mode 100644 index 000000000..ac08aa7ce --- /dev/null +++ b/mindnlp/core/distributed/algorithms/ddp_comm_hooks/post_localSGD_hook.py @@ -0,0 +1,124 @@ +# mypy: allow-untyped-defs +import logging + +from mindnlp import core +from mindnlp import core.distributed as dist + +from . import default_hooks as default + + +logger = logging.getLogger(__name__) + + +class PostLocalSGDState: + r""" + Store state for all-reducing gradients globally until given step, then locally after. + + Stores the state for all-reducing gradients globally using ``process_group`` until step ``start_localSGD_iter``, + and all-reducing gradients locally using ``subgroup`` afterwards. + + If ``process_group`` is ``None``, the global process group will be used. + If ``subgroup`` is ``None``, the intra-node process group on each machine will be used. + + Additionally, ``post_local_gradient_allreduce`` may be worth tuning, + because both true and false may give a faster convergence. + """ + + __slots__ = [ + "process_group", + "subgroup", + "start_localSGD_iter", + "post_local_gradient_allreduce", + "iter", + ] + + def __init__( + self, + process_group, + subgroup, + start_localSGD_iter, + post_local_gradient_allreduce=True, + ): + """Initialize state object with given parameters and log when localSGD start.""" + logger.info( + "Local SGD will be started after %s iterations", start_localSGD_iter + ) + + # The group used for all-reducing gradients globally. + self.process_group = process_group + # The group used for all-reducing gradients locally. + self.subgroup = subgroup + self.start_localSGD_iter = start_localSGD_iter + # Allreduce gradients locally since iteration `start_localSGD_iter`. + # This may help with the convergence efficiency at the cost of relatively cheap intra-subgroup communication. + self.post_local_gradient_allreduce = post_local_gradient_allreduce + # Iteration/step in the training loop. + self.iter = 0 + + def maybe_increase_iter(self, bucket): + """Track iterations and trigger log message at start of local SGD.""" + # Since bucket 0 is the last bucket to allreduce in an iteration. + # Only increase `iter` when bucket 0 is processed. + if bucket.is_last(): + self.iter += 1 + + if self.iter == self.start_localSGD_iter: + logger.info("Start to apply local SGD after %s iterations.", self.iter) + + +def post_localSGD_hook( + state: PostLocalSGDState, bucket: dist.GradBucket +) -> core.futures.Future[core.Tensor]: + """ + Run post-localSGD algorithm. + + This DDP communication hook is used for running post-localSGD algorithm, + by combining with a model averaging component (e.g., + :class:`~core.distributed.algorithms.model_averaging.averagers.PeriodicModelAverager`) + that runs after the optimizer step. + + Args: + state (PostLocalSGDState): State information to run post-localSGD. + Users mainly need to tune ``start_localSGD_iter`` to determine when to start local SGD. + bucket (dist.GradBucket): Bucket that stores a 1D flattened gradient tensor that batches multiple per-variable tensors. + Note that since DDP comm hook only supports single process single device mode, + only exactly one tensor is stored in this bucket. + + Returns: + Future handler of the communication, which updates the gradients in place. + + Example:: + >>> # xdoctest: +SKIP + >>> state = PostLocalSGDState(process_group=process_group, subgroup=subgroup, + start_localSGD_iter=10) + >>> ddp_model.register_comm_hook(state, post_localSGD_hook) + >>> # Also need to establish a model averaging module and run model averaging after ``optimizer.step()``. + >>> # Please refer to the examples in ``core.distributed.algorithms.model_averaging.averagers`` module. + """ + global_group_to_use = ( + state.process_group if state.process_group is not None else dist.group.WORLD + ) + + # The input tensor is a flattened 1D tensor. + input_tensor = bucket.buffer() + + # Run allreduce using `global_group_to_use` in the first `start_localSGD_iter` iterations. + if state.iter < state.start_localSGD_iter: + state.maybe_increase_iter(bucket) + return default._allreduce_fut(global_group_to_use, input_tensor) # type: ignore[arg-type] + + # If `post_local_gradient_allreduce` is not set, + # then no gradient synchronization after the first `start_localSGD_iter` iterations. + if not state.post_local_gradient_allreduce: + fut: core.futures.Future[core.Tensor] = core.futures.Future() + fut.set_result(input_tensor) + return fut + + # Run allreduce using `subgroup` after the first `start_localSGD_iter` iterations. + # Note that by default, a separate subgroup for each node is created which + # causes an intra-node allreduce to be done at each training step. + # From this moment, model averaging should run after the optimizer step, + # to globally allreduce all the parameters. + if state.subgroup is None: + state.subgroup, _ = dist.new_subgroups() + return default._allreduce_fut(state.subgroup, input_tensor) diff --git a/mindnlp/core/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py b/mindnlp/core/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py new file mode 100644 index 000000000..90fbb1b5d --- /dev/null +++ b/mindnlp/core/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py @@ -0,0 +1,861 @@ +# mypy: allow-untyped-defs +import logging +import math +from collections import defaultdict +from typing import Dict + +from mindnlp import core +from mindnlp import core.distributed as dist +from core.distributed import distributed_c10d +from core.utils._typing_utils import not_none + +from . import default_hooks as default + + +__all__ = ["PowerSGDState", "powerSGD_hook", "batched_powerSGD_hook"] + +logger = logging.getLogger(__name__) + + +def _orthogonalize(matrices, epsilon=0): + """ + Decide between Gram-Schmidt or QR factorization to orthogonalize a batch of matrices. + + QR factorization doesn't work with half-precision, but it is usually faster with a rank > 2. + """ + assert len(matrices.shape) == 3 and matrices.shape[2] <= matrices.shape[1] + + num_matrices = matrices.shape[0] + rank = matrices.shape[2] + dtype = matrices.dtype + if rank <= 2 or dtype in [core.float16, core.bfloat16]: + _orthogonalize_gram_schmidt(matrices, epsilon=epsilon) + else: + core.linalg.qr( + matrices, + out=( + matrices, + core.empty( + num_matrices, rank, rank, device=matrices.device, dtype=dtype + ), + ), + ) + + +def _orthogonalize_gram_schmidt(matrices, epsilon=0): + """ + Apply Gram-Schmidt procedure to orthogonalize a batch of matrices. + + If epsilon is 0, this is equivalent to `core.qr(matrices, out=(matrices, _))`, + """ + num_cols = matrices.shape[2] + for i in range(num_cols): + # Normalize the i'th column. + col = matrices[:, :, i : i + 1] + # If no epsilon is added here, division by zero may be caused by vanishing gradients. + # This epsilon is not needed if the input batch of matrices covers the gradients of at least one entire layer + # in the neural network. + if epsilon == 0: + # Note that col ** 2 can underflow/overflow if we use FP16. + # May need to consider multiplying a scaling factor and dividing it later, or using bfloat16 instead. + try: + col /= core.norm(col, dim=1, keepdim=True) + except ZeroDivisionError: + logger.error( + "The matrices to be orthogonalized has at least a column of all 0s. Please set a small value such as 1e-8 " + "as `orthogonalization_epsilon` in PowerSGD state." + ) + # Recover the values from NaNs to 0s. + col.fill_(0.0) + else: + col /= core.norm(col, dim=1, keepdim=True) + epsilon + # Project it on the rest and remove it. + if i + 1 < num_cols: + rest = matrices[:, :, i + 1 :] + rest -= core.sum(col * rest, dim=1, keepdim=True) * col + + +def _should_compress( + num_rows, num_cols, matrix_approximation_rank, min_compression_rate +): + """ + Recommend if tensor given is worth compressing. + + Returns a recommendation as to whether the 2D tensor described by the arguments is worth compressing, + including statistics describing the expected savings from compression. We consider a tensor worth + compressing when ``min_compression_rate`` < uncompressed size / compressed size, where + uncompressed size = ``num_rows`` * ``num_cols``, + and compressed size = (``num_rows`` + ``num_cols``) * ``matrix_approximation_rank``. + + The result of this function is a tuple of the form (compression_recommendation, uncompressed_el_count, compressed_el_count), where: + + compression_recommendation is true if the tensor is worth compressing, and false otherwise (see above); + + uncompressed_el_count is the uncompressed element count, i.e. ``num_rows`` * ``num_cols``; and, + + compress_el_count is the element count after compression, i.e. (``num_rows`` + ``num_cols``) * ``matrix_approximation_rank``. + """ # noqa: B950 + uncompressed_size = num_rows * num_cols + compressed_size = (num_rows + num_cols) * matrix_approximation_rank + return ( + compressed_size * min_compression_rate < uncompressed_size, + uncompressed_size, + compressed_size, + ) + + +def _report_compression_stats(bucket, state): + """Report compression stats at frequency of ``compression_stats_logging_frequency`` specified in PowerSGD state.""" + if bucket.is_last() and state.iter >= state.next_stats_report: + stats = state.compression_stats() + logger.info( + "Compression stats: iter %s, total before compression %s, total after compression %s, " + "rate %s", + state.iter, + stats[1], + stats[2], + stats[0], + ) + state.next_stats_report = state.iter + state.compression_stats_logging_frequency + + +class PowerSGDState: + r""" + Store both the algorithm's hyperparameters and internal state for all gradients during training. + + Particularly, ``matrix_approximation_rank`` and ``start_powerSGD_iter`` are the main hyperparameters that should be tuned by the user. + For performance, we suggest to keep binary hyperparameters ``use_error_feedback`` and ``warm_start`` on. + + 1. ``matrix_approximation_rank`` controls the size of compressed low-rank tensors, which determines the compression rate. The lower the rank, the stronger the compression. + + 1.1. If ``matrix_approximation_rank`` is too low, the full model quality will need more training steps to reach or will never reach and yield loss in accuracy. + + 1.2. The increase of ``matrix_approximation_rank`` can substantially increase the computation costs of the compression, and the accuracy may not be further improved beyond a certain ``matrix_approximation_rank`` threshold. + + To tune ``matrix_approximation_rank``, we suggest to start from 1 and increase by factors of 2 (like an exponential grid search, 1, 2, 4, ...), until a satisfactory accuracy is reached. Typically only a small value 1-4 is used. For some NLP tasks (as shown in Appendix D of the original paper), this value has been increased to 32. + + 2. ``start_powerSGD_iter`` defers PowerSGD compression until step ``start_powerSGD_iter``, and vanilla allreduce runs prior to step ``start_powerSGD_iter``. This hybrid scheme of **vanilla allreduce + PowerSGD** can effectively improve the accuracy, even a relatively small ``matrix_approximation_rank`` is used. This is because that, the beginning of training phase is usually very sensitive to inaccurate gradients, and compressing gradients too early may make the training quickly take a suboptimal trajectory, which can result in an irrecoverable impact on the accuracy. + + To tune ``start_powerSGD_iter``, we suggest to start with 10% of total training steps, and increase it until a satisfactory accuracy is reached. If there is a warm-up stage in the training, ``start_powerSGD_iter`` typically should be no less than the number of warm-up steps. + + 3. ``min_compression_rate`` is the minimum compression rate required when a layer is compressed. Due to the computation overheads incurred by the compression, a tensor is worth compressing only if there can be sufficient saving in bandwidth, where ``(num_rows + num_cols) * matrix_approximation_rank * min_compression_rate < num_rows * num_cols``. If the specified compression rate threshold cannot be satisfied, the tensor will be directly allreduced without compression. + + Compression statistics are logged every ``compression_stats_logging_frequency`` iterations once PowerSGD compression starts. + + 4. ``orthogonalization_epsilon`` can be a very small value (e.g., 1e-8) added to every normalized matrix column in orthogonalization step, to prevent div-by-zero error if any column has all 0s. If this can already be prevented (e.g., by batch normalization), an epsilon of 0 is recommended for accuracy. + + 5. ``batch_tensors_with_same_shape`` controls whether to compress and decompress tensors with same shape in a batched operation to achieve higher parallelism. Note that you should also increase the bucket size (i.e., ``bucket_cap_mb`` arg in DDP constructor) to make more same-shaped tensors appear in the same bucket, however this may reduce the overlap between computation and communication, and increase the memory footprint due to stacking the tensors of the same shape. Set to ``True`` if the compression / decompression computation is a bottleneck. + + .. warning :: + If error feedback or warm-up is enabled, the minimum value of ``start_powerSGD_iter`` allowed in DDP is 2. + This is because there is another internal optimization that rebuilds buckets at iteration 1 in DDP, + and this can conflict with any tensor memorized before the rebuild process. + """ # noqa: B950 + + __slots__ = [ + "process_group", + # The fields below are the hyperparameters that often need to be tuned by the user. + "matrix_approximation_rank", + "start_powerSGD_iter", + # The fields below are the hyperparameters that seldom need be tuned by the user. + "min_compression_rate", + "orthogonalization_epsilon", + # The fields below are the binary hyperparameters recommended to be turned on for performance and accuracy. + "use_error_feedback", + "warm_start", + "batch_tensors_with_same_shape", + # The fields below are internal state. + "rng", + "error_dict", + "p_memory_dict", + "q_memory_dict", + "iter", + # The fields below are for recording compression stats. + "total_numel_before_compression", + "total_numel_after_compression", + "compression_stats_logging_frequency", + "next_stats_report", + ] + + def __init__( + self, + process_group, + matrix_approximation_rank=1, + start_powerSGD_iter=1_000, + min_compression_rate=2, + use_error_feedback=True, + warm_start=True, + orthogonalization_epsilon=0, + random_seed=0, + compression_stats_logging_frequency=10_000, + batch_tensors_with_same_shape: bool = False, + ): + logger.info( + "PowerSGD config: matrix_approximation_rank = %s; start_powerSGD_iter = %s; " + "min_compression_rate = %s; orthogonalization_epsilon = %s; use_error_feedback = %s; warm_start = %s; " + "random_seed = %s; compression_stats_logging_frequency = %s; batch_tensors_with_same_shape = %s", + matrix_approximation_rank, + start_powerSGD_iter, + min_compression_rate, + orthogonalization_epsilon, + use_error_feedback, + warm_start, + random_seed, + compression_stats_logging_frequency, + batch_tensors_with_same_shape, + ) + + self.process_group = process_group + self.matrix_approximation_rank = matrix_approximation_rank + # Deferring PowerSGD compression util step 'start_powerSGD_iter' can have two advantages: + # 1) It turns out that PowerSGD may lead to a non-trivial accuracy loss, + # even if the matrix approximation rank is increased to a large value. + # To mitigate the accuracy loss, a simple yet effective way is mixing vanilla allreduce + # (or a more conservative compression such as FP16 compression) with PowerSGD. + # 2) There is an internal optimization of rebuilding buckets process in DDP, + # in order to save the memory space. + # This step takes place after the first iteration. + # However, this means that the shape of input bucketized tensors is subject to change, + # which will complicate the implementations of error feedback and warm-up. + # Running vanilla allreduce in the first few iterations can avoid this complexity. + if (use_error_feedback or warm_start) and start_powerSGD_iter <= 1: + raise ValueError( + "Expect `start_powerSGD_iter` > 1 if `use_error_feedback` or `warm_start` is enabled, " + "because PowerSGD can only be applied after the first two iterations in DDP." + ) + self.start_powerSGD_iter = start_powerSGD_iter + self.min_compression_rate = min_compression_rate + # Error feedback is usually crucial for both for convergence and generalization, + # because PowerSGD is a biased compressor, + # i.e., compressing and decompressing a random gradient does not yield the original in expectation. + # This mechanism requires a temporary copy of the input gradients, + # so it increases the peak memory consumption by the size of the gradient tensor. + # However, if the target matrices are known to be exactly low-ranked (instead of just low stable rank), + # sometimes it is possible to converge to the optima without error feedback. + # See: http://proceedings.mlr.press/v54/yurtsever17a/yurtsever17a.pdf + self.use_error_feedback = use_error_feedback + # Warm-start reuses P(s) and Q(s) from the previous iteration. + # This can improve the approximation quality and hence improve the accuracy. + # Additionally, by avoiding the initialization of these low-rank tensors at every step, + # this can also accelerate training. + # However, this is at the cost of extra memory. + self.warm_start = warm_start + # Can use a very small value to prevent div-by-zero error caused by orthogonalization of vanishing gradients. + self.orthogonalization_epsilon = orthogonalization_epsilon + # The purpose of this RNG is to generate different random seeds for initializing Q across iterations, + # but in the same order for all the DDP replicas. + # Different random seeds across iterations indicate different 'projections' of the gradients at different SGD steps. + # If the same random projection is used, + # there will be differences between the gradients that are never synchronized. + import numpy as np + + self.rng = np.random.RandomState(random_seed) + # Since there is only a single state instance for all the input buckets, + # need to maintain a dictionary that maps each bucket index to the local error. + self.error_dict: Dict[int, core.Tensor] = {} + self.p_memory_dict: Dict[int, core.Tensor] = {} + self.q_memory_dict: Dict[int, core.Tensor] = {} + # Iteration/step in the training loop. + self.iter = 0 + # Compression stats accumulators + self.total_numel_before_compression = 0 + self.total_numel_after_compression = 0 + # We'll report compression stats every 'compression_stats_logging_frequency' iterations + # Note that we always report compression stats at least once. + self.compression_stats_logging_frequency = max( + 1, compression_stats_logging_frequency + ) + self.next_stats_report = 0 + # Batching tensors with same shape can increase parallelism in compression / decompression computation. + # This requires a larger bucket size to make more same-shaped tensor to appear in one bucket, however + # this may reduce the overlap between computation and communication, and increase the memory footprint + # due to stacking tensors. + # Turn on if compression / decompression computation is a bottleneck. + self.batch_tensors_with_same_shape = batch_tensors_with_same_shape + + def __getstate__(self): + r""" + Return a ``Dict[str, Any]`` which will be pickled and saved. + + ``process_group`` is not serializable and excluded from + a returned state. + """ + logger.warning( + "NOTE: Process group is not serializable and excluded from a saved state." + ) + return { + slot: getattr(self, slot) + for slot in self.__slots__ + if slot != "process_group" + } + + def __setstate__(self, state): + r""" + Take a provided ``state`` and set to this ``PowerSGDState`` instance. + + ``process_group`` is set to default. + """ + self.process_group = distributed_c10d._get_default_group() + logger.warning( + "NOTE: Process group will be set to a default group (i.e. the world size).\ + If a different group is desired, please set `self.process_group` after PowerSGD state is loaded." + ) + for slot, value in state.items(): + setattr(self, slot, value) + + def maybe_increase_iter(self, bucket): + """Track iterations and trigger log message at start of local SGD.""" + # Since bucket 0 is the last bucket to allreduce in an iteration. + # Only increase `iter` when bucket 0 is processed. + if bucket.is_last(): + self.iter += 1 + + if self.iter == self.start_powerSGD_iter: + logger.info("Start to apply PowerSGD after %s iterations.", self.iter) + + def compression_stats(self): + r""" + Return latest compression statistics as tuple. + + Returns tuple of form (compress_rate, numel_before_compression, numel_after_compression) where: + + compress_rate is the effective compression rate i.e. (number of elements before compression) / (number of elements after compression); + + numel_before_compression is the total number of elements before compression was applied; and, + + numel_after_compression is the total number of elements after compression was applied. + """ # noqa: B950 + compress_rate = ( + self.total_numel_before_compression / self.total_numel_after_compression + if self.total_numel_after_compression > 0 + else 0 + ) + return ( + compress_rate, + self.total_numel_before_compression, + self.total_numel_after_compression, + ) + + +def powerSGD_hook( + state: PowerSGDState, bucket: dist.GradBucket +) -> core.futures.Future[core.Tensor]: + r""" + Implement PowerSGD algorithm. + + This DDP communication hook implements PowerSGD gradient compression + algorithm described in the `paper `_. + Once gradient tensors are aggregated across all workers, this hook applies + compression as follows: + + 1. Views the input flattened 1D gradient tensor as a list of per-parameter tensors, and divides all the tensors into two groups: + + 1.1 The tensors that should be compressed before allreduce, because the compression can give enough saving in bandwidth. + + 1.2 Rest of the tensors will be directly allreduced without compression, including all the vector tensors (for biases). + + 2. Handles uncompressed tensors: + + 2.1. Allocate contiguous memory for those uncompressed tensors, and allreduces all the uncompressed tensors as a batch, without compression; + + 2.2. Copies the individual uncompressed tensors from the contiguous memory back to the input tensor. + + 3. Handles the tensors that should be compressed by PowerSGD compression: + + 3.1. For each tensor M, creates two low-rank tensors P and Q for decomposing M, + such that M = PQ^T, where Q is initialized from a standard normal distribution and orthogonalized; + + 3.2. Computes each P in Ps, which is equal to MQ; + + 3.3. Allreduces Ps as a batch; + + 3.4. Orthogonalizes each P in Ps; + + 3.5. Computes each Q in Qs, which is approximately equal to M^TP; + + 3.6. Allreduces Qs as a batch; + + 3.7. Computes each M among all the compressed tensors, which is approximately equal to PQ^T. + + Note that this communication hook enforces vanilla allreduce for the first ``state.start_powerSGD_iter`` iterations. + This not only gives the user more control over the tradeoff between speedup and accuracy, + but also helps abstract away some complexity of the internal optimization of DDP for future communication hook developers. + + Args: + state (PowerSGDState): State information to configure the compression rate and support error feedback, warm start, etc. + To tune the compression configs, mainly need to tune ``matrix_approximation_rank``, ``start_powerSGD_iter`` + and ``min_compression_rate``. + bucket (dist.GradBucket): Bucket that stores a 1D flattened gradient tensor that batches multiple per-variable tensors. + Note that since DDP comm hook only supports single process single device mode, + only exactly one tensor is stored in this bucket. + + Returns: + Future handler of the communication, which updates the gradients in place. + + Example:: + >>> # xdoctest: +SKIP + >>> state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1, + start_powerSGD_iter=10, min_compression_rate=0.5) + >>> ddp_model.register_comm_hook(state, powerSGD_hook) + """ # noqa: B950 + process_group = state.process_group + group_to_use = ( + process_group if process_group is not None else not_none(dist.group.WORLD) + ) + world_size = group_to_use.size() + + # The input tensor is a flattened 1D tensor. + input_tensor = bucket.buffer() + + # Run vanilla allreduce in the first `start_powerSGD_iter` iterations. + if state.iter < state.start_powerSGD_iter: + state.maybe_increase_iter(bucket) + return default._allreduce_fut(group_to_use, input_tensor) + + # Apply PowerSGD after `start_powerSGD_iter` iterations. + device = input_tensor.device + dtype = input_tensor.dtype + + # Incorporate the error from the previous state into the gradients. + bucket_index = bucket.index() + input_tensor_cp = None + total_length = input_tensor.shape[0] + if state.use_error_feedback: + if bucket_index in state.error_dict: + input_tensor.add_(state.error_dict[bucket_index]) + else: + logger.info( + "A zero tensor of length %s that represents local error is created.", + total_length, + ) + state.error_dict[bucket_index] = core.zeros( + total_length, device=device, dtype=dtype + ) + + # Keep a copy of the input tensor, + # so that we can compute the local error caused by compression later, + # by comparing this copy and the input tensor updated after decompression. + input_tensor_cp = core.clone(input_tensor).detach() + + # Unflatten the input tensor into per-parameter tensors, for layer-wise compression. + tensors = bucket.gradients() + + # Step I: Divide all the tensors into two groups, + # one will be compressed before allreduce and the other will be directly allreduced without compression. + tensors_to_compress, uncompressed_tensors = [], [] + total_Ps_size = 0 + total_Qs_size = 0 + for tensor in tensors: + matrix = tensor.view(tensor.shape[0], -1) + n, m = matrix.shape + matrix_approximation_rank = min(n, m, state.matrix_approximation_rank) + compress_test = _should_compress( + n, m, matrix_approximation_rank, state.min_compression_rate + ) + state.total_numel_before_compression += compress_test[1] + if compress_test[0]: + tensors_to_compress.append(matrix) + total_Ps_size += n * matrix_approximation_rank + total_Qs_size += m * matrix_approximation_rank + state.total_numel_after_compression += compress_test[2] + else: + uncompressed_tensors.append(tensor) + state.total_numel_after_compression += compress_test[1] + + _report_compression_stats(bucket, state) + + # Step II: Handle uncompressed tensors. + # Allocate contiguous memory for these tensors to allreduce efficiently. + uncompressed_tensors_memory = ( + core.cat([tensor.view(-1) for tensor in uncompressed_tensors]) + if uncompressed_tensors + else core.tensor([], device=device, dtype=dtype) + ) + + # Step III: Handle the tensors that should be compressed. + # Allocate contiguous memory for Ps and Qs to allreduce efficiently. + # If warm-start is enabled, reuse Ps and Qs from the previous iteration if possible. + # The memory spaces of Ps and Qs need to be allocated in the first iteration when PowerSGD is applied. + need_randomize_qs = False + if not state.warm_start or bucket_index not in state.p_memory_dict: + need_randomize_qs = True + # If warm-start is disabled, low-rank tensors will be initialized at every step. + # Only log this if warm-start to avoid spamming. + if state.warm_start: + logger.info( + "Allocating contiguous memory of length %s for Ps, and of length %s for Qs, respectively.", + total_Ps_size, + total_Qs_size, + ) + state.p_memory_dict[bucket_index] = core.empty( + total_Ps_size, device=device, dtype=dtype + ) + state.q_memory_dict[bucket_index] = core.empty( + total_Qs_size, device=device, dtype=dtype + ) + + # Batch tensors to compress by shape. + shape_to_tensors = defaultdict(list) + for tensor in tensors_to_compress: + shape_to_tensors[tensor.shape].append(tensor) + + # This function decides whether to batch tensors with same shape or not according to the argument, + # so the following process could share the same code. + def maybe_batched_tensors_to_compress(): + for tensors in shape_to_tensors.values(): + if state.batch_tensors_with_same_shape: + batch_size = len(tensors) + if batch_size == 1: + # Use the original tensor to avoid copy. + yield tensors[0].unsqueeze(0) + else: + yield core.stack(tensors) + else: + for tensor in tensors: + yield tensor.unsqueeze(0) + + # Create Ps and Qs that point to the allocated memory. + tensors_to_compress = [] + ps = [] + qs = [] + p_idx = 0 + q_idx = 0 + for tensor in maybe_batched_tensors_to_compress(): + batch_size, n, m = tensor.shape + matrix_approximation_rank = min(n, m, state.matrix_approximation_rank) + tensors_to_compress.append(tensor) + ps.append( + state.p_memory_dict[bucket_index][ + p_idx : p_idx + batch_size * n * matrix_approximation_rank + ].view(batch_size, n, matrix_approximation_rank) + ) + qs.append( + state.q_memory_dict[bucket_index][ + q_idx : q_idx + batch_size * m * matrix_approximation_rank + ].view(batch_size, m, matrix_approximation_rank) + ) + p_idx += batch_size * n * matrix_approximation_rank + q_idx += batch_size * m * matrix_approximation_rank + + # If warm-start is enabled, reuse Qs from the previous iteration if possible and skip filling random values. + # The exception is the first iteration when PowerSGD is applied. + if not need_randomize_qs: + for q in qs: + _orthogonalize(q, state.orthogonalization_epsilon) + else: + with core.random.fork_rng(devices=[]): + # Fork this RNG to avoid changing the seed globally and affecting the random sampling anywhere else in the training. + # The seed makes sure that the initial random values are the same across all the DDP replicas. + # This seed should differ at every step. + # Since it is very slow to fork RNG state across all the CUDA devices, + # only fork on CPU and then move the generated tensor to the CUDA device (by overwriting q). + core.manual_seed(state.rng.randint(1_000_000_000)) + for q in qs: + q.copy_( + core.randn( + *q.shape, + device="cpu", + dtype=dtype, + ) + ) + _orthogonalize(q, state.orthogonalization_epsilon) + + # Compute Ps. + for tensor, q, p in zip(tensors_to_compress, qs, ps): + core.bmm(tensor, q, out=p) + + # This allreduce is only applied to uncompressed tensors, + # so it should have been kicked off before the above computation on the compressed tensors to hide more communication costs. + # However, this somehow requires a separate future chain at this time. + allreduce_contiguous_uncompressed_tensors_fut = dist.all_reduce( + uncompressed_tensors_memory, group=group_to_use, async_op=True + ).get_future() + + def unpack_uncompressed_tensors_and_allreduce_ps(fut): + uncompressed_tensors_memory = fut.value()[0].div_(world_size) + idx = 0 + for tensor in uncompressed_tensors: + tensor.copy_( + uncompressed_tensors_memory[idx : idx + tensor.numel()].view_as(tensor) + ) + idx += tensor.numel() + + # Since these Ps will be orthogonalized later, no need to divide them by world size. + return ( + dist.all_reduce( + state.p_memory_dict[bucket_index], group=group_to_use, async_op=True + ) + .get_future() + .wait()[0] + ) + + def compute_qs(fut): + state.p_memory_dict[bucket_index] = fut.value() + for p in ps: + _orthogonalize(p, state.orthogonalization_epsilon) + + # Compute Qs. + for tensor, p, q in zip(tensors_to_compress, ps, qs): + core.bmm(tensor.transpose(1, 2), p, out=q) + + # TODO: The above procedure does two matmul+allreduce steps per iteration -- + # one left multiplication and one right multiplication. + # For warm-start, can take one such step at a time, and alternate between them. + + # Allreduce Qs. + return ( + dist.all_reduce( + state.q_memory_dict[bucket_index], group=group_to_use, async_op=True + ) + .get_future() + .wait()[0] + ) + + def decompress(fut): + state.q_memory_dict[bucket_index] = fut.value().div_(world_size) + + for p, q, tensor in zip(ps, qs, tensors_to_compress): + core.bmm(p, q.transpose(1, 2), out=tensor) + + # Copy batched tensors back to original buffer. + if state.batch_tensors_with_same_shape: + for tensor in tensors_to_compress: + if tensor.shape[0] == 1: + # Skip tensor with batch_size == 1 since itself is the original tensor. + continue + original_tensors = shape_to_tensors[tensor.shape[1:]] + for i, original_tensor in enumerate(original_tensors): + original_tensor.copy_(tensor[i]) + + if core.cuda.is_available(): + core.cuda.synchronize(device) + + if state.use_error_feedback: + # Memorize the local errors. + state.error_dict[bucket_index] = input_tensor_cp - input_tensor + if not state.warm_start: + state.p_memory_dict.clear() + state.q_memory_dict.clear() + + state.maybe_increase_iter(bucket) + + return input_tensor + + return ( + allreduce_contiguous_uncompressed_tensors_fut.then( + unpack_uncompressed_tensors_and_allreduce_ps + ) + .then(compute_qs) + .then(decompress) + ) + + +def batched_powerSGD_hook( + state: PowerSGDState, bucket: dist.GradBucket +) -> core.futures.Future[core.Tensor]: + r""" + Implement simplified PowerSGD algorithm. + + This DDP communication hook implements a simplified PowerSGD gradient compression + algorithm described in the `paper `_. + This variant does not compress the gradients layer by layer, + but instead compresses the flattened input tensor that batches all the gradients. + Therefore, it is **faster** than :meth:`powerSGD_hook`, + but usually results in a **much lower accuracy**, unless ``matrix_approximation_rank`` is 1. + + .. warning :: + Increasing ``matrix_approximation_rank`` here may not necessarily increase the accuracy, + because batching per-parameter tensors without column/row alignment can destroy low-rank structure. + Therefore, the user should always consider :meth:`powerSGD_hook` first, + and only consider this variant when a satisfactory accuracy can be achieved when ``matrix_approximation_rank`` is 1. + + Once gradient tensors are aggregated across all workers, this hook applies + compression as follows: + + 1. Views the input flattened 1D gradient tensor as a square-shaped tensor M with 0 paddings; + + 2. Creates two low-rank tensors P and Q for decomposing M, such that M = PQ^T, where Q is initialized from a standard normal distribution and orthogonalized; + + 3. Computes P, which is equal to MQ; + + 4. Allreduces P; + + 5. Orthogonalizes P; + + 6. Computes Q, which is approximately equal to M^TP; + + 7. Allreduces Q; + + 8. Computes M, which is approximately equal to PQ^T. + + 9. Truncates the input tensor to the original length. + + Note that this communication hook enforces vanilla allreduce for the first ``state.start_powerSGD_iter`` iterations. + This not only gives the user more control over the tradeoff between speedup and accuracy, + but also helps abstract away some complexity of the internal optimization of DDP for future communication hook developers. + + Args: + state (PowerSGDState): State information to configure the compression rate and support error feedback, warm start, etc. + To tune the compression configs, mainly need to tune ``matrix_approximation_rank`` and ``start_powerSGD_iter``. + bucket (dist.GradBucket): Bucket that stores a 1D flattened gradient tensor that batches multiple per-variable tensors. + Note that since DDP comm hook only supports single process single device mode, + only exactly one tensor is stored in this bucket. + + Returns: + Future handler of the communication, which updates the gradients in place. + + Example:: + >>> # xdoctest: +SKIP + >>> state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1) + >>> ddp_model.register_comm_hook(state, batched_powerSGD_hook) + """ # noqa: B950 + process_group = state.process_group + group_to_use = ( + process_group if process_group is not None else not_none(dist.group.WORLD) + ) + world_size = group_to_use.size() + + # The input tensor is a flattened 1D tensor. + input_tensor = bucket.buffer() + + # Run vanilla allreduce in the first `start_powerSGD_iter` iterations. + if state.iter < state.start_powerSGD_iter: + state.maybe_increase_iter(bucket) + return default._allreduce_fut(group_to_use, input_tensor) + + # Apply PowerSGD after `start_powerSGD_iter` iterations. + device = input_tensor.device + total_length = input_tensor.shape[0] + state.total_numel_before_compression += total_length + + # View the input tensor as a 2D square-shape tensor, and pad 0s if necessary. + square_side_length = math.ceil(math.sqrt(total_length)) + state.total_numel_after_compression += ( + square_side_length * state.matrix_approximation_rank * 2 + ) + padded_total_length = square_side_length**2 + input_tensor.resize_(padded_total_length) + input_tensor[total_length:padded_total_length].fill_(0) + + _report_compression_stats(bucket, state) + + # Incorporate the error from the previous state into the gradients. + bucket_index = bucket.index() + input_tensor_cp = None + if state.use_error_feedback: + if bucket_index in state.error_dict: + input_tensor.add_(state.error_dict[bucket_index]) + else: + logger.info( + "A zero tensor of length %s that represents local error is created.", + padded_total_length, + ) + state.error_dict[bucket_index] = core.zeros( + padded_total_length, device=device, dtype=input_tensor.dtype + ) + + # Keep a copy of the input tensor, + # so that we can compute the local error caused by compression later, + # by comparing this copy and the input tensor updated after decompression. + input_tensor_cp = core.clone(input_tensor).detach() + matrix = input_tensor.view(square_side_length, square_side_length) + + # Reuse P and Q from the previous iteration if possible. + # The memory spaces of P and Q need to be allocated in the first iteration when PowerSGD is applied. + if not state.warm_start or bucket_index not in state.p_memory_dict: + # If warm-start is disabled, low-rank tensors will be initialized at every step. + # Only log this if warm-start to avoid spamming. + if state.warm_start: + logger.info( + "Initializing low-rank tensors P and Q, each of which has a shape of %s x %s.", + square_side_length, + state.matrix_approximation_rank, + ) + + def create_low_rank_tensor(fill_random_values, rng): + """Return a low-rank 2D tensor of square_side_length * matrix_approximation_rank.""" + if fill_random_values: + with core.random.fork_rng(devices=[]): + # Fork this RNG to avoid changing the seed globally and affecting the random sampling + # anywhere else in the training. + # The seed makes sure that the initial random values are the same across all the DDP replicas. + # This seed should differ at every step. + # Since it is very slow to fork RNG state across all the CUDA devices, + # only fork on CPU and then move the generated tensor to the CUDA device. + core.manual_seed(rng.randint(1_000_000_000)) + return core.randn( + square_side_length, + state.matrix_approximation_rank, + device="cpu", + dtype=input_tensor.dtype, + ).to(device) + else: + return core.empty( + square_side_length, + state.matrix_approximation_rank, + device=device, + dtype=input_tensor.dtype, + ) + + state.p_memory_dict[bucket_index] = create_low_rank_tensor( + fill_random_values=False, rng=state.rng + ) + state.q_memory_dict[bucket_index] = create_low_rank_tensor( + fill_random_values=True, rng=state.rng + ) + _orthogonalize(state.q_memory_dict[bucket_index]) + + core.matmul( + matrix, state.q_memory_dict[bucket_index], out=state.p_memory_dict[bucket_index] + ) + allreduce_p_fut = dist.all_reduce( + state.p_memory_dict[bucket_index], group=group_to_use, async_op=True + ).get_future() + + def compute_q(fut): + state.p_memory_dict[bucket_index] = fut.value()[0] + _orthogonalize(state.p_memory_dict[bucket_index]) + + core.matmul( + matrix.t(), + state.p_memory_dict[bucket_index], + out=state.q_memory_dict[bucket_index], + ) + + # TODO: The above procedure does two matmul+allreduce steps per iteration -- + # one left multiplication and one right multiplication. + # For warm-start, can take one such step at a time, and alternate between them. + + return ( + dist.all_reduce( + state.q_memory_dict[bucket_index], group=group_to_use, async_op=True + ) + .get_future() + .wait()[0] + ) + + def decompress(fut): + state.q_memory_dict[bucket_index] = fut.value().div_(world_size) + core.matmul( + state.p_memory_dict[bucket_index], + state.q_memory_dict[bucket_index].t(), + out=matrix, + ) + + if state.use_error_feedback: + # Memorize the local errors. + state.error_dict[bucket_index] = input_tensor_cp - input_tensor + # Removing this seemingly unnecessary sync somehow may cause failures. + # See: https://github.com/pytorch/pytorch/pull/54838 + if core.cuda.is_available(): + core.cuda.synchronize(device) + if not state.warm_start: + state.p_memory_dict.clear() + state.q_memory_dict.clear() + ret = input_tensor.resize_(total_length) + + state.maybe_increase_iter(bucket) + + return ret + + return allreduce_p_fut.then(compute_q).then(decompress) diff --git a/mindnlp/core/distributed/algorithms/ddp_comm_hooks/quantization_hooks.py b/mindnlp/core/distributed/algorithms/ddp_comm_hooks/quantization_hooks.py new file mode 100644 index 000000000..93cc180cf --- /dev/null +++ b/mindnlp/core/distributed/algorithms/ddp_comm_hooks/quantization_hooks.py @@ -0,0 +1,218 @@ +# mypy: allow-untyped-defs +from mindnlp import core +from mindnlp import core.distributed as dist +from mindnlp.core import nn + + +def _quantize_per_tensor_backend(x, scale, zero_point): + y = core.round(x / scale) + zero_point + y = core.clamp(y, 0, 255).to(core.uint8) + return y + + +def _dequantize_per_tensor_backend(y, scale, zero_point): + x = scale * (y.to(core.float32) - zero_point) + return x + + +def _quantize_per_channel_backend(x, scale, zero_point): + y = core.zeros(x.size(), device=x.device) + for i in range(x.size()[0]): + y[i, :] = core.round(x[i, :] / scale[i]) + zero_point[i] + y = core.clamp(y, 0, 255).to(core.uint8) + return y + + +def _dequantize_per_channel_backend(y, scale, zero_point): + y = y.to(core.float32).to(y.device) + x = core.zeros_like(y, device=y.device) + for i in range(x.size()[0]): + x[i, :] = scale[i] * (y[i, :] - zero_point[i]) + return x + + +def _get_allgather_out_list(all_gather_in_list, world_size): + out_list = [ + core.zeros_like( + all_gather_in_list, + device=all_gather_in_list.device, + dtype=all_gather_in_list.dtype, + ) + for _ in range(world_size) + ] + return out_list + + +def quantization_pertensor_hook( + process_group: dist.ProcessGroup, bucket: dist.GradBucket +) -> core.futures.Future[core.Tensor]: + """ + Apply ``core.quantize_per_tensor`` logic to DDP using ``allgather`` protocol. + + Workers first allgather the scale and zero point of their own + ``GradBucket`` prior to the quantization. After all workers have that information, + the first ``then`` callback called ``quantize_and_allgather`` quantizes worker's + own gradient tensor, and uses ``allgather`` to communicate these across all workers. + The final ``then`` callback called ``dequantize_and_aggregate``, dequantizes and + aggregates each quantized gradient tensor locally and returns the mean. + + .. warning :: + This is experimental, and uses ``allgather`` protocol which is considerably slower than + ``allreduce`` protocol. It works only with flattened grads. + + Example:: + >>> # xdoctest: +SKIP + >>> ddp_model.register_comm_hook(process_group, quantization_pertensor_hook) + """ + group_to_use = process_group if process_group is not None else dist.group.WORLD + rank = process_group.rank() if process_group is not None else dist.get_rank() + world_size = group_to_use.size() + + tensor = bucket.buffer() + + myObserver = core.ao.quantization.MinMaxObserver().to(tensor.device) + myObserver(tensor) + + s, z = myObserver.calculate_qparams() + s_and_z = core.FloatTensor([s, z]).to(tensor.device) + + all_ranks_s_and_z = _get_allgather_out_list(s_and_z, world_size) + + # First, allgather scale and zeros. + fut = dist.all_gather( + all_ranks_s_and_z, s_and_z, group=group_to_use, async_op=True + ).get_future() + + def quantize_and_allgather(fut): + # Store scale and zeros across all workers. + all_ranks_s_and_z = fut.wait()[0] + # All workers quantize their own ``GradBucket`` tensors. + quantized_tensor = _quantize_per_tensor_backend( + tensor, all_ranks_s_and_z[rank][0], all_ranks_s_and_z[rank][1] + ) + # Allgather quantized tensors. + fut = dist.all_gather( + _get_allgather_out_list(quantized_tensor, world_size), + quantized_tensor, + group=group_to_use, + async_op=True, + ).get_future() + + return fut.wait() + + def dequantize_and_aggregate(fut): + all_ranks_quantized_tensor = fut.wait()[0] + + aggregated_dequantized_tensor = core.zeros_like( + all_ranks_quantized_tensor[0], device=tensor.device, dtype=core.float32 + ) + # Using previously allgathered scales and zeros, dequantize gradient tensors + # locally and then aggregate them. + for r, quantized_tensor in enumerate(all_ranks_quantized_tensor): + aggregated_dequantized_tensor += _dequantize_per_tensor_backend( + quantized_tensor, all_ranks_s_and_z[r][0], all_ranks_s_and_z[r][1] + ) + + return aggregated_dequantized_tensor / world_size + + return fut.then(quantize_and_allgather).then(dequantize_and_aggregate) + + +def quantization_perchannel_hook( + process_group: dist.ProcessGroup, bucket: dist.GradBucket, bucket_size=512 +) -> core.futures.Future[core.Tensor]: + """ + Apply``core.quantize_per_channel`` logic to DDP using ``allgather`` protocol. + + Compared to per-tensor, the main motivation of per-channel is + for considerably large tensors such as a tensor that contains 6 million + elements quantizing per a bucket size of 512 (or 128) elements may significantly + increase the resolution. + + It first splits ``GradBucket`` tensor into multiple chunks (channels) of ``bucket_size`` + elements. Then, workers allgather the scales and zero points of their own + ``GradBucket`` prior to the quantization. After all workers have that information, + the first ``then`` callback called ``quantize_and_allgather`` quantizes worker's + own gradient tensor, and uses ``allgather`` to communicate these across all workers. + The final ``then`` callback called ``dequantize_and_aggregate``, dequantizes, flattens, and + aggregates each quantized gradient tensor locally and returns the mean. + + .. warning :: + This is experimental, and uses ``allgather`` protocol which is considerably slower than + ``allreduce`` protocol. It works only with flattened grads. + + Example:: + >>> # xdoctest: +SKIP + >>> ddp_model.register_comm_hook(process_group, quantization_perchannel_hook) + """ + group_to_use = process_group if process_group is not None else dist.group.WORLD + rank = process_group.rank() if process_group is not None else dist.get_rank() + world_size = group_to_use.size() + + tensor = bucket.buffer() + + tensor_in_channels = ( + nn.functional.pad( + input=tensor, + pad=(0, bucket_size - len(tensor) % bucket_size), + mode="constant", + value=0, + ) + .view(-1, bucket_size) + .to(tensor.device) + ) + + myPerChannelObserver = core.ao.quantization.PerChannelMinMaxObserver().to( + tensor.device + ) + myPerChannelObserver(tensor_in_channels) + + s_ch, z_ch = myPerChannelObserver.calculate_qparams() + s_and_z = core.stack((s_ch, z_ch)).to(tensor.device) + + all_ranks_s_and_z = _get_allgather_out_list(s_and_z, world_size) + # First, allgather scale and zeros. + fut = dist.all_gather( + all_ranks_s_and_z, s_and_z, group=group_to_use, async_op=True + ).get_future() + + def quantize_and_allgather(fut): + # Store scale and zeros across all workers. + all_ranks_s_and_z = fut.wait()[0] + # All workers quantize their corresponding ``GradBucket`` tensors. + quantized_tensor = _quantize_per_channel_backend( + tensor_in_channels, + all_ranks_s_and_z[rank, 0, :], + all_ranks_s_and_z[rank, 1, :], + ) + # Allgather quantized tensors. + fut = dist.all_gather( + _get_allgather_out_list(quantized_tensor, world_size), + quantized_tensor, + group=group_to_use, + async_op=True, + ).get_future() + + return fut.wait() + + def dequantize_and_aggregate(fut): + all_ranks_quantized_tensor = fut.wait()[0] + + aggregated_dequantized_tensor = core.zeros_like( + all_ranks_quantized_tensor[0], device=tensor.device, dtype=core.float32 + ) + # Using previously allgathered scales and zeros, dequantize gradient tensors + # locally and then aggregate them. + for r, quantized_tensor in enumerate(all_ranks_quantized_tensor): + aggregated_dequantized_tensor += _dequantize_per_channel_backend( + quantized_tensor, all_ranks_s_and_z[r][0], all_ranks_s_and_z[r][1] + ) + + return ( + core.flatten(aggregated_dequantized_tensor).to(tensor.device)[ + : tensor.size()[0] + ] + / world_size + ) + + return fut.then(quantize_and_allgather).then(dequantize_and_aggregate) diff --git a/mindnlp/core/distributed/algorithms/join.py b/mindnlp/core/distributed/algorithms/join.py new file mode 100644 index 000000000..131aeb7e5 --- /dev/null +++ b/mindnlp/core/distributed/algorithms/join.py @@ -0,0 +1,349 @@ +# mypy: allow-untyped-defs +import warnings +from abc import ABC, abstractmethod +from types import TracebackType +from typing import Any, List, NamedTuple, Optional, Type + +from mindnlp import core +from mindnlp.core import distributed as dist + + +__all__ = ["JoinHook", "Joinable", "Join"] + + +class JoinHook: + r""" + This defines a join hook, which provides two entry points in the join context manager. + + Entry points : a main hook, which is called repeatedly while there exists a non-joined + process, and a post-hook, which is called once all processes have joined. + + To implement a join hook for the generic join context manager, define a + class that inherits from :class:`JoinHook` and override ``main_hook()`` and + ``post_hook()`` as appropriate. + """ + + def main_hook(self) -> None: + r"""Call this hook while there exists a non-joined process to shadow collective communications in a training iteration. + + Training iteration i.e., in one forward pass, backward pass, and optimizer step. + """ + + def post_hook(self, is_last_joiner: bool) -> None: + r""" + Call hook after all processes have joined. + + It is passed an additional ``bool`` argument ``is_last_joiner``, which indicates if the rank is one of the last to join. + + Arguments: + is_last_joiner (bool): ``True`` if the rank is one of the last to + join; ``False`` otherwise. + """ + + +class Joinable(ABC): + r""" + This defines an abstract base class for joinable classes. + + A joinable class + (inheriting from :class:`Joinable`) should implement :meth:`join_hook`, + which returns a :class:`JoinHook` instance, in addition to + :meth:`join_device` and :meth:`join_process_group` that return device and + process group information, respectively. + """ + + @abstractmethod + def __init__(self) -> None: + super().__init__() + self._join_config = _JoinConfig.construct_disabled_join_config() + + @abstractmethod + def join_hook(self, **kwargs) -> JoinHook: + r""" + Return a :class:`JoinHook` instance for the given :class:`Joinable`. + + Arguments: + kwargs (dict): a :class:`dict` containing any keyword arguments + to modify the behavior of the join hook at run time; all + :class:`Joinable` instances sharing the same join context + manager are forwarded the same value for ``kwargs``. + """ + ... + + @property + @abstractmethod + def join_device(self) -> core.device: + r"""Return the device from which to perform collective communications needed by the join context manager.""" + ... + + @property + @abstractmethod + def join_process_group(self) -> Any: + r"""Returns the process group for the collective communications needed by the join context manager itself.""" + ... + + +class _JoinConfig(NamedTuple): + r"""This includes all fields needed from a :class:`Joinable` instance for the join context manager side.""" + + enable: bool + throw_on_early_termination: bool + is_first_joinable: bool + + @staticmethod + def construct_disabled_join_config(): + r"""Return a :class:`_JoinConfig` instance indicating that join-related logic should be disabled. + + e.g. if the caller is not in a join context manager. + """ + return _JoinConfig( + enable=False, throw_on_early_termination=False, is_first_joinable=False + ) + + +class Join: + r""" + This class defines the generic join context manager, which allows custom hooks to be called after a process joins. + + These hooks should shadow the + collective communications of non-joined processes to prevent hanging and + erroring and to ensure algorithmic correctness. Refer to :class:`JoinHook` + for details about the hook definition. + + .. warning:: + The context manager requires each participating :class:`Joinable` to + call the method :meth:`notify_join_context()` before its own per- + iteration collective communications to ensure correctness. + + .. warning:: + The context manager requires that all ``process_group`` attributes in + the :class:`JoinHook` objects are the same. If there are multiple + :class:`JoinHook` objects, then the ``device`` of the first is used. + The process group and device information is used for checking for non- + joined processes and for notifying processes to throw an exception if + ``throw_on_early_termination`` is enabled, both of which using an all- + reduce. + + Arguments: + joinables (List[Joinable]): a list of the participating + :class:`Joinable` s; their hooks are iterated over in the given + order. + + enable (bool): a flag enabling uneven input detection; setting to + ``False`` disables the context manager's functionality and should + only be set when the user knows the inputs will not be uneven + (default: ``True``). + + throw_on_early_termination (bool): a flag controlling whether to throw an + exception upon detecting uneven inputs (default: ``False``). + + Example:: + + >>> import os + >>> from mindnlp import core + >>> from mindnlp import core.distributed as dist + >>> from mindnlp import core.multiprocessing as mp + >>> # xdoctest: +SKIP + >>> from mindnlp import core.nn.parallel.DistributedDataParallel as DDP + >>> from mindnlp import core.distributed.optim.ZeroRedundancyOptimizer as ZeRO + >>> from core.distributed.algorithms.join import Join + >>> + >>> # On each spawned worker + >>> def worker(rank): + >>> dist.init_process_group("nccl", rank=rank, world_size=2) + >>> model = DDP(core.nn.Linear(1, 1).to(rank), device_ids=[rank]) + >>> optim = ZeRO(model.parameters(), core.optim.Adam, lr=0.01) + >>> # Rank 1 gets one more input than rank 0 + >>> inputs = [core.tensor([1.]).to(rank) for _ in range(10 + rank)] + >>> with Join([model, optim]): + >>> for input in inputs: + >>> loss = model(input).sum() + >>> loss.backward() + >>> optim.step() + >>> # All ranks reach here without hanging/erroring + """ + + def __init__( + self, + joinables: List[Joinable], + enable: bool = True, + throw_on_early_termination: bool = False, + **kwargs, + ): + if len(joinables) == 0: + raise ValueError("The join context manager requires at least one joinable") + self._joinables = joinables + self._join_hooks = [ + joinable.join_hook(**kwargs) for joinable in self._joinables + ] + self._enable = enable + self._throw_on_early_termination = throw_on_early_termination + self._set_joinable_configs() + self._extract_dist_info() + + def _set_joinable_configs(self) -> None: + r"""Set the :class:`_JoinConfig` of each participating :class:`Joinable`.""" + assert len(self._joinables) > 0 + is_first_joinable = True + for joinable in self._joinables: + joinable._join_config = _JoinConfig( + enable=self._enable, + throw_on_early_termination=self._throw_on_early_termination, + is_first_joinable=is_first_joinable, + ) + is_first_joinable = False + + def _extract_dist_info(self) -> None: + r""" + Extract the process group and device information from the joinables. + + If there are multiple joinables, then the context manager uses the + first specified device. + + Preconditions: + ``self._joinables`` is not ``None`` and is non-empty. + + Raises: + ValueError + If there are multiple conflicting ``process_group`` attributes + among the ``Joinable`` objects. + """ + process_group = None + device = None + for joinable in self._joinables: + if process_group is None: + process_group = joinable.join_process_group + elif process_group != joinable.join_process_group: + raise ValueError( + "Using join context manager with multiple process groups" + ) + if device is None: + device = joinable.join_device + self._process_group = process_group + self._rank = dist.get_rank(self._process_group) + self._device = device + + def __enter__(self): + ... + + def __exit__( + self, + type: Optional[Type[BaseException]], + value: Optional[BaseException], + traceback: Optional[TracebackType], + ): + r""" + Repeatedly runs the main hooks until all processes join; then, runs the post-hooks. + + Raises: + RuntimeError + If ``throw_on_early_termination=True``. + """ + if not self._enable or type: + return # propagate the exception directly if one was raised + + all_procs_joined = False + is_last_joiner = True + + i = 0 + WARN_THRESHOLD = 1000 + warnings.simplefilter("once") + + while not all_procs_joined: + if i > WARN_THRESHOLD: + warnings.warn( + "Detected uneven input skew of greater than " + f"{WARN_THRESHOLD}. This means that rank " + f"{self._rank} has at least {WARN_THRESHOLD} " + f"fewer inputs than other currently-active ranks. " + "This level of skew could lead to performance " + "degradation during training." + ) + # Shadow the all-reduce in non-joined processes + num_nonjoined_procs = self._get_num_nonjoined_procs() + if num_nonjoined_procs == 0: + all_procs_joined = True + else: + if self._throw_on_early_termination: + self._notify_procs_to_terminate() + + # Run main hooks + for join_hook in self._join_hooks: + join_hook.main_hook() + + is_last_joiner = False + i += 1 + + # Run post-hooks + for join_hook in self._join_hooks: + join_hook.post_hook(is_last_joiner) + + def _get_num_nonjoined_procs(self): + r"""Return the number of non-joined processes by shadowing an all-reduce in the non-joined processes.""" + num_nonjoined_procs = core.zeros(1, device=self._device) + dist.all_reduce(num_nonjoined_procs, group=self._process_group) + return num_nonjoined_procs.item() + + def _notify_procs_to_terminate(self): + r"""Schedule an all-reduce to notify non-joined processes to terminate. + + Also raise a ``RuntimeError`` indicating that the current process has exhausted its inputs. + """ + ones = core.ones(1, device=self._device) + dist.all_reduce(ones, group=self._process_group) + raise RuntimeError(f"Rank {self._rank} exhausted all inputs.") + + @staticmethod + def notify_join_context(joinable: Joinable): + r""" + Notifies the join context manager that the calling process has not yet joined. + + Then, if ``throw_on_early_termination=True``, checks if uneven inputs have been detected + (i.e. if one process has already joined) and throws an exception if so. + + This method should be called from a :class:`Joinable` object before + its per-iteration collective communications. For example, this should + be called at the beginning of the forward pass in + :class:`DistributedDataParallel`. + + Only the first :class:`Joinable` object passed into the context + manager performs the collective communications in this method, and + for the others, this method is vacuous. + + Arguments: + joinable (Joinable): the :class:`Joinable` object calling this + method. + + Returns: + An async work handle for the all-reduce meant to notify the context + manager that the process has not yet joined if ``joinable`` is the + first one passed into the context manager; ``None`` otherwise. + """ + assert hasattr(joinable, "_join_config"), ( + f"Check that the {type(joinable)} constructor calls the " + "``Joinable`` constructor" + ) + + join_config = joinable._join_config + # First joinable is responsible for the collective communications + if not join_config.is_first_joinable or not join_config.enable: + return None + + device = joinable.join_device + process_group = joinable.join_process_group + + # Schedule an all-reduce to indicate that the caller has not yet joined + ones = core.ones(1, device=device) + work = dist.all_reduce(ones, group=process_group, async_op=True) + + if join_config.throw_on_early_termination: + # Check if uneven inputs have been detected + zeros = core.zeros(1, device=device) + dist.all_reduce(zeros, group=process_group) + should_throw = zeros.item() + if should_throw: + raise RuntimeError( + "Detected at least one rank that exhausted inputs. " + "Throwing across all ranks." + ) + return work diff --git a/mindnlp/core/distributed/algorithms/model_averaging/__init__.py b/mindnlp/core/distributed/algorithms/model_averaging/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/mindnlp/core/distributed/algorithms/model_averaging/averagers.py b/mindnlp/core/distributed/algorithms/model_averaging/averagers.py new file mode 100644 index 000000000..165e8f649 --- /dev/null +++ b/mindnlp/core/distributed/algorithms/model_averaging/averagers.py @@ -0,0 +1,129 @@ +# mypy: allow-untyped-defs +import warnings +from abc import ABC, abstractmethod +from typing import Dict, Iterable, Optional, Union + +from mindnlp import core +from mindnlp import core.distributed as dist +from mindnlp import core.distributed.algorithms.model_averaging.utils as utils +from core.utils._typing_utils import not_none as _not_none + + +__all__ = ["ModelAverager", "PeriodicModelAverager"] + + +class ModelAverager(ABC): + r"""Base class for all model averagers. + + Args: + process_group: The process group to be used for all-reduce. + If ``None``, the default process group, which + is created by :func:`core.distributed.init_process_group`, + will be used. (default: ``None``) + """ + + def __init__(self, process_group: Optional[dist.ProcessGroup] = None): + self.process_group = ( + process_group if process_group is not None else _not_none(dist.group.WORLD) + ) + self.step = 0 + + @abstractmethod + def average_parameters(self, params): + raise NotImplementedError + + +class PeriodicModelAverager(ModelAverager): + r""" + Averages parameters periodically after the warm-up stage. + + This can be used for running `post-local SGD `_, + by running :class:`~core.nn.DistributedDataParallel` (DDP) + using the subgroups created by :meth:`~core.distributed.new_subgroups`. + + Args: + period (int): The number of steps per model averaging. + Usually the period should be greater than ``1`` to reduce the communication cost. + Otherwise, only DDP needs to be used. + warmup_steps (int): The number of warm-up steps. During this stage, + model averaging is skipped. + process_group: The process group to be used for all-reduce. + If ``None``, the default process group, which + is created by :func:`core.distributed.init_process_group`, + will be used. (default: ``None``) + + Example:: + + >>> # xdoctest: +SKIP("undefined variables") + >>> from mindnlp import core + >>> from mindnlp import core.distributed as dist + >>> from mindnlp import core.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook as post_localSGD + >>> from mindnlp import core.distributed.algorithms.model_averaging.averagers as averagers + >>> from mindnlp import core.nn as nn + >>> + >>> dist.init_process_group("nccl", rank=rank, world_size=16) + >>> core.cuda.set_device(rank) + >>> module = nn.Linear(1, 1, bias=False).cuda() + >>> model = nn.parallel.DistributedDataParallel( + >>> module, device_ids=[rank], output_device=rank + >>> ) + >>> # Register a post-localSGD communication hook. + >>> state = PostLocalSGDState(process_group=None, subgroup=None, start_localSGD_iter=100) + >>> model.register_comm_hook(state, post_localSGD_hook) + >>> + >>> # In the first 100 steps, run global gradient averaging like normal DDP at every step. + >>> # After 100 steps, run model averaging every 4 steps. + >>> # Note that ``warmup_steps`` must be the same as ``start_localSGD_iter`` used in ``PostLocalSGDState``. + >>> averager = averagers.PeriodicModelAverager(period=4, warmup_steps=100) + >>> for step in range(0, 200): + >>> optimizer.zero_grad() + >>> loss = loss_fn(output, labels) + >>> loss.backward() + >>> optimizer.step() + >>> # Will average model parameters globally every 4 steps. Thus, + >>> # inter-node communication only occurs every 4 iterations after + >>> # the initial ``warmup_steps`` period. + >>> averager.average_parameters(model.parameters()) + """ + + def __init__( + self, period, warmup_steps=0, process_group: Optional[dist.ProcessGroup] = None + ): + super().__init__(process_group) + if warmup_steps < 0: + raise ValueError("Arg ``warmup_steps`` must be a non-negative number.") + self.warmup_steps = warmup_steps + if period < 1: + raise ValueError("Arg ``period`` must be a positive value.") + elif period == 1: + warnings.warn( + "When period is 1, no need to use model averaging because the communication cost " + "of all-reducing parameters will be no less than the cost of all-reducing gradients " + "by DistributedDataParallel in the backward pass. Therefore, only " + "DistributedDataParallel should be used for this case." + ) + self.period = period + + def average_parameters( + self, + params: Union[ + Iterable[core.nn.Parameter], Iterable[Dict[str, core.nn.Parameter]] + ], + ): + """ + Averages parameters or parameter groups of an optimizer if ``step`` is no less than ``warmup_steps``. + + Can be divided by ``period``, where ``step`` is increased by 1 + at each iteration in the training loop. + Args: + params: The parameters of a model or parameter groups of an optimizer. + + """ + if ( + self.step >= self.warmup_steps + and (self.step - self.warmup_steps) % self.period == 0 + ): + utils.average_parameters_or_parameter_groups( + params, _not_none(self.process_group) + ) + self.step += 1 diff --git a/mindnlp/core/distributed/algorithms/model_averaging/hierarchical_model_averager.py b/mindnlp/core/distributed/algorithms/model_averaging/hierarchical_model_averager.py new file mode 100644 index 000000000..9b31e7ee6 --- /dev/null +++ b/mindnlp/core/distributed/algorithms/model_averaging/hierarchical_model_averager.py @@ -0,0 +1,180 @@ +# mypy: allow-untyped-defs +# Copyright 2022 Cruise LLC +import logging +import warnings +from collections import OrderedDict +from typing import Dict, Iterable, Union + +from mindnlp import core +from mindnlp import core.distributed as dist +from mindnlp import core.distributed.algorithms.model_averaging.averagers as averagers +from mindnlp import core.distributed.algorithms.model_averaging.utils as utils + + +logger = logging.getLogger(__name__) + + +class HierarchicalModelAverager(averagers.ModelAverager): + r""" + Runs hierarchical model averaging (`hierarchical SGD `_). + + Process groups of different sizes are organized in a hierarchy, and they average parameters + by using different periods concurrently after the warm-up stage. + This is an extension of :class:`~core.distributed.algorithms.model_averaging.averagers.PeriodicModelAverager` + that supports `post-local SGD `_, which essentially only supports + a two-level hierarchy: the intra-machine level and the global level, where the intra-machine + level is usually embedded in :meth:`~core.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook`. + Similarly, the process groups within this class do not have such an intra-machine process + subgroup, which should be embedded by the post-local SGD communication hook instead. + + Args: + period_group_size_dict: An ordered dict mapping keys of model averaging period to + process group size, used for initializing process groups of + different sizes in a hierarchy to average parameters concurrently. + Particularly, at each iteration, there will be at most a single + process group that runs averaging -- the period of such group should + have the largest period which the current step can be divided by. + For example, if the dict has three keys: 2, 4, and 8, + then this means totally three process groups will be created to + average parameters every 2, 4, and 8 iterations, respectively. + At the 4th iteration, only the second process group will run + averaging, because the first process group should be a + subset of the second process group, and no need to execute the first + process group redundantly. + On the other hand, the third process group can only be triggered + every 8 iterations, so it will not be triggered at the 4th iteration. + warmup_steps (int): The number of warm-up steps. During this stage, model averaging is skipped. + process_group (ProcessGroup, optional): The overall process group containing all the processes that runs model averaging. + If ``None``, the default process group, which is created + by :func:`core.distributed.init_process_group`, will be used. + (default: ``None``) + + Example:: + >>> # xdoctest: +SKIP('undefined rank') + >>> from collections import OrderedDict + >>> from mindnlp import core + >>> from mindnlp import core.distributed as dist + >>> from core.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook import ( + >>> PostLocalSGDState, + >>> post_localSGD_hook, + >>> ) + >>> from mindnlp import core.distributed.algorithms.model_averaging.hierarchical_model_averager as hierarchicalSGD + >>> from mindnlp import core.nn as nn + >>> + >>> dist.init_process_group("nccl", rank=rank, world_size=16) + >>> core.cuda.set_device(rank) + >>> module = nn.Linear(1, 1, bias=False).to(rank) + >>> model = nn.parallel.DistributedDataParallel( + >>> module, device_ids=[rank], output_device=rank + >>> ) + >>> # Register a post-localSGD communication hook. + >>> # Assume that each machine has 4 GPUs, then each intra-machine subgroup has a size of 4. + >>> subgroup, _ = dist.new_subgroups() + >>> state = PostLocalSGDState(process_group=None, subgroup=subgroup, start_localSGD_iter=100) + >>> model.register_comm_hook(state, post_localSGD_hook) + >>> + >>> # Average parameters among each group of 8 processes every 4 iterations, and among all + >>> # the 16 processes every 16 iterations. + >>> averager = hierarchicalSGD.HierarchicalModelAverager( + >>> period_group_size_dict=OrderedDict([(4, 8), (16, 16)]), warmup_steps=100) + >>> # Note that ``warmup_steps`` must be the same as ``start_localSGD_iter`` used in ``PostLocalSGDState``. + >>> # In the first 100 steps, run global gradient averaging like normal DDP at every step. + >>> # After 100 steps, run model averaging at two levels. + >>> for step in range(0, 200): + >>> optimizer.zero_grad() + >>> loss = loss_fn(output, labels) + >>> loss.backward() + >>> optimizer.step() + >>> # Average parameters after ``optimizer.step()``. + >>> # Thus, the inter-node communication only occurs periodically after ``warmup_steps``. + >>> averager.average_parameters(model.parameters()) + + .. warning :: + The last group size in the dict must be the size of the provided ``process_group``, + which indicates model averaging at the highest level of the hierarchy. + If ``process_group`` is not provided, then the last group size should be equal to the world size. + + .. warning :: + `HierarchicalModelAverager` is experimental and subject to change. + """ + + def __init__(self, period_group_size_dict=None, warmup_steps=0, process_group=None): + super().__init__(process_group) + if not period_group_size_dict: + raise ValueError("Arg ``period_group_size_dict`` must not be empty.") + self._periods = list(period_group_size_dict.keys()) + if self._periods[0] <= 0: + raise ValueError( + "The minimum period in arg ``period_group_size_dict`` must be a positive value." + ) + elif self._periods[-1] == 1: + warnings.warn( + "When the maximum period in arg ``period_group_size_dict`` is 1, " + "no need to use model averaging because the communication cost " + "of all-reducing parameters will be no less than the cost of all-reducing gradients " + "by DistributedDataParallel in the backward pass. Therefore, only " + "DistributedDataParallel should be used for this case." + ) + overall_group_size = dist.get_world_size(group=self.process_group) + if list(period_group_size_dict.values())[-1] != overall_group_size: + raise ValueError( + f"The last value in arg ``period_process_group_dict`` {list(period_group_size_dict.values())[-1]} " + f"must be equal to the size of arg ``process_group`` {overall_group_size}." + ) + + self.period_process_group_dict = OrderedDict() + logger.info("Model averaging hierarchy:") + for period, group_size in period_group_size_dict.items(): + logger.info( + "\tEach group that has %s processes average parameters every %s iterations, " + "if no higher-level averaging.", + group_size, + period, + ) + if group_size != overall_group_size: + self.period_process_group_dict[period], _ = dist.new_subgroups( + group_size=group_size, group=self.process_group + ) + else: + self.period_process_group_dict[period] = self.process_group + + if warmup_steps < 0: + raise ValueError("Arg ``warmup_steps`` must be a non-negative number.") + self.warmup_steps = warmup_steps + + def _find_process_group(self): + """ + Return a process group as the value of an ``period_process_group_dict`` entry. + + If ``step`` can be divided by multiple periods in the keys of ``period_process_group_dict``, + then the returned process group is the one corresponding to the largest period, + since this process group will be used for averaging parameters at this ``step``. + Returns ``None`` if not found. + """ + for period in reversed(self._periods): + if self.step % period == 0: + return self.period_process_group_dict[period] + return None + + def average_parameters( + self, + params: Union[ + Iterable[core.nn.Parameter], Iterable[Dict[str, core.nn.Parameter]] + ], + ): + """ + Averages parameters or parameter groups of an optimizer. + + Averaging only occurs if ``step`` is no less than ``warmup_steps`` + and it can be divided by a period in the keys of ``period_process_group_dict``, + where ``step`` is increased by 1 at each iteration in the training loop. + If ``step`` can be divided by multiple periods in the keys of ``period_process_group_dict``, + only the largest period is used, and the corresponding process group is used for averaging parameters. + Args: + params: The parameters of a model or parameter groups of an optimizer. + """ + if self.step >= self.warmup_steps: + group = self._find_process_group() + if group is not None: + utils.average_parameters_or_parameter_groups(params, group) + self.step += 1 diff --git a/mindnlp/core/distributed/algorithms/model_averaging/utils.py b/mindnlp/core/distributed/algorithms/model_averaging/utils.py new file mode 100644 index 000000000..d37e50a39 --- /dev/null +++ b/mindnlp/core/distributed/algorithms/model_averaging/utils.py @@ -0,0 +1,89 @@ +# mypy: allow-untyped-defs +# flake8: noqa C101 +import itertools +from typing import Dict, Iterable, Iterator, Union + +from mindnlp import core +from mindnlp import core.distributed as dist + +# The two imports below are not always available depending on the +# USE_DISTRIBUTED compile flag. Make sure they raise import error +# if we're trying to use them. +from core.distributed import group, ProcessGroup + + +__all__ = [ + "average_parameters", + "get_params_to_average", + "average_parameters_or_parameter_groups", +] + + +def average_parameters( + params: Iterator[core.nn.Parameter], process_group: ProcessGroup +): + """ + Averages all the given parameters. + + For allreduce efficiency, all the parameters are flattened into a contiguous buffer. + Thus, it requires extra memory of the same size as the given parameters. + """ + group_to_use = process_group if process_group is not None else group.WORLD + # Do not update any parameter if not in the process group. + if dist._rank_not_in_group(group_to_use): + return + + params_it1, params_it2 = itertools.tee(params) + # If the input parameters have different data types, + # packing these parameters will trigger an implicit type up-casting. + # The original parameter data types will be restored during the subsequent unpacking. + flat_params = core.cat([p.data.reshape(-1) for p in params_it1]) + flat_params /= dist.get_world_size(group_to_use) + # Make sure the allreduce will not conflict with any other ongoing process group. + if core.accelerator.is_available(): + core.accelerator.synchronize() + dist.all_reduce(flat_params, group=group_to_use) + + offset = 0 + for p in params_it2: + p.data = flat_params[offset : offset + p.numel()].view_as(p).type_as(p) + offset += p.numel() + + +def get_params_to_average( + params: Union[Iterable[core.nn.Parameter], Iterable[Dict[str, core.nn.Parameter]]] +): + """ + Return a list of parameters that need to average. + + This filters out the parameters that do not contain any gradients. + Args: + params: The parameters of a model or parameter groups of an optimizer. + """ + filtered_params = [] + for param in params: + if isinstance(param, core.nn.Parameter): + # model.parameters() input + param_data = param + if param_data.grad is not None: + filtered_params.append(param_data) + elif isinstance(param, dict): + # optimizer.param_groups input + for param_data in param["params"]: + if param_data.grad is not None: + filtered_params.append(param_data) + else: + raise NotImplementedError( + f"Parameter input of type {type(param)} is not supported" + ) + return filtered_params + + +def average_parameters_or_parameter_groups( + params: Union[ + Iterable[core.nn.Parameter], Iterable[Dict[str, core.nn.Parameter]] + ], + process_group: ProcessGroup, +): + """Averages parameters of a model or parameter groups of an optimizer.""" + average_parameters(iter(get_params_to_average(params)), process_group) diff --git a/mindnlp/core/distributed/argparse_util.py b/mindnlp/core/distributed/argparse_util.py new file mode 100644 index 000000000..5a4030479 --- /dev/null +++ b/mindnlp/core/distributed/argparse_util.py @@ -0,0 +1,104 @@ +#!/usr/bin/env python3 +# mypy: allow-untyped-defs + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import os +from argparse import Action + + +class env(Action): + """ + Get argument values from ``PET_{dest}`` before defaulting to the given ``default`` value. + + For flags (e.g. ``--standalone``) + use ``check_env`` instead. + + .. note:: when multiple option strings are specified, ``dest`` is + the longest option string (e.g. for ``"-f", "--foo"`` + the env var to set is ``PET_FOO`` not ``PET_F``) + + Example: + :: + + parser.add_argument("-f", "--foo", action=env, default="bar") + + ./program -> args.foo="bar" + ./program -f baz -> args.foo="baz" + ./program --foo baz -> args.foo="baz" + PET_FOO="env_bar" ./program -f baz -> args.foo="baz" + PET_FOO="env_bar" ./program --foo baz -> args.foo="baz" + PET_FOO="env_bar" ./program -> args.foo="env_bar" + + parser.add_argument("-f", "--foo", action=env, required=True) + + ./program -> fails + ./program -f baz -> args.foo="baz" + PET_FOO="env_bar" ./program -> args.foo="env_bar" + PET_FOO="env_bar" ./program -f baz -> args.foo="baz" + """ + + def __init__(self, dest, default=None, required=False, **kwargs) -> None: + env_name = f"PET_{dest.upper()}" + default = os.environ.get(env_name, default) + + # ``required`` means that it NEEDS to be present in the command-line args + # rather than "this option requires a value (either set explicitly or default" + # so if we found default then we don't "require" it to be in the command-line + # so set it to False + if default: + required = False + + super().__init__(dest=dest, default=default, required=required, **kwargs) + + def __call__(self, parser, namespace, values, option_string=None): + setattr(namespace, self.dest, values) + + +class check_env(Action): + """ + Check whether the env var ``PET_{dest}`` exists before defaulting to the given ``default`` value. + + Equivalent to + ``store_true`` argparse built-in action except that the argument can + be omitted from the commandline if the env var is present and has a + non-zero value. + + .. note:: it is redundant to pass ``default=True`` for arguments + that use this action because a flag should be ``True`` + when present and ``False`` otherwise. + + Example: + :: + + parser.add_argument("--verbose", action=check_env) + + ./program -> args.verbose=False + ./program --verbose -> args.verbose=True + PET_VERBOSE=1 ./program -> args.verbose=True + PET_VERBOSE=0 ./program -> args.verbose=False + PET_VERBOSE=0 ./program --verbose -> args.verbose=True + + Anti-pattern (don't do this): + + :: + + parser.add_argument("--verbose", action=check_env, default=True) + + ./program -> args.verbose=True + ./program --verbose -> args.verbose=True + PET_VERBOSE=1 ./program -> args.verbose=True + PET_VERBOSE=0 ./program -> args.verbose=False + + """ + + def __init__(self, dest, default=False, **kwargs) -> None: + env_name = f"PET_{dest.upper()}" + default = bool(int(os.environ.get(env_name, "1" if default else "0"))) + super().__init__(dest=dest, const=True, default=default, nargs=0, **kwargs) + + def __call__(self, parser, namespace, values, option_string=None): + setattr(namespace, self.dest, self.const) diff --git a/mindnlp/core/distributed/autograd/__init__.py b/mindnlp/core/distributed/autograd/__init__.py new file mode 100644 index 000000000..bb7963bdf --- /dev/null +++ b/mindnlp/core/distributed/autograd/__init__.py @@ -0,0 +1,53 @@ +# mypy: allow-untyped-defs + +from mindnlp import core + + +def is_available(): + return hasattr(core._C, "_dist_autograd_init") + + +if is_available() and not core._C._dist_autograd_init(): + raise RuntimeError("Failed to initialize core.distributed.autograd") + +if is_available(): + from core._C._distributed_autograd import ( + _current_context, + _get_debug_info, + _get_max_id, + _init, + _is_valid_context, + _new_context, + _release_context, + _retrieve_context, + backward, + DistAutogradContext, + get_gradients, + ) + + +class context: + """ + Context object to wrap forward and backward passes when using + distributed autograd. The ``context_id`` generated in the ``with`` + statement is required to uniquely identify a distributed backward pass + on all workers. Each worker stores metadata associated with this + ``context_id``, which is required to correctly execute a distributed + autograd pass. + + Example:: + >>> # xdoctest: +SKIP + >>> from mindnlp import core.distributed.autograd as dist_autograd + >>> with dist_autograd.context() as context_id: + >>> t1 = core.rand((3, 3), requires_grad=True) + >>> t2 = core.rand((3, 3), requires_grad=True) + >>> loss = rpc.rpc_sync("worker1", core.add, args=(t1, t2)).sum() + >>> dist_autograd.backward(context_id, [loss]) + """ + + def __enter__(self): + self.autograd_context = _new_context() + return self.autograd_context._context_id() + + def __exit__(self, type, value, traceback): + _release_context(self.autograd_context._context_id()) diff --git a/mindnlp/core/distributed/benchmarks/README.md b/mindnlp/core/distributed/benchmarks/README.md new file mode 100644 index 000000000..af0510b5c --- /dev/null +++ b/mindnlp/core/distributed/benchmarks/README.md @@ -0,0 +1,68 @@ +# Benchmark combining Distributed Data Parallel and Distributed RPC + +This Benchmark is used to measure distributed training iteration time. It combines Distributed Data Parallelism with Distributed Model Parallelism leveraging PyTorch DDP and the Distributed RPC Framework. The number of trainer nodes and parameter servers are configurable. The default is 8 trainers, 1 master node and 8 parameter servers. + +## Background + +There are different training paradigms where combining these two techniques might be useful. For example: +1) If we have a model with a sparse part (large embedding table) and a dense + part (FC layers), we might want to set the embedding table on a parameter + server and replicate the FC layer across multiple trainers using [DistributedDataParallel](https://pycore.org/docs/stable/nn.html#core.nn.parallel.DistributedDataParallel). The [Distributed RPC framework](https://pycore.org/docs/main/rpc.html) comes handy to perform embedding lookups on the parameter servers. +2) Enable hybrid parallelism as described in the [PipeDream](https://arxiv.org/abs/1806.03377) paper. We can use the [Distributed RPC framework](https://pycore.org/docs/main/rpc.html) to pipeline stages of the model across multiple workers and replicate each stage (if needed) using [DistributedDataParallel](https://pycore.org/docs/stable/nn.html#core.nn.parallel.DistributedDataParallel). + +## Training Process +This benchmark focuses on the first paradigm above. The training process is executed as follows: + +1) The master creates embedding tables on each of the 8 Parameter Servers and holds an [RRef](https://pycore.org/docs/main/rpc.html#rref) to it. +2) The master, then kicks off the training loop on the 8 trainers and passes the embedding table RRef to the trainers. +3) The trainers create a `HybridModel` which performs embedding lookups in all 8 Parameter Servers using the embedding table RRef provided by the master and then executes the FC layer which is wrapped and replicated via DDP (DistributedDataParallel). +4) The trainer executes the forward pass of the model and uses the loss to + execute the backward pass using [Distributed Autograd](https://pycore.org/docs/main/rpc.html#distributed-autograd-framework). +5) As part of the backward pass, the gradients for the FC layer are computed + first and synced to all trainers via allreduce in DDP. +6) Next, Distributed Autograd propagates the gradients to the parameter servers, + where the gradients for the embedding table are updated. +7) Finally, the [Distributed Optimizer](https://pycore.org/docs/main/rpc.html#module-core.distributed.optim) is used to update all parameters. + + +## Example Benchmark output: + +---------- Info --------- + +* PyTorch version: 1.7.0 +* CUDA version: 9.2.0 + +---------- nvidia-smi topo -m --------- + + GPU0 GPU1 GPU2 GPU3 GPU4 GPU5 GPU6 GPU7 CPU Affinity + GPU0 X NV2 NV1 NV2 NV1 NODE NODE NODE 0-19,40-59 + GPU1 NV2 X NV2 NV1 NODE NV1 NODE NODE 0-19,40-59 + GPU2 NV1 NV2 X NV1 NODE NODE NV2 NODE 0-19,40-59 + GPU3 NV2 NV1 NV1 X NODE NODE NODE NV2 0-19,40-59 + GPU4 NV1 NODE NODE NODE X NV2 NV1 NV2 0-19,40-59 + GPU5 NODE NV1 NODE NODE NV2 X NV2 NV1 0-19,40-59 + GPU6 NODE NODE NV2 NODE NV1 NV2 X NV1 0-19,40-59 + GPU7 NODE NODE NODE NV2 NV2 NV1 NV1 X 0-19,40-59 + +Legend: + + X = Self + SYS = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI) + NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node + PHB = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU) + PXB = Connection traversing multiple PCIe switches (without traversing the PCIe Host Bridge) + PIX = Connection traversing a single PCIe switch + NV# = Connection traversing a bonded set of # NVLinks + +------------------ PyTorch Distributed Benchmark (DDP and RPC) --------------------- + + sec/epoch epoch/sec sec/epoch epoch/sec sec/epoch epoch/sec sec/epoch epoch/sec + Trainer0: p50: 0.376s 185/s p75: 0.384s 182/s p90: 0.390s 179/s p95: 0.396s 176/s + Trainer1: p50: 0.377s 204/s p75: 0.384s 200/s p90: 0.389s 197/s p95: 0.393s 195/s + Trainer2: p50: 0.377s 175/s p75: 0.384s 172/s p90: 0.390s 169/s p95: 0.395s 166/s + Trainer3: p50: 0.377s 161/s p75: 0.384s 158/s p90: 0.390s 156/s p95: 0.393s 155/s + Trainer4: p50: 0.377s 172/s p75: 0.383s 169/s p90: 0.389s 166/s p95: 0.395s 164/s + Trainer5: p50: 0.377s 180/s p75: 0.383s 177/s p90: 0.389s 174/s p95: 0.395s 172/s + Trainer6: p50: 0.377s 204/s p75: 0.384s 200/s p90: 0.390s 197/s p95: 0.394s 195/s + Trainer7: p50: 0.377s 185/s p75: 0.384s 182/s p90: 0.389s 179/s p95: 0.394s 177/s + All: p50: 0.377s 1470/s p75: 0.384s 1443/s p90: 0.390s 1421/s p95: 0.396s 1398/s diff --git a/mindnlp/core/distributed/benchmarks/benchmark_ddp_rpc.py b/mindnlp/core/distributed/benchmarks/benchmark_ddp_rpc.py new file mode 100644 index 000000000..6b3fd5850 --- /dev/null +++ b/mindnlp/core/distributed/benchmarks/benchmark_ddp_rpc.py @@ -0,0 +1,363 @@ +# mypy: allow-untyped-defs + +# pyre-unsafe +import argparse +import io +import os +import random +import shlex +import subprocess +import time + +import numpy as np + +from mindnlp import core +from mindnlp import core.distributed as dist +from mindnlp import core.distributed.autograd as dist_autograd +from mindnlp import core.distributed.rpc as rpc +from mindnlp import core.multiprocessing as mp +from mindnlp import core.nn as nn +from mindnlp import core.optim as optim +from core.distributed.optim import DistributedOptimizer +from core.distributed.rpc import RRef, TensorPipeRpcBackendOptions +from core.distributed.rpc.backend_registry import BackendType +from core.nn.parallel import DistributedDataParallel as DDP + + +# Config +NUM_TRAINERS = 8 +NUM_PS = 8 + +NUM_EMBEDDINGS = 300 +EMBEDDING_DIM = 64 + +WARMUP_CYCLES = 5 + + +class HybridModel(core.nn.Module): + r""" + The model consists of a sparse part and a dense part. + + The dense part is an nn.Linear module that is replicated across all trainers using + DistributedDataParallel. The sparse part has nn.EmbeddingBags stored on multiple + parameter servers. + + The model holds a Remote Reference to the embedding tables on the parameter + servers. + """ + + def __init__(self, emb_rref_list, device): + super().__init__() + self.emb_rref_list = emb_rref_list + fc1 = core.nn.Linear(512, 256) + fc2 = core.nn.Linear(256, 128) + relu = core.nn.ReLU() + fc3 = core.nn.Linear(128, 64) + fc4 = core.nn.Linear(64, 32) + fc5 = core.nn.Linear(32, 8) + sec = nn.Sequential(fc1, fc2, relu, fc3, fc4, fc5) + self.ddp = DDP(sec.to(device), device_ids=[device]) + self.device = device + + def forward(self, indices, offsets): + emb_lookups = [] + + for emb_rref in self.emb_rref_list: + emb_lookups.append( + emb_rref.rpc_sync().forward( + indices, offsets + ) # embedding_sum(input, offsets) + ) + emb_lookups_cat = core.cat(emb_lookups, dim=1) + + # Make sure combined PS dimension is always bigger or equal than the FC input + assert NUM_PS * EMBEDDING_DIM >= 512 + dim_normalizer = int(NUM_PS * EMBEDDING_DIM / 512) + emb_lookups_reshaped = emb_lookups_cat.reshape( # type: ignore[possibly-undefined] + [emb_lookups_cat.shape[0] * dim_normalizer, 512] + ) + + return self.ddp(emb_lookups_reshaped) + + +def _retrieve_embedding_parameters(emb_rref): + return [RRef(p) for p in emb_rref.local_value().parameters()] + + +def _print_header(): + _print_cont("\n") + _print_cont("%10s" % "") + for _ in [50, 75, 90, 95]: + _print_cont("%14s%10s" % ("sec/epoch", "epoch/sec")) + _print_cont("\n") + + +def _print_benchmark(prefix, nelem, measurements): + measurements = sorted(measurements) + _print_cont("%8s:" % prefix) + for p in [50, 75, 90, 95]: + v = np.percentile(measurements, p) + _print_cont(" p%02d: %1.3fs %6d/s" % (p, v, nelem / v)) + _print_cont("\n") + + +def _print_cont(msg): + print(msg, end="", flush=True) + + +def _run_printable(cmd): + proc = subprocess.run(shlex.split(cmd), capture_output=True, check=False) # type: ignore[call-overload] + assert proc.returncode == 0 + + buffer = io.BytesIO() + core.save(proc.stdout.decode("utf-8"), buffer) + input_tensor = core.ByteTensor(list(buffer.getvalue())) + + output = [] + buffer = io.BytesIO(np.asarray(input_tensor).tobytes()) + output.append(core.load(buffer)) + return output + + +def _run_trainer(emb_rref_list, rank): + r""" + Each trainer runs a forward pass which involves an embedding lookup on the 8 parameter servers, + and running nn.Linear locally. + + During the backward pass, DDP is responsible for aggregating the gradients for the dense part + (nn.Linear) and distributed autograd ensures gradients updates are + propagated to the parameter servers. + """ + # Setup the model. + model = HybridModel(emb_rref_list, rank) + + # Retrieve all model parameters as rrefs for DistributedOptimizer. + + # Retrieve parameters from all embedding tables for the current trainer. + model_parameter_rrefs = [] + for ind, emb_rref in enumerate(emb_rref_list): + ps_name = f"ps{ind}" + model_parameter_rrefs.extend( + rpc.rpc_sync(ps_name, _retrieve_embedding_parameters, args=(emb_rref,)) + ) + + # model.parameters() only includes local parameters. + model_parameter_rrefs.extend(RRef(param) for param in model.parameters()) + + # Setup distributed optimizer + opt = DistributedOptimizer(optim.SGD, model_parameter_rrefs, lr=0.05) + + criterion = core.nn.CrossEntropyLoss() + + def get_next_batch(rank): + for _ in range(10): + num_indices = random.randint(20, 50) + indices = core.LongTensor(num_indices).random_(0, NUM_EMBEDDINGS) + + # Generate offsets. + offsets = [] + start = 0 + batch_size = 0 + + while start < num_indices: + offsets.append(start) + start += random.randint(1, 10) + batch_size += 1 + + offsets_tensor = core.LongTensor(offsets) + target = core.LongTensor(batch_size).random_(8).cuda(rank) + + yield indices, offsets_tensor, target + + measurements = [] + # Include warm-up cycles during training + for _ in range(100 + WARMUP_CYCLES): + start = time.time() + batch_size = 0 + + # create distributed autograd context + for indices, offsets, target in get_next_batch(rank): + batch_size += len(target) + + with dist_autograd.context() as context_id: + output = model(indices, offsets) + loss = criterion(output, target) + + # Run distributed backward pass + dist_autograd.backward(context_id, [loss]) + + # Run distributed optimizer. Gradients propagated all the way to the parameter servers + opt.step(context_id) + + # Not necessary to zero grads as each iteration creates a different + # distributed autograd context which hosts different grads + + measurements.append(time.time() - start) + # print("Training done for epoch {}".format(epoch)) + + # Throw away warm-up measurements + measurements = measurements[WARMUP_CYCLES:] + return rank, measurements, batch_size # type: ignore[possibly-undefined] + + +def run_worker(rank, world_size): + r""" + Initialize RPC, calls the function, and shuts down RPC. + """ + # Using different port numbers in TCP init_method for init_rpc and + # init_process_group to avoid port conflicts. + rpc_backend_options = TensorPipeRpcBackendOptions() + rpc_backend_options.init_method = "tcp://localhost:29500" + + # Rank 16. Master + if rank == (NUM_TRAINERS + NUM_PS): + rpc.init_rpc( + "master", + rank=rank, + backend=BackendType.TENSORPIPE, # type: ignore[attr-defined] + world_size=world_size, + ) + + # Build the Embedding tables on the Parameter Servers. + emb_rref_list = [] + index = 0 + while index < NUM_PS: + ps_name = f"ps{index}" + emb_rref = rpc.remote( + ps_name, + core.nn.EmbeddingBag, + args=(NUM_EMBEDDINGS, EMBEDDING_DIM), + kwargs={"mode": "sum"}, + ) + emb_rref_list.append(emb_rref) + index += 1 + + # Run training loop on the trainers. + futs = [] + for trainer_rank in range(NUM_TRAINERS): + trainer_name = f"trainer{trainer_rank}" + fut = rpc.rpc_async( + trainer_name, _run_trainer, args=(emb_rref_list, trainer_rank) + ) + futs.append(fut) + + _print_header() + + measurements_all_trainers = [] + batch_size_all_trainers = 0 + # Wait for all training to finish. + for fut in futs: + rank, measurements, batch_size = fut.wait() + _print_benchmark(f"Trainer{rank}", batch_size, measurements) + batch_size_all_trainers += batch_size + measurements_all_trainers.append(measurements) + + _print_benchmark("All", batch_size_all_trainers, measurements_all_trainers) + + # Rank 0-7. Trainers + elif rank >= 0 and rank < NUM_PS: + # Initialize process group for Distributed DataParallel on trainers. + dist.init_process_group( + backend=dist.Backend.GLOO, + rank=rank, + world_size=NUM_TRAINERS, + init_method="tcp://localhost:29501", + ) + + # Initialize RPC. Trainer just waits for RPCs from master. + trainer_name = f"trainer{rank}" + rpc.init_rpc( + trainer_name, + rank=rank, + world_size=world_size, + rpc_backend_options=rpc_backend_options, + ) + + # Rank 8-15. Parameter Servers + elif rank >= NUM_TRAINERS and rank < NUM_TRAINERS + NUM_PS: + ps_name = f"ps{rank - NUM_TRAINERS}" + rpc.init_rpc( + ps_name, + rank=rank, + world_size=world_size, + backend=BackendType.TENSORPIPE, # type: ignore[attr-defined] + rpc_backend_options=rpc_backend_options, + ) + # parameter server do nothing + + # block until all rpcs finish + rpc.shutdown() + + +if __name__ == "__main__": + """Initializing the distributed environment.""" + + output = _run_printable("nvidia-smi topo -m") + print("-------------------------------------------") + print(" Info ") + print("-------------------------------------------") + print() + print(f"* PyTorch version: {core.__version__}") + print(f"* CUDA version: {core.version.cuda}") + print() + print("------------ nvidia-smi topo -m -----------") + print() + print(output[0]) + print("-------------------------------------------") + print("PyTorch Distributed Benchmark (DDP and RPC)") + print("-------------------------------------------") + + # Cmd arguments to enable automated runs (e.g. Chronos, SSH, etc). + parser = argparse.ArgumentParser(description="PyTorch DDP and RPC Benchmark") + parser.add_argument( + "--master-addr", type=str, default="localhost", help="Address of master node." + ) + parser.add_argument("--master-port", type=str, default="29500", help="Master port.") + + parser.add_argument( + "--number-trainers", + type=int, + default=NUM_TRAINERS, + help="Number of Trainer Nodes.", + ) + parser.add_argument( + "--number-ps", type=int, default=NUM_PS, help="Number of Parameter Servers." + ) + parser.add_argument( + "--number-embeddings", + type=int, + default=NUM_EMBEDDINGS, + help="Number of test embeddings to be generated.", + ) + parser.add_argument( + "--embedding-dim", + type=int, + default=EMBEDDING_DIM, + help="Number of embedding dimensions.", + ) + parser.add_argument( + "--warmup-cycles", + type=int, + default=WARMUP_CYCLES, + help="Number of cycles to warm-up each process before running the benchmark.", + ) + + args = parser.parse_args() + + os.environ["MASTER_ADDR"] = args.master_addr + os.environ["MASTER_PORT"] = args.master_port + + NUM_TRAINERS = args.number_trainers + NUM_PS = args.number_ps + + NUM_EMBEDDINGS = args.number_embeddings + EMBEDDING_DIM = args.embedding_dim + + WARMUP_CYCLES = args.warmup_cycles + + # Defaults: + # 8 trainers (rank 0-7), + # 8 parameter servers (rank 8-15), + # 1 master (rank 16). + world_size = NUM_TRAINERS + NUM_PS + 1 # Trainers + PS + Master + mp.spawn(run_worker, args=(world_size,), nprocs=world_size, join=True) diff --git a/mindnlp/core/distributed/c10d/__init__.py b/mindnlp/core/distributed/c10d/__init__.py new file mode 100644 index 000000000..4549853f2 --- /dev/null +++ b/mindnlp/core/distributed/c10d/__init__.py @@ -0,0 +1,6 @@ +from .store import Store +from .prefix_store import PrefixStore +from .types import * +from .process_group import ProcessGroup +from .work import Work +from .backend import Backend diff --git a/mindnlp/core/distributed/c10d/backend.py b/mindnlp/core/distributed/c10d/backend.py new file mode 100644 index 000000000..2059de8f8 --- /dev/null +++ b/mindnlp/core/distributed/c10d/backend.py @@ -0,0 +1,183 @@ +from mindnlp import core +from typing import List, Optional, Callable, Any +from enum import Enum +import time + +# Enum for Backend operations +class OpType(Enum): + BROADCAST = 0 + ALLREDUCE = 1 + ALLREDUCE_COALESCED = 2 + REDUCE = 3 + ALLGATHER = 4 + _ALLGATHER_BASE = 5 + ALLGATHER_COALESCED = 6 + GATHER = 7 + SCATTER = 8 + REDUCE_SCATTER = 9 + ALLTOALL_BASE = 10 + ALLTOALL = 11 + SEND = 12 + RECV = 13 + RECVANYSOURCE = 14 + BARRIER = 15 + _REDUCE_SCATTER_BASE = 16 + COALESCED = 17 + _ALLREDUCE_SPARSE = 18 + UNKNOWN = 100 + +kBackendDefaultTimeout = 30 * 60 * 1000 # Default timeout in milliseconds + +class Backend: + + class Options: + def __init__(self, backend: str, timeout: int = kBackendDefaultTimeout): + self.timeout = timeout + self.backend = backend + + def __init__(self, rank: int, size: int): + self.rank_ = rank + self.size_ = size + self.pg_uid_ = "" + self.pg_desc_ = "" + self.dist_debug_level_ = "Off" + self.bound_device_id_ = None + + def get_rank(self) -> int: + return self.rank_ + + def get_size(self) -> int: + return self.size_ + + def get_id(self) -> int: + return id(self) + + def supports_splitting(self) -> bool: + return False + + def start_coalescing(self): + raise NotImplementedError(f"Backend {self.get_backend_name()} does not implement startCoalescing.") + + def end_coalescing(self): + raise NotImplementedError(f"Backend {self.get_backend_name()} does not implement endCoalescing.") + + def get_backend_name(self) -> str: + raise NotImplementedError("getBackendName is not implemented.") + + def broadcast(self, tensors: List[core.Tensor], opts: Optional[Any] = None): + raise NotImplementedError(f"Backend {self.get_backend_name()} does not support broadcast.") + + def allreduce(self, tensors: List[core.Tensor], opts: Optional[Any] = None): + raise NotImplementedError(f"Backend {self.get_backend_name()} does not support allreduce.") + + def allreduce_sparse(self, tensors: List[core.Tensor], opts: Optional[Any] = None): + raise NotImplementedError(f"Backend {self.get_backend_name()} does not support allreduce sparse.") + + def allreduce_coalesced(self, tensors: List[core.Tensor], opts: Optional[Any] = None): + raise NotImplementedError(f"Backend {self.get_backend_name()} does not support allreduce_coalesced.") + + def reduce(self, tensors: List[core.Tensor], opts: Optional[Any] = None): + raise NotImplementedError(f"Backend {self.get_backend_name()} does not support reduce.") + + def allgather(self, output_tensors: List[List[core.Tensor]], input_tensors: List[core.Tensor], opts: Optional[Any] = None): + raise NotImplementedError(f"Backend {self.get_backend_name()} does not support allgather.") + + def _allgather_base(self, output_buffer: core.Tensor, input_buffer: core.Tensor, opts: Optional[Any] = None): + raise NotImplementedError(f"Backend {self.get_backend_name()} does not support _allgather_base.") + + def allgather_coalesced(self, output_tensor_lists: List[List[core.Tensor]], input_tensors: List[core.Tensor], opts: Optional[Any] = None): + raise NotImplementedError(f"Backend {self.get_backend_name()} does not support allgather_coalesced.") + + def allgather_into_tensor_coalesced(self, outputs: List[core.Tensor], inputs: List[core.Tensor], opts: Optional[Any] = None): + raise NotImplementedError(f"Backend {self.get_backend_name()} does not support allgather_into_tensor_coalesced.") + + def gather(self, output_tensors: List[List[core.Tensor]], input_tensors: List[core.Tensor], opts: Optional[Any] = None): + raise NotImplementedError(f"Backend {self.get_backend_name()} does not support gather.") + + def scatter(self, output_tensors: List[core.Tensor], input_tensors: List[List[core.Tensor]], opts: Optional[Any] = None): + raise NotImplementedError(f"Backend {self.get_backend_name()} does not support scatter.") + + def reduce_scatter(self, output_tensors: List[core.Tensor], input_tensors: List[List[core.Tensor]], opts: Optional[Any] = None): + raise NotImplementedError(f"Backend {self.get_backend_name()} does not support reduce_scatter.") + + def _reduce_scatter_base(self, output_buffer: core.Tensor, input_buffer: core.Tensor, opts: Optional[Any] = None): + raise NotImplementedError(f"Backend {self.get_backend_name()} does not support _reduce_scatter_base.") + + def reduce_scatter_tensor_coalesced(self, outputs: List[core.Tensor], inputs: List[core.Tensor], opts: Optional[Any] = None): + raise NotImplementedError(f"Backend {self.get_backend_name()} does not support reduce_scatter_tensor_coalesced.") + + def alltoall_base(self, output_buffer: core.Tensor, input_buffer: core.Tensor, output_split_sizes: List[int], input_split_sizes: List[int], opts: Optional[Any] = None): + raise NotImplementedError(f"Backend {self.get_backend_name()} does not support alltoall_base.") + + def alltoall(self, output_tensors: List[core.Tensor], input_tensors: List[core.Tensor], opts: Optional[Any] = None): + raise NotImplementedError(f"Backend {self.get_backend_name()} does not support alltoall.") + + def monitored_barrier(self, opts: Optional[Any] = None, wait_all_ranks=False): + raise NotImplementedError(f"Backend {self.get_backend_name()} does not support monitoredBarrier, only GLOO supports monitored barrier.") + + def set_sequence_number_for_group(self): + raise NotImplementedError(f"Backend {self.get_backend_name()} does not yet support sequence numbers.") + + def get_sequence_number_for_group(self) -> int: + raise NotImplementedError(f"Backend {self.get_backend_name()} does not yet support sequence numbers.") + + def send(self, tensors: List[core.Tensor], dst_rank: int, tag: int): + raise NotImplementedError(f"Backend {self.get_backend_name()} does not support send.") + + def recv(self, tensors: List[core.Tensor], src_rank: int, tag: int): + raise NotImplementedError(f"Backend {self.get_backend_name()} does not support recv.") + + def recv_anysource(self, tensors: List[core.Tensor], tag: int): + raise NotImplementedError(f"Backend {self.get_backend_name()} does not support recvAnysource.") + + def barrier(self, opts: Optional[Any] = None): + raise NotImplementedError(f"Backend {self.get_backend_name()} does not support barrier.") + + def register_on_completion_hook(self, hook: Callable): + raise NotImplementedError(f"Only ProcessGroupNCCL supports onCompletion hook, but got {self.get_backend_name()} backend.") + + def wait_for_pending_works(self): + raise NotImplementedError(f"Only ProcessGroupNCCL supports waitForPendingWorks, but got {self.get_backend_name()} backend.") + + def enable_collectives_timing(self): + raise NotImplementedError(f"Backend {self.get_backend_name()} is missing implementation of enableCollectivesTiming.") + + def has_hooks(self) -> bool: + return self.on_completion_hook_ is not None + + def set_group_uid(self, pg_uid: str): + self.pg_uid_ = pg_uid + + def get_group_uid(self) -> str: + return self.pg_uid_ + + def set_group_desc(self, desc: str): + self.pg_desc_ = desc + + def get_group_desc(self) -> str: + return self.pg_desc_ + + def get_bound_device_id(self) -> Optional[core.device]: + return self.bound_device_id_ + + def eager_connect_single_device(self, device: core.device): + pass + + def set_bound_device_id(self, device: Optional[core.device]): + if device: + assert device.index is not None, "setBoundDeviceId must have an index" + self.bound_device_id_ = device + +# Example subclass implementation (e.g., for NCCL, GLOO) +class NCCLBackend(Backend): + def __init__(self, rank: int, size: int): + super().__init__(rank, size) + + def get_backend_name(self) -> str: + return "NCCL" + + def start_coalescing(self): + pass + + def end_coalescing(self): + pass diff --git a/mindnlp/core/distributed/c10d/prefix_store.py b/mindnlp/core/distributed/c10d/prefix_store.py new file mode 100644 index 000000000..4e80a5a40 --- /dev/null +++ b/mindnlp/core/distributed/c10d/prefix_store.py @@ -0,0 +1,69 @@ +from typing import List +from .store import Store + +class PrefixStore(Store): + def __init__(self, prefix: str, store: Store): + self.prefix_ = prefix + self.store_ = store + + def join_key(self, key: str) -> str: + return f"{self.prefix_}/{key}" + + def join_keys(self, keys: List[str]) -> List[str]: + return [self.join_key(key) for key in keys] + + def set(self, key: str, value: List[int]): + self.store_.set(self.join_key(key), value) + + def compare_set(self, key: str, expected_value: List[int], desired_value: List[int]) -> List[int]: + return self.store_.compare_set(self.join_key(key), expected_value, desired_value) + + def get(self, key: str) -> List[int]: + return self.store_.get(self.join_key(key)) + + def add(self, key: str, value: int) -> int: + return self.store_.add(self.join_key(key), value) + + def delete_key(self, key: str) -> bool: + return self.store_.delete_key(self.join_key(key)) + + def get_num_keys(self) -> int: + return self.store_.get_num_keys() + + def check(self, keys: List[str]) -> bool: + return self.store_.check(self.join_keys(keys)) + + def wait(self, keys: List[str]): + self.store_.wait(self.join_keys(keys)) + + def wait_with_timeout(self, keys: List[str], timeout: int): + self.store_.wait(self.join_keys(keys), timeout) + + def get_timeout(self) -> int: + return self.store_.get_timeout() + + def set_timeout(self, timeout: int): + self.store_.set_timeout(timeout) + + def append(self, key: str, value: List[int]): + self.store_.append(self.join_key(key), value) + + def multi_get(self, keys: List[str]) -> List[List[int]]: + return self.store_.multi_get(self.join_keys(keys)) + + def multi_set(self, keys: List[str], values: List[List[int]]): + self.store_.multi_set(self.join_keys(keys), values) + + def has_extended_api(self) -> bool: + return self.store_.has_extended_api() + + def get_underlying_store(self) -> Store: + return self.store_ + + def get_underlying_non_prefix_store(self) -> Store: + store = self.store_ + while isinstance(store, PrefixStore): + store = store.get_underlying_store() + if store is None: + raise ValueError("Underlying Non-PrefixStore shouldn't be null.") + return store \ No newline at end of file diff --git a/mindnlp/core/distributed/c10d/process_group.py b/mindnlp/core/distributed/c10d/process_group.py new file mode 100644 index 000000000..047a37b4d --- /dev/null +++ b/mindnlp/core/distributed/c10d/process_group.py @@ -0,0 +1,293 @@ +from mindnlp import core +from mindnlp.core import Tensor +from typing import List, Optional, Dict, Any +from enum import Enum + +from mindnlp.core.executor import execute + + +class BackendType(Enum): + UNDEFINED = 0 + GLOO = 1 + NCCL = 2 + UCC = 3 + MPI = 4 + CUSTOM = 5 + + +def backend_type_to_string(backend_type: BackendType) -> str: + if backend_type == BackendType.GLOO: + return "gloo" + elif backend_type == BackendType.NCCL: + return "nccl" + elif backend_type == BackendType.UCC: + return "ucc" + elif backend_type == BackendType.MPI: + return "mpi" + elif backend_type == BackendType.UNDEFINED: + return "undefined" + elif backend_type == BackendType.CUSTOM: + return "custom" + else: + raise ValueError("Unknown backend type!") + + +def str_to_backend_type(backend: str) -> BackendType: + if backend == "undefined": + return BackendType.UNDEFINED + elif backend == "gloo": + return BackendType.GLOO + elif backend == "nccl": + return BackendType.NCCL + elif backend == "ucc": + return BackendType.UCC + elif backend == "mpi": + return BackendType.MPI + else: + return BackendType.CUSTOM + + +class ProcessGroup: + class BackendType(Enum): + UNDEFINED = 0 + GLOO = 1 + NCCL = 2 + UCC = 3 + MPI = 4 + CUSTOM = 5 + + def __init__(self, store: Optional[Any] = None, rank: int = 0, size: int = 0): + self.store = store + self._name = self.store.prefix_[:-1] + self._rank = rank + self._size = size + self.backend_type = BackendType.UNDEFINED + self.device_type_to_backend = {} + self.backend_type_to_backend = {} + self.device_types = set() + self.pg_desc = "" + self.dist_debug_level = "Off" + + def rank(self): + return self._rank + + def size(self): + return self._size + + def get_rank(self) -> int: + return self._rank + + def get_size(self) -> int: + return self._size + + def get_backend_name(self) -> str: + return backend_type_to_string(self.backend_type) + + def set_backend(self, device_type, backend_type: BackendType, backend: Optional[Any] = None): + self.device_type_to_backend[device_type] = backend_type + self.device_types.add(device_type) + + if backend_type in self.backend_type_to_backend: + existing_backend = self.backend_type_to_backend[backend_type] + self.device_type_to_backend[device_type] = existing_backend + else: + if backend: + self.device_type_to_backend[device_type] = backend + self.backend_type_to_backend[backend_type] = backend + + def get_backend(self, device_type) -> Any: + if device_type in self.device_type_to_backend: + return self.device_type_to_backend[device_type] + else: + raise ValueError(f"No backend found for device type {device_type}") + + def start_coalescing(self, device_type): + backend = self.get_backend(device_type) + backend.start_coalescing() + + def end_coalescing(self, device_type): + backend = self.get_backend(device_type) + return backend.end_coalescing() + + def broadcast(self, tensors: List[Tensor], opts: Any) -> Any: + tensor = tensors[0] + _, work = execute('dist_comm_broadcast', tensor, opts.rootRank, self._name, device=self.device) + return work + + def allreduce(self, tensors: List[Tensor], opts: Any) -> Any: + tensor = tensors[0] + _, handle = execute('dist_comm_all_reduce', tensor, opts.reduceOp, self._name, device=self.device) + return handle + + def _allgather_base(self, output_tensor: Tensor, input_tensor: Tensor, opts: Any=None): + input_size = (-1,) + output_rank = output_tensor.ndim - 1 + if output_rank > 0: + input_size = input_size + input_tensor.shape[input_tensor.ndim - output_rank:] + _, handle = execute('dist_comm_all_gather_into_tensor', output_tensor, input_tensor.view(input_size), self._size, self._name, device=self.device) + return handle + + def allgather(self, output_tensors: List[List[Tensor]], input_tensors: List[Tensor], opts: Any=None) -> Any: + tensor_list = output_tensors[0] + tensor = input_tensors[0] + _, handle = execute('dist_comm_all_gather', tensor_list, tensor, self._size, self._name, device=self.device) + return handle + + def reduce(self, tensors: List[Tensor], opts: Any) -> Any: + out = reduce(tensors[0], opts.rootRank, opts.reduceOp, self._name) + return out + + def gather(self, output_tensors, input_tensors, opts): + # # do not use mindspore.communication.gather because not support uint8 + tensor = input_tensors[0] + gather_list = output_tensors[0] + + _, work = execute('dist_comm_gather', tensor, gather_list, self._size, opts.rootRank, self._rank, self._name, device=self.device) + return work + + def scatter(self, output_tensors: List[Tensor], input_tensors: List[List[Tensor]], opts: Any) -> Any: + tensor = output_tensors[0] + scatter_list = input_tensors[0] + _, work = execute('dist_comm_scatter', tensor, scatter_list, self._size, opts.rootRank, self._rank, self._name, device=self.device) + return work + + def reduce_scatter(self, output_tensors: List[Tensor], input_tensors: List[List[Tensor]], opts: Any) -> Any: + output = output_tensors[0] + input_list = input_tensors[0] + _, work = execute('dist_comm_reduce_scatter', output, input_list, self._size, opts.reduceOp, self._name, device=self.device) + if allow_inflight_collective_as_graph_input(): + for tensor in output_tensors: + register_work(tensor, work) + return work + + def _reduce_scatter_base(self, output_tensor, input_tensor, opts: Any): + _, work = execute('dist_comm_reduce_scatter_tensor', output_tensor, input_tensor, self._size, opts.reduceOp, self._name, device=self.device) + return work + + def barrier(self, opts: Any) -> Any: + _, work = execute('dist_comm_barrier', self._name, device=self.device) + return work + + def recv(self, tensors, srcRank, tag): + tensor = tensors[0] + _, work = execute('dist_comm_irecv', tensor, tag, srcRank, self._name, device=self.device) + return work + + def send(self, tensors: List[Tensor], dstRank: int, tag: int): + tensor = tensors[0] + _, handle = execute('dist_comm_isend', tensor, dstRank, self._name, tag, device=self.device) + return handle + + def get_device_types(self) -> List[Any]: + return list(self.device_types) + + def set_group_name(self, name: str): + for backend in self.device_type_to_backend.values(): + backend.set_group_uid(name) + + def get_group_name(self) -> str: + return self.device_type_to_backend[next(iter(self.device_type_to_backend))].get_group_uid() + + def set_group_desc(self, desc: str): + self.pg_desc = desc + for backend in self.device_type_to_backend.values(): + backend.set_group_desc(desc) + + def enable_collectives_timing(self): + for backend in self.device_type_to_backend.values(): + backend.enable_collectives_timing() + + def release_resources(self): + self.device_type_to_backend.clear() + self.backend_type_to_backend.clear() + self.store = None + + def _register_backend(self, device, backend_type, backend_class): + self.device = device + +class WorkRegistry: + def __init__(self): + self.registry = {} + self.allow_inflight_collective_as_graph_input = False + + def register_work(self, tensor: Tensor, work: Any): + if not tensor.has_storage(): + print(f"Warning: Tensor {tensor} has no storage!") + return + storage = tensor.storage().getWeakStorageImpl() + if storage not in self.registry: + self.registry[storage] = [work] + else: + if work not in self.registry[storage]: + self.registry[storage].append(work) + + def pop_works(self, tensor: Tensor): + storage = tensor.storage().getWeakStorageImpl() + if storage in self.registry: + works = self.registry.pop(storage) + return works + return [] + + def unregister_work(self, work: Any): + for storage, works in list(self.registry.items()): + self.registry[storage] = [w for w in works if w != work] + if not self.registry[storage]: + del self.registry[storage] + + def get_work_registry_size(self): + return sum(len(works) for works in self.registry.values()) + + def set_allow_inflight_collective_as_graph_input(self, value: bool): + self.allow_inflight_collective_as_graph_input = value + + # @property + # def allow_inflight_collective_as_graph_input(self) -> bool: + # return self.allow_inflight_collective_as_graph_input + + def __del__(self): + if self.get_work_registry_size() > 0: + print("Warning: Some work objects were not awaited!") + + +# Global WorkRegistry +process_registry = WorkRegistry() + +# Helper functions +def register_work(tensor: Tensor, work: Any): + process_registry.register_work(tensor, work) + + +def wait_tensor(tensor: Tensor) -> Tensor: + works = process_registry.pop_works(tensor) + for work in works: + work.wait() + return tensor + + +def unregister_work(work: Any): + process_registry.unregister_work(work) + + +def get_work_registry_size() -> int: + return process_registry.get_work_registry_size() + + +def set_allow_inflight_collective_as_graph_input(value: bool): + process_registry.set_allow_inflight_collective_as_graph_input(value) + + +def allow_inflight_collective_as_graph_input() -> bool: + return process_registry.allow_inflight_collective_as_graph_input() + + +def create_tensor(device: Optional[Any] = None) -> Tensor: + # Placeholder function for tensor creation + if device: + return core.empty([1], device=device) + return core.empty([1]) + + +def get_backend_op(name: str): + # Placeholder for fetching backend operation + # Would need to map to actual dispatcher + pass diff --git a/mindnlp/core/distributed/c10d/store.py b/mindnlp/core/distributed/c10d/store.py new file mode 100644 index 000000000..7f5a20a77 --- /dev/null +++ b/mindnlp/core/distributed/c10d/store.py @@ -0,0 +1,100 @@ +import time +from typing import List, Optional, Callable +from abc import ABC, abstractmethod + +class Store: + kDefaultTimeout = 300 # in seconds + kNoTimeout = 0 # No timeout + + def __init__(self, timeout: Optional[int] = kDefaultTimeout): + self.timeout_ = timeout + + def set(self, key: str, value: str): + self.set_bytes(key, value.encode()) + + @abstractmethod + def set_bytes(self, key: str, value: List[int]): + pass + + def compare_set(self, key: str, current_value: str, new_value: str) -> str: + current_bytes = current_value.encode() + new_bytes = new_value.encode() + value = self.compare_set_bytes(key, current_bytes, new_bytes) + return value.decode() + + @abstractmethod + def compare_set_bytes(self, key: str, current_value: List[int], new_value: List[int]) -> List[int]: + pass + + def get_to_str(self, key: str) -> str: + value = self.get(key) + return bytes(value).decode() + + @abstractmethod + def get(self, key: str) -> List[int]: + pass + + @abstractmethod + def add(self, key: str, value: int) -> int: + pass + + @abstractmethod + def delete_key(self, key: str) -> bool: + pass + + @abstractmethod + def check(self, keys: List[str]) -> bool: + pass + + @abstractmethod + def get_num_keys(self) -> int: + pass + + @abstractmethod + def wait(self, keys: List[str], timeout: Optional[int] = None): + pass + + def get_timeout(self) -> int: + return self.timeout_ + + def set_timeout(self, timeout: int): + self.timeout_ = timeout + + def watch_key(self, key: str, callback: Callable[[Optional[str], Optional[str]], None]): + raise NotImplementedError("watchKey is deprecated, no implementation supports it.") + + def append(self, key: str, value: List[int]): + expected = value + current = [] + current = self.compare_set_bytes(key, current, expected) + while current != expected: + expected = current + value + current = self.compare_set_bytes(key, current, expected) + + def multi_get(self, keys: List[str]) -> List[List[int]]: + result = [] + for key in keys: + result.append(self.get(key)) + return result + + def multi_set(self, keys: List[str], values: List[List[int]]): + for i in range(len(keys)): + self.set_bytes(keys[i], values[i]) + + def has_extended_api(self) -> bool: + return False + +class StoreTimeoutGuard: + def __init__(self, store: Store, timeout: int): + self.store_ = store + self.old_timeout_ = store.get_timeout() + store.set_timeout(timeout) + + def __del__(self): + self.store_.set_timeout(self.old_timeout_) + + def __copy__(self): + raise NotImplementedError("Copying not allowed") + + def __move__(self): + raise NotImplementedError("Moving not allowed") diff --git a/mindnlp/core/distributed/c10d/types.py b/mindnlp/core/distributed/c10d/types.py new file mode 100644 index 000000000..048e7be91 --- /dev/null +++ b/mindnlp/core/distributed/c10d/types.py @@ -0,0 +1,143 @@ +from mindnlp import core +from mindnlp.core import Tensor +from typing import Optional, List +from datetime import timedelta + + +class _SupplementBase: + def __del__(self): + pass + + +class NCCLPreMulSumSupplement(_SupplementBase): + def __init__(self, factor): + if isinstance(factor, float): + self.double_factor = factor + self.tensor_factor = None + elif isinstance(factor, Tensor): + self.double_factor = 0.0 + assert factor.numel() == 1, "Tensor must have exactly one element." + self.tensor_factor = factor + else: + raise ValueError("factor must be either a float or a Tensor") + + +class ReduceOp: + SUM = 'sum' + AVG = 1 + PRODUCT = 'prod' + MIN = 'min' + MAX = 'max' + BAND = 5 # Bitwise AND + BOR = 6 # Bitwise OR + BXOR = 7 # Bitwise XOR + PREMUL_SUM = 8 # Multiply by a user-supplied constant before summing. + UNUSED = 9 + + def __init__(self, op: Optional[int] = None, supplement: Optional[_SupplementBase] = None): + if op is None: + self.op = self.SUM + else: + if op == self.PREMUL_SUM: + raise ValueError("Use `make_ncc_premul_sum` to create an instance of ReduceOp with PREMUL_SUM") + self.op = op + + if supplement: + self.supplement = supplement + else: + self.supplement = None + + def __eq__(self, other): + if isinstance(other, int): + return self.op == other + elif isinstance(other, ReduceOp): + return self.op == other.op + return False + + def __int__(self): + return self.op + + +def make_ncc_premul_sum(factor): + rop = ReduceOp() + rop.op = ReduceOp.PREMUL_SUM + rop.supplement = NCCLPreMulSumSupplement(factor) + return rop + + +kUnsetTimeout = timedelta(milliseconds=-1) + + +class BroadcastOptions: + def __init__(self): + self.rootRank = 0 + self.rootTensor = 0 + self.timeout = kUnsetTimeout + self.asyncOp = True + + +class AllreduceOptions: + def __init__(self): + self.reduceOp = ReduceOp.SUM + self.timeout = kUnsetTimeout + self.sparseIndices = None + + +class AllreduceCoalescedOptions(AllreduceOptions): + pass + + +class ReduceOptions: + def __init__(self): + self.reduceOp = ReduceOp.SUM + self.rootRank = 0 + self.rootTensor = 0 + self.timeout = kUnsetTimeout + + +class AllgatherOptions: + def __init__(self): + self.timeout = kUnsetTimeout + self.asyncOp = True + + +class GatherOptions: + def __init__(self): + self.rootRank = 0 + self.timeout = kUnsetTimeout + + +class ScatterOptions: + def __init__(self): + self.rootRank = 0 + self.timeout = kUnsetTimeout + self.asyncOp = True + + +class ReduceScatterOptions: + def __init__(self): + self.reduceOp = ReduceOp.SUM + self.timeout = kUnsetTimeout + self.asyncOp = True + + +class AllToAllOptions: + def __init__(self): + self.timeout = kUnsetTimeout + + +class BarrierOptions: + def __init__(self): + self.device_ids = [] + self.timeout = kUnsetTimeout + self.device = None + + +class DistributedBackendOptions: + def __init__(self, store, group_rank, group_size, timeout, group_id, global_ranks_in_group): + self.store = store + self.group_rank = group_rank + self.group_size = group_size + self.timeout = timeout + self.group_id = group_id + self.global_ranks_in_group = global_ranks_in_group diff --git a/mindnlp/core/distributed/c10d/work.py b/mindnlp/core/distributed/c10d/work.py new file mode 100644 index 000000000..89e0d73cf --- /dev/null +++ b/mindnlp/core/distributed/c10d/work.py @@ -0,0 +1,198 @@ +from mindnlp import core +from mindnlp.core import Tensor +from enum import Enum +from typing import List, Optional, Callable +import time +import threading + + +class OpType(Enum): + BROADCAST = 0 + ALLREDUCE = 1 + ALLREDUCE_COALESCED = 2 + REDUCE = 3 + ALLGATHER = 4 + _ALLGATHER_BASE = 5 + ALLGATHER_COALESCED = 6 + GATHER = 7 + SCATTER = 8 + REDUCE_SCATTER = 9 + ALLTOALL_BASE = 10 + ALLTOALL = 11 + SEND = 12 + RECV = 13 + RECVANYSOURCE = 14 + BARRIER = 15 + _REDUCE_SCATTER_BASE = 16 + COALESCED = 17 + _ALLREDUCE_SPARSE = 18 + UNKNOWN = 100 + + +class WorkResult(Enum): + SUCCESS = 0 + TIMEOUT = 1 + COMM_ERROR = 2 + UNKNOWN = 100 + + +kNoTimeout = 0 # Default to 0 for no timeout + + +def op_type_to_string(op_type: OpType) -> str: + """Converts OpType to human-readable string.""" + return op_type.name + + +def is_p2p_op(op_type: OpType, batch_p2p=False) -> bool: + """Determines if the operation is point-to-point.""" + if batch_p2p: + return False + return op_type in {OpType.SEND, OpType.RECV, OpType.RECVANYSOURCE} + + +class Work: + def __init__(self, rank=-1, op_type=OpType.UNKNOWN, profiling_title=None, input_tensors=None): + self.rank = rank + self.op_type = op_type + self.completed = False + self.exception = None + self.cv = threading.Condition() + self.record_function_end_callback = None + self.input_tensors = input_tensors + + if profiling_title is not None: + # Simulate the profiling functionality in Python (simplified) + self.record_function_end_callback = lambda: print(f"Profiling {profiling_title} ended.") + + def is_completed(self) -> bool: + """Non-blocking check for completion.""" + with self.cv: + return self.completed + + def is_success(self) -> bool: + """Returns True if work is successful.""" + with self.cv: + return self.exception is None + + def exception(self): + """Returns the exception if work was unsuccessful.""" + with self.cv: + return self.exception + + def source_rank(self) -> int: + """Returns source rank for recv-from-any work.""" + raise NotImplementedError("sourceRank() is only available for recv or recv-from-any operations.") + + def result(self) -> List[Tensor]: + """Returns result tensors.""" + raise NotImplementedError("Result not implemented for this operation.") + + def synchronize(self): + """Ensure synchronization of operations on output tensors.""" + if self.is_completed() and self.is_success(): + if self.record_function_end_callback: + self.record_function_end_callback() + + def wait(self, timeout=kNoTimeout) -> bool: + """Wait for completion of the work.""" + with self.cv: + if timeout == kNoTimeout: + self.cv.wait_for(lambda: self.completed) + else: + self.cv.wait(timeout) + if not self.completed: + raise TimeoutError("Operation timed out!") + if self.exception: + raise self.exception + return self.is_success() + + def abort(self): + """Aborts the work.""" + raise NotImplementedError("Abort is not implemented.") + + def get_future(self): + """Returns a Future object associated with the work.""" + raise NotImplementedError("getFuture is not implemented.") + + def get_future_result(self): + """Returns a Future object that marks success or failure.""" + raise NotImplementedError("getFutureResult is not implemented.") + + def finish(self, exception=None): + """Complete the work and notify waiting threads.""" + with self.cv: + self.completed = True + self.exception = exception + if self.record_function_end_callback: + self.record_function_end_callback() + self.cv.notify_all() + + def finish_and_throw(self, exception): + """Finish work and throw exception.""" + with self.cv: + self.completed = True + self.exception = exception + if self.record_function_end_callback: + self.record_function_end_callback() + if self.exception: + raise self.exception + + def get_duration(self) -> float: + """Get the duration of the work.""" + raise NotImplementedError("This backend doesn't support getDuration.") + + def get_sequence_number(self) -> int: + """Get the sequence number for the work.""" + raise NotImplementedError("This backend doesn't support getSequenceNumber.") + + @staticmethod + def create_from_future(future): + """Create a Work object from a Future.""" + return FutureWrappingWork(future) + + +class FutureWrappingWork(Work): + def __init__(self, fut): + super().__init__() + self._fut = fut + + def is_completed(self) -> bool: + """Checks if the future is completed.""" + return self._fut.completed() + + def is_success(self) -> bool: + """Checks if the future has succeeded.""" + return self._fut.has_value() + + def exception(self): + """Returns exception if any.""" + return self._fut.exception_ptr() + + def source_rank(self) -> int: + raise NotImplementedError("FutureWrappingWork::sourceRank() not implemented") + + def result(self) -> List[Tensor]: + return self._fut.value().to_py_object_holder().extract_tensors() + + def wait(self, timeout=kNoTimeout) -> bool: + """Waits for the future to complete.""" + if timeout != kNoTimeout: + raise NotImplementedError("Timeout handling not implemented for FutureWrappingWork.") + self._fut.wait() + return True + + def abort(self): + raise NotImplementedError("abort not implemented for FutureWrappingWork.") + + def get_future(self): + return self._fut + + +class WorkInfo: + def __init__(self, op_type: OpType, seq: int, time_started, time_finished, active_duration): + self.op_type = op_type + self.seq = seq + self.time_started = time_started + self.time_finished = time_finished + self.active_duration = active_duration diff --git a/mindnlp/core/distributed/c10d_logger.py b/mindnlp/core/distributed/c10d_logger.py new file mode 100644 index 000000000..6ccc6d934 --- /dev/null +++ b/mindnlp/core/distributed/c10d_logger.py @@ -0,0 +1,98 @@ +#!/usr/bin/env python3 +# mypy: allow-untyped-defs + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import functools +import logging +from typing import Any, Callable, Dict, List, Tuple, TypeVar +from typing_extensions import ParamSpec + +from mindnlp import core +from mindnlp.core import distributed as dist +from mindnlp.core.distributed.logging_handlers import _log_handlers +# from core.monitor import _WaitCounter + + +__all__: List[str] = [] + +_DEFAULT_DESTINATION = "default" + + +def _get_or_create_logger(destination: str = _DEFAULT_DESTINATION) -> logging.Logger: + logging_handler, log_handler_name = _get_logging_handler(destination) + logger = logging.getLogger(f"c10d-{log_handler_name}") + logger.setLevel(logging.DEBUG) + formatter = logging.Formatter( + "%(asctime)s %(filename)s:%(lineno)s %(levelname)s p:%(processName)s t:%(threadName)s: %(message)s" + ) + logging_handler.setFormatter(formatter) + logger.propagate = False + logger.addHandler(logging_handler) + return logger + + +def _get_logging_handler( + destination: str = _DEFAULT_DESTINATION, +) -> Tuple[logging.Handler, str]: + log_handler = _log_handlers[destination] + log_handler_name = f"{type(log_handler).__name__}-{destination}" + return (log_handler, log_handler_name) + + +global _c10d_logger +_c10d_logger = _get_or_create_logger() + + +def _get_msg_dict(func_name, *args, **kwargs) -> Dict[str, Any]: + if dist.is_initialized(): + group = kwargs.get("group") or kwargs.get("process_group") + msg_dict = { + "func_name": f"{func_name}", + "pg_name": f"{dist._get_process_group_name(kwargs.get('pg'))}", # type: ignore[arg-type] + "backend": f"{dist.get_backend(group)}", + "world_size": f"{dist.get_world_size()}", + "group_size": f"{dist.get_world_size(group)}", + "global_rank": f"{dist.get_rank()}", + "local_rank": f"{dist.get_rank(group)}", + } + if msg_dict["backend"] == "nccl": + nccl_version = core.cuda.nccl.version() + msg_dict["nccl_version"] = ".".join(str(v) for v in nccl_version) + else: + msg_dict = { + "func_name": f"{func_name}", + } + return msg_dict + + +_T = TypeVar("_T") +_P = ParamSpec("_P") + + +def _exception_logger(func: Callable[_P, _T]) -> Callable[_P, _T]: + @functools.wraps(func) + def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _T: + try: + return func(*args, **kwargs) + except Exception as error: + msg_dict = _get_msg_dict(func.__name__, *args, **kwargs) + msg_dict["error"] = f"{error}" + _c10d_logger.debug(msg_dict) + raise + + return wrapper + + +def _time_logger(func: Callable[_P, _T]) -> Callable[_P, _T]: + @functools.wraps(func) + def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _T: + # with _WaitCounter(f"pycore.wait_counter.c10d.{func.__name__}").guard(): + func_return = func(*args, **kwargs) + return func_return + + return wrapper diff --git a/mindnlp/core/distributed/checkpoint/__init__.py b/mindnlp/core/distributed/checkpoint/__init__.py new file mode 100644 index 000000000..3b63a5b7f --- /dev/null +++ b/mindnlp/core/distributed/checkpoint/__init__.py @@ -0,0 +1,14 @@ +from .api import CheckpointException +from .default_planner import DefaultLoadPlanner, DefaultSavePlanner +from .filesystem import FileSystemReader, FileSystemWriter +from .metadata import ( + BytesStorageMetadata, + ChunkStorageMetadata, + Metadata, + TensorStorageMetadata, +) +# from .optimizer import load_sharded_optimizer_state_dict +from .planner import LoadPlan, LoadPlanner, ReadItem, SavePlan, SavePlanner, WriteItem +from .state_dict_loader import load, load_state_dict +# from .state_dict_saver import async_save, save, save_state_dict +# from .storage import StorageReader, StorageWriter diff --git a/mindnlp/core/distributed/checkpoint/_checkpointer.py b/mindnlp/core/distributed/checkpoint/_checkpointer.py new file mode 100644 index 000000000..3d6728769 --- /dev/null +++ b/mindnlp/core/distributed/checkpoint/_checkpointer.py @@ -0,0 +1,100 @@ +from concurrent.futures import Future +from typing import Any, Dict, List, Optional + +from mindnlp import core.distributed as dist +from mindnlp import core.distributed.checkpoint.state_dict_loader as loader +from mindnlp import core.distributed.checkpoint.state_dict_saver as saver +from core.distributed.checkpoint.metadata import Metadata, STATE_DICT_TYPE +from core.distributed.checkpoint.storage import ( + LoadPlanner, + SavePlanner, + StorageReader, + StorageWriter, +) + + +__all__: List[str] = [] + + +class _Checkpointer: + """This base class specefies a high level API for saving and loading + distributed `state_dict` 's. It provides an abstraction over the low-level APIs + provided by :py:mod:`core.distributed.checkpoint.storage`, essentially calling + :py:meth: `core.distributed.state_dict_saver.save` and + :py:meth: `core.distributed.state_dict_loader.load` with the provided storage + readers and writers. + + .. warning:: + This feature is experimental and subject to removal/change. + + """ + + def __init__( + self, + storage_writer: StorageWriter, + storage_reader: StorageReader, + *, + process_group: Optional[dist.ProcessGroup] = None, + coordinator_rank: int = 0, + no_dist: bool = False, + load_planner: Optional[LoadPlanner] = None, + save_planner: Optional[SavePlanner] = None, + ): + """Initializes the Checkpointer instance. + + Args: + storage_writer: Instance of StorageWrite use to perform writes. + storage_reader: StorageReader used to load data from. + process_group: ProcessGroup to be used for cross-rank synchronization. + coordinator_rank: Rank to use to coordinate the checkpoint. rank0 is used by default. + no_dist: If ``True``, distributed checkpoint will not load in SPMD style. (Default: ``False``) + loader_planner: Instance of LoadPlanner to use when loading. + save_planner: Instance of SavePlanner to use when saving. + """ + self.storage_writer = storage_writer + self.storage_reader = storage_reader + self.process_group = process_group + self.coordinator_rank = coordinator_rank + self.no_dist = no_dist + self.load_planner = load_planner + self.save_planner = save_planner + + def save( + self, + state_dict: STATE_DICT_TYPE, + ) -> Metadata: + """Calls :py:meth: `core.distributed.state_dict_saver.save`. Utilizing values passed during initialization.""" + return saver.save( + state_dict, + self.storage_writer, + process_group=self.process_group, + coordinator_rank=self.coordinator_rank, + no_dist=self.no_dist, + planner=self.save_planner, + ) + + def async_save( + self, + state_dict: STATE_DICT_TYPE, + ) -> Future: + """ + Calls :py:meth: `core.distributed.state_dict_saver._async_save`. Utilizing values passed during initialization. + + Returns: + Future: A future holding the resultant Metadata object from `save`. + """ + return saver.async_save( + state_dict, + storage_writer=self.storage_writer, + process_group=self.process_group, + planner=self.save_planner, + ) + + def load(self, state_dict: Dict[str, Any]) -> None: + """Calls :py:meth: `core.distributed.state_dict_loader.load`. Utilizing values passed during initialization.""" + loader.load( + state_dict, + storage_reader=self.storage_reader, + process_group=self.process_group, + planner=self.load_planner, + ) diff --git a/mindnlp/core/distributed/checkpoint/_dedup_save_plans.py b/mindnlp/core/distributed/checkpoint/_dedup_save_plans.py new file mode 100644 index 000000000..674f2d60e --- /dev/null +++ b/mindnlp/core/distributed/checkpoint/_dedup_save_plans.py @@ -0,0 +1,60 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +import dataclasses +from collections import defaultdict +from typing import Dict, List, Set, TYPE_CHECKING + +from core.distributed.checkpoint.planner import SavePlan, WriteItem + + +if TYPE_CHECKING: + from core.distributed.checkpoint.metadata import MetadataIndex + +__all__ = ["dedup_save_plans"] + + +def dedup_save_plans( + all_plans: List[SavePlan], + save_to_lowest_rank: bool = False, +) -> List[SavePlan]: + """ + Removes duplicate entries from appearing on multiple SavePlans. For each duplicate across + a set of SavePlans, only the smallest SavePlan in terms of planned storage keeps the entry. + """ + + write_item_to_plan_indices: Dict[MetadataIndex, Set[int]] = defaultdict(set) + write_item_idx_to_write_item: Dict[MetadataIndex, WriteItem] = {} + for plan_idx, plan in enumerate(all_plans): + for write_item in plan.items: + # map each write item to its plan + write_item_to_plan_indices[write_item.index].add(plan_idx) + write_item_idx_to_write_item[write_item.index] = write_item + + # put item in the plan with the smallest size and remove it from the other plan_indices + to_remove: List[Set] = [set() for _ in range(len(all_plans))] + plan_to_size = [0] * len(all_plans) + for write_item_idx, plan_indices in write_item_to_plan_indices.items(): + if save_to_lowest_rank: + select_plan_idx = min(plan_indices) + else: + select_plan_idx = min( + plan_indices, key=lambda plan_idx: plan_to_size[plan_idx] + ) + + write_item = write_item_idx_to_write_item[write_item_idx] + # essentially ignores the storage size of anything that is not a tensor, since + # we don't know how much storage they represent + plan_to_size[select_plan_idx] += write_item.tensor_storage_size() or 1 + + plan_indices.remove(select_plan_idx) + for plan_idx in plan_indices: + to_remove[plan_idx].add(write_item_idx) + + for plan_idx, remove_set in enumerate(to_remove): + new_items = [ + write_item + for write_item in all_plans[plan_idx].items + if write_item.index not in remove_set + ] + all_plans[plan_idx] = dataclasses.replace(all_plans[plan_idx], items=new_items) + + return all_plans diff --git a/mindnlp/core/distributed/checkpoint/_dedup_tensors.py b/mindnlp/core/distributed/checkpoint/_dedup_tensors.py new file mode 100644 index 000000000..f3acb3024 --- /dev/null +++ b/mindnlp/core/distributed/checkpoint/_dedup_tensors.py @@ -0,0 +1,62 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +import dataclasses +import logging +from typing import Dict, List, TYPE_CHECKING + +from core.distributed.checkpoint.planner import SavePlan + + +if TYPE_CHECKING: + from core.distributed.checkpoint.metadata import MetadataIndex + +__all__ = ["dedup_tensors"] + + +def init_logger() -> logging.Logger: + logger = logging.getLogger(__name__) + level = logging.INFO + logger.setLevel(level) + console = logging.StreamHandler() + formatter = logging.Formatter( + "%(asctime)s %(filename)s:%(lineno)s %(levelname)s p:%(processName)s t:%(threadName)s: %(message)s" + ) + console.setFormatter(formatter) + console.setLevel(level) + logger.addHandler(console) + logger.propagate = False + return logger + + +logger = init_logger() + + +# TODO add docstring for dedup_tensors +def dedup_tensors(all_plans: List[SavePlan]) -> List[SavePlan]: + all_plans = list(all_plans) + key_to_plan: Dict[MetadataIndex, List[int]] = {} + for plan_idx, plan in enumerate(all_plans): + for write_item in plan.items: + key_to_plan.setdefault(write_item.index, []).append(plan_idx) + + replicated_items = {k: v for k, v in key_to_plan.items() if len(v) > 1} + + # Remove duplicates by always keeping the first entry. + # Compute the per-rank remove set. + plan_to_keys: Dict[int, List[MetadataIndex]] = {} + for key, plans in replicated_items.items(): + for plan_idx in plans[1:]: + plan_to_keys.setdefault(plan_idx, []).append(key) + if len(plan_to_keys) > 0: + logger.info("Duplicate keys to remove: %s", plan_to_keys) + + for plan_idx, keys in plan_to_keys.items(): + key_set = set(keys) + # rewrite items and remove elements + new_items = [ + write_item + for write_item in all_plans[plan_idx].items + if write_item.index not in key_set + ] + all_plans[plan_idx] = dataclasses.replace(all_plans[plan_idx], items=new_items) + + return all_plans diff --git a/mindnlp/core/distributed/checkpoint/_fsspec_filesystem.py b/mindnlp/core/distributed/checkpoint/_fsspec_filesystem.py new file mode 100644 index 000000000..b635cddda --- /dev/null +++ b/mindnlp/core/distributed/checkpoint/_fsspec_filesystem.py @@ -0,0 +1,151 @@ +# Mypy will not try inferring the types of any 3rd party libraries installed. +# mypy: ignore-errors + +import io +import os +from contextlib import contextmanager +from pathlib import Path +from typing import Generator, Optional, TYPE_CHECKING, Union + +from fsspec.core import url_to_fs + +from core.distributed.checkpoint.filesystem import ( + FileSystemBase, + FileSystemReader, + FileSystemWriter, +) + + +if TYPE_CHECKING: + from fsspec import AbstractFileSystem + + +__all__ = [ + "FsspecWriter", + "FsspecReader", +] + + +class FileSystem(FileSystemBase): + def __init__(self) -> None: + self.fs: Optional[AbstractFileSystem] = None + + @contextmanager + def create_stream( + self, path: Union[str, os.PathLike], mode: str + ) -> Generator[io.IOBase, None, None]: + assert self.fs is not None + path = os.fspath(path) + + # fsspec does not support concurrent transactions, and not all + # AbstractFileSystem have working rollback implementations, so + # just manually delete the file if necessary on errors. + with self.fs.open(path, mode) as stream: + try: + yield stream + except: # noqa: B001,E722 + if "w" or "+" or "a" in mode: # cleanup file if not read-only + try: + self.rm_file(path) + except: # noqa: B001,E722 + pass + raise + + def concat_path( + self, path: Union[str, os.PathLike], suffix: str + ) -> Union[str, os.PathLike]: + return os.path.join(path, suffix) + + def init_path(self, path: Union[str, os.PathLike]) -> Union[str, os.PathLike]: + self.fs, _ = url_to_fs(path) + return path + + def rename( + self, path: Union[str, os.PathLike], new_path: Union[str, os.PathLike] + ) -> None: + self.fs.rename(path, new_path) + + def mkdir(self, path: Union[str, os.PathLike]) -> None: + self.fs.makedirs(path, exist_ok=True) + + @classmethod + def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool: + if isinstance(checkpoint_id, Path): + return False + + try: + url_to_fs(checkpoint_id) + except ValueError: + return False + + return True + + def exists(self, path: Union[str, os.PathLike]) -> bool: + return self.fs.exists(path) + + def rm_file(self, path: Union[str, os.PathLike]) -> None: + self.fs.rm(path) + + +# TODO: add the dcp.async_save mixin +class FsspecWriter(FileSystemWriter): + """ + Basic implementation of StorageWriter using FFspec. + + This implementation makes the following assumptions and simplifications: + + * The checkpoint path is an empty or non-existing directory. + * File creation is atomic + + The checkpoint consist of one file per write request plus + a `.metadata` file with the serialized metadata. + + """ + + def __init__( + self, + path: Union[str, os.PathLike], + single_file_per_rank: bool = True, + sync_files: bool = True, + thread_count: int = 1, + per_thread_copy_ahead: int = 10_000_000, + overwrite: bool = True, + ) -> None: + """ + Initialize the writer pointing to `path`. + + Args: + path: directory where the checkpoint will be written to. + single_file_per_rank: Produce one file per rank instead of one file per tensor/blob. Default to True. + sync_files : force files to be synced to permanent storage. Default to True. + thread_count: Number of IO threads to use to write. Default to 1. + per_thread_copy_ahead: How many bytes to copy from the GPU ahead of saving then. Default 10Mb. + overwrite: Whether to allow overwriting existing checkpoints. Defaults to True. + + N. B. If sync_files is disabled, there's no guarantee that the checkpoint will be consistent in the case of a failure. + """ + super().__init__( + path, + single_file_per_rank, + sync_files, + thread_count, + per_thread_copy_ahead, + overwrite=overwrite, + ) + self.fs = FileSystem() + self.path = self.fs.init_path(path) + + @classmethod + def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool: + return FileSystem.validate_checkpoint_id(checkpoint_id) + + +class FsspecReader(FileSystemReader): + def __init__(self, path: Union[str, os.PathLike]) -> None: + super().__init__(path) + self.fs = FileSystem() + self.path = self.fs.init_path(path) + + @classmethod + def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool: + return FileSystem.validate_checkpoint_id(checkpoint_id) diff --git a/mindnlp/core/distributed/checkpoint/_nested_dict.py b/mindnlp/core/distributed/checkpoint/_nested_dict.py new file mode 100644 index 000000000..ea6b84947 --- /dev/null +++ b/mindnlp/core/distributed/checkpoint/_nested_dict.py @@ -0,0 +1,70 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +from typing import Dict, Tuple + +from core.distributed.checkpoint.metadata import STATE_DICT_TYPE + +from . import _version +from ._traverse import ( + OBJ_PATH, + set_element, + STATE_DICT_ITEM, + traverse_state_dict, + traverse_state_dict_v_2_3, +) + + +""" +TODO: +Need to add ability to handle tuple, OrderedDict, NamedTuple. +Update mappings from dict to a class. +Change set_element to recreate the right type for tuple, OrderedDict, and NamedTuple. +""" + + +FLATTEN_MAPPING = Dict[str, OBJ_PATH] + + +# TODO: Update Docstring for nested_dict.py +def flatten_state_dict( + state_dict: STATE_DICT_TYPE, +) -> Tuple[STATE_DICT_TYPE, FLATTEN_MAPPING]: + """ + Flatten ``state_dict`` made of nested dicts and lists into a top level dictionary. + + Use ``unflatten_state_dict`` to revert this process. + Returns: + A tuple with the flatten state_dict and a mapping from original to new state_dict. + N.B. The new keys are derived from the object paths, joined by dot. + For example: ``{ 'a': {'b':...}}`` results in the key `a.b`. + """ + flattened: STATE_DICT_TYPE = {} + mappings: FLATTEN_MAPPING = {} + + def flat_copy(path: OBJ_PATH, value: STATE_DICT_ITEM) -> None: + new_fqn = ".".join(map(str, path)) + if new_fqn in flattened: + raise ValueError(f"duplicated flatten key {new_fqn}") + flattened[new_fqn] = value + mappings[new_fqn] = path + + # We started to flatten dictionary since v2.4. But in order to not break + # the checkpoints that were saved before v2.4, we need to keep the old + # traversal so that we can reconstruct those checkpoints. + use_v_2_3 = ( + _version._derived_version is not None and _version._derived_version == "2_3" + ) + if use_v_2_3: + traverse_state_dict_v_2_3(state_dict, flat_copy) + else: + traverse_state_dict(state_dict, flat_copy) + return flattened, mappings + + +def unflatten_state_dict( + state_dict: STATE_DICT_TYPE, mapping: FLATTEN_MAPPING +) -> STATE_DICT_TYPE: + """Restore the original nested state_dict according to ``mapping`` and the flattened ``state_dict``.""" + nested: STATE_DICT_TYPE = {} + for key, value in state_dict.items(): + set_element(nested, mapping[key], value) + return nested diff --git a/mindnlp/core/distributed/checkpoint/_sharded_tensor_utils.py b/mindnlp/core/distributed/checkpoint/_sharded_tensor_utils.py new file mode 100644 index 000000000..0879a9544 --- /dev/null +++ b/mindnlp/core/distributed/checkpoint/_sharded_tensor_utils.py @@ -0,0 +1,107 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates + +import copy +from typing import TYPE_CHECKING + +from mindnlp import core.distributed as dist +from core.distributed._shard.sharded_tensor import Shard, ShardedTensor, ShardMetadata +from core.distributed.checkpoint.metadata import STATE_DICT_TYPE +from core.distributed.remote_device import _remote_device + +from ._traverse import OBJ_PATH, set_element, STATE_DICT_ITEM, traverse_state_dict +from .utils import _element_wise_add, _normalize_device_info + + +if TYPE_CHECKING: + from core.distributed._shard.sharded_tensor.metadata import ShardedTensorMetadata + + +# TODO: We need to refactor this code. +def _flatten_sharded_tensors(state_dict: STATE_DICT_TYPE) -> STATE_DICT_TYPE: + r""" + Transform ``state_dict`` by flattening all nested ShardedTensor instances found. + + The resulting ShardedTensor instances are only correct regarding the local shard and + MUST not be used for any other purpose but checkpointing, as no operator will work with them. + + This function should be used in conjunction with a state_dict produced by FSDP's + StateDictType.SHARDED_STATE_DICT methods. + """ + new_state_dict: STATE_DICT_TYPE = {} + + def rewrite_dict(path: OBJ_PATH, value: STATE_DICT_ITEM) -> None: + if not isinstance(value, ShardedTensor): + set_element(new_state_dict, path, value) + return + shards = value.local_shards() + + if len(shards) == 0: + return + if len(shards) != 1: + set_element(new_state_dict, path, value) + return + + outer_shard = shards[0] + + inner_st = outer_shard.tensor + if not isinstance(inner_st, ShardedTensor): + set_element(new_state_dict, path, value) + return + + if len(inner_st.local_shards()) != 1: + raise ValueError("Cannot handle inner tensor with more than 1 shard") + inner_shard = inner_st.local_shards()[0] + + local_shards = [ + Shard( + tensor=inner_shard.tensor, + metadata=ShardMetadata( + shard_offsets=_element_wise_add( + outer_shard.metadata.shard_offsets, + inner_shard.metadata.shard_offsets, + ), + shard_sizes=inner_shard.metadata.shard_sizes, + placement=f"rank:{dist.get_rank()}/{inner_shard.tensor.device}", + ), + ) + ] + + st_meta: ShardedTensorMetadata = copy.deepcopy(value.metadata()) + other_rank = 0 if dist.get_rank() > 0 else 1 + device_info = _normalize_device_info(inner_shard.tensor.device.type, 0) + + # Remove the outer ST shard the inner ST covers + for i, shard_md in enumerate(st_meta.shards_metadata): + if shard_md.shard_offsets == outer_shard.metadata.shard_offsets: + st_meta.shards_metadata.pop(i) + break + + # Attribute other rank for the other shards + for shard_md in st_meta.shards_metadata: + shard_md.placement = _remote_device(f"rank:{other_rank}/{device_info}") + + # Add other inner shards from the inner tensor + for inner_md in inner_st.metadata().shards_metadata: + if inner_md.shard_offsets != inner_shard.metadata.shard_offsets: + st_meta.shards_metadata.append( + ShardMetadata( + shard_offsets=_element_wise_add( + outer_shard.metadata.shard_offsets, + inner_md.shard_offsets, + ), + shard_sizes=inner_md.shard_sizes, + placement=f"rank:{other_rank}/{device_info}", + ) + ) + + # Finally add this shard + st_meta.shards_metadata.append(local_shards[0].metadata) + + st = ShardedTensor._init_from_local_shards_and_global_metadata( + local_shards=local_shards, + sharded_tensor_metadata=st_meta, + ) + set_element(new_state_dict, path, st) + + traverse_state_dict(state_dict, rewrite_dict) + return new_state_dict diff --git a/mindnlp/core/distributed/checkpoint/_storage_utils.py b/mindnlp/core/distributed/checkpoint/_storage_utils.py new file mode 100644 index 000000000..a9261a5fc --- /dev/null +++ b/mindnlp/core/distributed/checkpoint/_storage_utils.py @@ -0,0 +1,49 @@ +import os +from typing import List, Type, Union + +from .filesystem import FileSystemReader, FileSystemWriter +from .storage import StorageReader, StorageWriter + + +def _storage_setup( + storage: Union[StorageReader, StorageWriter, None], + checkpoint_id: Union[str, os.PathLike, None], + reader: bool = False, +) -> Union[None, StorageReader, StorageWriter]: + if storage: + if checkpoint_id is not None: + storage.reset(checkpoint_id) + return storage + + if not checkpoint_id: + raise RuntimeError( + "`checkpoint_id` must be specificed if " + "storage_reader/storage_writer is None." + ) + + targets: List[Type[Union[StorageReader, StorageWriter]]] = [] + if reader: + targets = [ + FileSystemReader, + ] + else: + targets = [ + FileSystemWriter, + ] + try: + from ._fsspec_filesystem import FsspecReader, FsspecWriter + + targets.append(FsspecReader if reader else FsspecWriter) + except Exception: + pass + + for target in targets: + if target.validate_checkpoint_id(checkpoint_id): + storage = target(checkpoint_id) # type: ignore[call-arg] + storage.reset(checkpoint_id) + return storage + + raise RuntimeError( + "Cannot detect which StorageReader or StorageWriter to use. " + "Please specify the storage_reader/storage_writer." + ) diff --git a/mindnlp/core/distributed/checkpoint/_traverse.py b/mindnlp/core/distributed/checkpoint/_traverse.py new file mode 100644 index 000000000..cf6d11c6e --- /dev/null +++ b/mindnlp/core/distributed/checkpoint/_traverse.py @@ -0,0 +1,208 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +from typing import ( + Callable, + cast, + Collection, + List, + Mapping, + MutableMapping, + Optional, + Tuple, + TypeVar, + Union, +) + +from mindnlp import core +from core.distributed._shard.sharded_tensor.api import ShardedTensor +from core.distributed.checkpoint.metadata import STATE_DICT_TYPE +# from core.distributed.tensor import DTensor + + +PATH_ITEM = Union[str, int] +OBJ_PATH = Tuple[PATH_ITEM, ...] +T = TypeVar("T") + +STATE_DICT_ITEM = object +CONTAINER_TYPE = MutableMapping[PATH_ITEM, STATE_DICT_ITEM] + +__all__ = ["traverse_state_dict", "set_element", "get_element", "print_tensor"] + + +def _keep_visiting_tensors(value: STATE_DICT_ITEM) -> bool: + return isinstance(value, core.Tensor) + + +# TODO: update docstring for traverse.py +def traverse_state_dict( + state_dict: STATE_DICT_TYPE, + visitor: Callable[[OBJ_PATH, STATE_DICT_ITEM], None], + keep_traversing: Callable[[STATE_DICT_ITEM], bool] = _keep_visiting_tensors, +) -> None: + """ + Invoke ``visitor`` for each value recursively in ``state_dict``. + Mapping will be traversed and ``visitor`` will be applied to the leaf elements. + ``visitor`` will only be applied to elements in a list or a tuple, if the + container contains tensors or mappings. + """ + + def _is_terminal(value: STATE_DICT_ITEM) -> bool: + values: Collection[STATE_DICT_ITEM] + if isinstance(value, Mapping): + return False + elif isinstance(value, list): + values = value + else: + return True + + for entry in values: + if isinstance(entry, (Mapping, list)) and not _is_terminal(entry): + return False + if keep_traversing is not None and keep_traversing(entry): + return False + return True + + def _traverse_obj(path: OBJ_PATH, value: STATE_DICT_ITEM) -> None: + if isinstance(value, Mapping): + for k, v in value.items(): + _traverse_obj(path + (str(k),), v) + elif _is_terminal(value): + visitor(path, value) + elif isinstance(value, (list, tuple)): + for i, v in enumerate(value): + _traverse_obj(path + (i,), v) + + for key, value in state_dict.items(): + _traverse_obj((str(key),), value) + + +def traverse_state_dict_v_2_3( + state_dict: STATE_DICT_TYPE, + visitor: Callable[[OBJ_PATH, STATE_DICT_ITEM], None], + keep_traversing: Callable[[STATE_DICT_ITEM], bool] = _keep_visiting_tensors, +) -> None: + """ + Traversal is short-circuited when if finds a collection for which ``keep_visiting_tensors`` evaluates + to false for all elements. + By default, all collections with at least one ``core.Tensor`` element are traversed. + Visitor takes a path argument that is a tuple of the keys used to reach it. + """ + + # a value is terminal if it has no other containers values inside it + def _is_terminal(value: STATE_DICT_ITEM) -> bool: + values: Collection[STATE_DICT_ITEM] + if isinstance(value, Mapping): + values = value.values() + elif isinstance(value, list): + values = value + else: + return True + + for entry in values: + if isinstance(entry, (Mapping, list)) and not _is_terminal(entry): + return False + if keep_traversing is not None and keep_traversing(entry): + return False + return True + + def _traverse_obj(path: OBJ_PATH, value: STATE_DICT_ITEM) -> None: + if _is_terminal(value): + visitor(path, value) + elif isinstance(value, Mapping): + for k, v in value.items(): + _traverse_obj(path + (str(k),), v) + elif isinstance(value, list): + for i, v in enumerate(value): + _traverse_obj(path + (i,), v) + + for key, value in state_dict.items(): + _traverse_obj((str(key),), value) + + +def set_element( + root_dict: STATE_DICT_TYPE, path: OBJ_PATH, value: STATE_DICT_ITEM +) -> None: + """Set ``value`` in ``root_dict`` along the ``path`` object path.""" + cur_container = cast(CONTAINER_TYPE, root_dict) + + def extend_list(lst: List[STATE_DICT_ITEM], idx: int) -> None: + while len(lst) <= idx: + lst.append(None) + + for i in range(1, len(path)): + prev_key = path[i - 1] + key = path[i] + def_val = cast(STATE_DICT_ITEM, {} if type(key) == str else []) + + if isinstance(cur_container, Mapping): + cur_container = cast( + CONTAINER_TYPE, cur_container.setdefault(prev_key, def_val) + ) + else: + extend_list(cur_container, prev_key) + if cur_container[prev_key] is None: + cur_container[prev_key] = def_val + cur_container = cur_container[prev_key] + + key = path[-1] + if type(key) == int: + extend_list(cast(List[STATE_DICT_ITEM], cur_container), key) + + cur_container[key] = value + + +def get_element( + root_dict: STATE_DICT_TYPE, + path: OBJ_PATH, + default_value: Optional[T] = None, +) -> Optional[T]: + """Retrieve the value at ``path``from ``root_dict``, returning ``default_value`` if not found.""" + cur_value = cast(CONTAINER_TYPE, root_dict) + for part in path: + if type(part) is int: + if not isinstance(cur_value, list) or len(cur_value) < part: + return default_value + elif not isinstance(cur_value, Mapping) or part not in cur_value: + return default_value + + cur_value = cast(CONTAINER_TYPE, cur_value[part]) + return cast(Optional[T], cur_value) + + +def _print_nested( + value: STATE_DICT_ITEM, + prefix: str = "", + print_fun: Callable[[str], None] = print, +) -> None: + if type(value) is ShardedTensor: + print_fun(f"{prefix} ShardedTensor size: {value.size()}") + for shard in value.local_shards(): + _print_nested( + shard.tensor, + f"{shard.metadata.shard_offsets} ", + print_fun=print_fun, + ) + # elif type(value) is (DTensor): + # print_fun(f"{prefix} DistributedTensor size: {value.size()}") + # # TODO: add local offset for _local_tensor in print_nested. + # _print_nested( + # value._local_tensor, + # print_fun=print_fun, + # ) + elif isinstance(value, core.Tensor): + print_fun(f"{prefix} Tensor size: {value.size()}") + else: + print_fun(f"{prefix} Type: {type(value)}") + + +def print_tensor( + path: OBJ_PATH, + value: STATE_DICT_ITEM, + print_fun: Callable[[str], None] = print, +) -> None: + """ + Use this callback with traverse_state_dict to print its content. + + By default the content is printed using the builtin ``print`` but this can + be change by passing a different ``print_fun` callable. + """ + _print_nested(value, prefix=str(path), print_fun=print_fun) diff --git a/mindnlp/core/distributed/checkpoint/_version.py b/mindnlp/core/distributed/checkpoint/_version.py new file mode 100644 index 000000000..a1bca2949 --- /dev/null +++ b/mindnlp/core/distributed/checkpoint/_version.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates + +from typing import Optional + + +_derived_version: Optional[str] = None diff --git a/mindnlp/core/distributed/checkpoint/api.py b/mindnlp/core/distributed/checkpoint/api.py new file mode 100644 index 000000000..9b6b6f64e --- /dev/null +++ b/mindnlp/core/distributed/checkpoint/api.py @@ -0,0 +1,43 @@ +# mypy: allow-untyped-defs +import traceback as tb +from typing import Any, Dict, Tuple + + +WRAPPED_EXCEPTION = Tuple[BaseException, tb.StackSummary] + +__all__ = ["CheckpointException"] + + +def _wrap_exception(exc: BaseException) -> WRAPPED_EXCEPTION: + return (exc, tb.extract_tb(exc.__traceback__)) + + +def _is_wrapped_exception(obj: Any) -> bool: + if not isinstance(obj, tuple): + return False + if len(obj) != 2: + return False + return isinstance(obj[0], BaseException) and isinstance(obj[1], tb.StackSummary) + + +class CheckpointException(BaseException): + """Exception raised if failure was detected as part of a checkpoint load or save.""" + + def __init__(self, msg: str, failures: Dict[int, WRAPPED_EXCEPTION]): + super().__init__(msg, failures) + self._failures = failures + + @property + def failures(self) -> Dict[int, WRAPPED_EXCEPTION]: + """Return a dictionary mapping node ranks to their associated exceptions in case of failure.""" + return self._failures + + def __str__(self): + str = f"CheckpointException ranks:{self._failures.keys()}\n" + for rank, exc_pair in self._failures.items(): + exc, trace = exc_pair + str += f"Traceback (most recent call last): (RANK {rank})\n" + if trace is not None: + str += "".join(tb.format_list(trace)) + str += "".join(tb.format_exception_only(type(exc), value=exc)) + return str diff --git a/mindnlp/core/distributed/checkpoint/default_planner.py b/mindnlp/core/distributed/checkpoint/default_planner.py new file mode 100644 index 000000000..46c3be1b6 --- /dev/null +++ b/mindnlp/core/distributed/checkpoint/default_planner.py @@ -0,0 +1,546 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates + +import dataclasses +import io +import logging +import operator +from collections import ChainMap +from functools import reduce +from typing import Any, cast, Dict, List, Optional, Tuple, Union + +from mindnlp import core +from core.distributed._shard._utils import narrow_tensor_by_index +from core.distributed.checkpoint._dedup_save_plans import dedup_save_plans +from core.distributed.checkpoint._nested_dict import ( + FLATTEN_MAPPING, + flatten_state_dict, +) +from core.distributed.checkpoint._sharded_tensor_utils import _flatten_sharded_tensors +from core.distributed.checkpoint._traverse import set_element +from core.distributed.checkpoint.metadata import ( + BytesStorageMetadata, + ChunkStorageMetadata, + Metadata, + MetadataIndex, + STATE_DICT_TYPE, + STORAGE_TYPES, + StorageMeta, + TensorStorageMetadata, +) +from core.distributed.checkpoint.planner import ( + LoadPlan, + LoadPlanner, + ReadItem, + SavePlan, + SavePlanner, + WriteItem, + WriteItemType, +) +from core.distributed.checkpoint.planner_helpers import ( + _create_default_metadata_only_plan, + _create_read_items, + _create_write_items, + _init_state_dict, +) +from core.distributed.checkpoint.utils import find_state_dict_object +# from core.distributed.tensor import DTensor + +from . import _version + + +logger: logging.Logger = logging.getLogger(__name__) + + +__all__ = [ + "DefaultSavePlanner", + "DefaultLoadPlanner", + "create_default_local_load_plan", + "create_default_global_load_plan", + "create_default_local_save_plan", + "create_default_global_save_plan", +] + + +# TODO: Update docstrings for default_planner.py +class DefaultSavePlanner(SavePlanner): + mappings: FLATTEN_MAPPING + + def __init__( + self, + flatten_state_dict: bool = True, + flatten_sharded_tensors: bool = True, + dedup_replicated_tensors: Optional[bool] = None, + dedup_save_to_lowest_rank: bool = False, + ) -> None: + self.flatten_state_dict = flatten_state_dict + self.flatten_sharded_tensors = flatten_sharded_tensors + self.mappings = {} + self.dedup_save_to_lowest_rank = dedup_save_to_lowest_rank + if dedup_replicated_tensors is not None: + logger.warning( + "DefaultSavePlanner's `dedup_replicated_tensors` argument is being " + "deprecated, and no longer has any effect. Please remove this argument " + "from your call." + ) + + def set_up_planner( + self, + state_dict: STATE_DICT_TYPE, + storage_meta: Optional[StorageMeta] = None, + is_coordinator: bool = False, + ) -> None: + if self.flatten_state_dict: + state_dict, self.mappings = flatten_state_dict(state_dict) + if self.flatten_sharded_tensors: + state_dict = _flatten_sharded_tensors(state_dict) + self.state_dict = state_dict + self.is_coordinator = is_coordinator + + def create_local_plan(self) -> SavePlan: + plan = create_default_local_save_plan(self.state_dict, self.is_coordinator) + if self.flatten_state_dict: + plan = dataclasses.replace(plan, planner_data=self.mappings) + self.plan = plan + + return self.plan + + def create_global_plan( + self, all_plans: List[SavePlan] + ) -> Tuple[List[SavePlan], Metadata]: + all_plans = dedup_save_plans(all_plans, self.dedup_save_to_lowest_rank) + + global_plan, metadata = create_default_global_save_plan(all_plans) + + if self.flatten_state_dict: + # | does not work for Python 3.8 or older version. + # merged_mappings = reduce( + # lambda x, y: x | y, (p.planner_data for p in global_plan) + # ) + planner_data_dict = [p.planner_data for p in global_plan] + merged_mappings = dict(ChainMap(*planner_data_dict)) + metadata = dataclasses.replace(metadata, planner_data=merged_mappings) + + if not _validate_global_plan(global_plan, metadata): + raise ValueError("Failed to validate global plan") + + self.global_plan = global_plan + self.metadata = metadata + + return self.global_plan, self.metadata + + def finish_plan(self, new_plan: SavePlan) -> SavePlan: + self.plan = new_plan + return new_plan + + def resolve_data(self, write_item: WriteItem) -> Union[core.Tensor, io.BytesIO]: + object = self.lookup_object(write_item.index) + return self.transform_object(write_item, object) + + def lookup_object(self, index: MetadataIndex) -> Any: + """Extension from the planner interface to make it easy to extend the default planner.""" + return find_state_dict_object(self.state_dict, index) + + def transform_object(self, write_item: WriteItem, object: Any): + """Extension from the planner interface to make it easy to extend the default planner.""" + if write_item.type == WriteItemType.BYTE_IO: + bytes = io.BytesIO() + core.save(object, bytes) + object = bytes + return object + + +class DefaultLoadPlanner(LoadPlanner): + """ + DefaultLoadPlanner that adds multiple features on top of LoadPlanner. + + In particular it adds the following: + + flatten_state_dict: Handle state_dict with nested dicts + flatten_sharded_tensors: For FSDP in 2D parallel mode + allow_partial_load: If False, will raise a runtime error if a key is present in state_dict, but not in the checkpoint. + """ + + original_state_dict: STATE_DICT_TYPE + mappings: FLATTEN_MAPPING + + def __init__( + self, + flatten_state_dict: bool = True, + flatten_sharded_tensors: bool = True, + allow_partial_load: bool = False, + ) -> None: + self.flatten_state_dict = flatten_state_dict + self.flatten_sharded_tensors = flatten_sharded_tensors + self.original_state_dict = {} + self.mappings = {} + self.allow_partial_load = allow_partial_load + + def set_up_planner( + self, + state_dict: STATE_DICT_TYPE, + metadata: Optional[Metadata] = None, + is_coordinator: bool = False, + ) -> None: + _init_state_dict(state_dict) + self.original_state_dict = state_dict + + if self.flatten_sharded_tensors: + state_dict = _flatten_sharded_tensors(state_dict) + + if self.flatten_state_dict: + state_dict, self.mappings = flatten_state_dict(state_dict) + + self.state_dict = state_dict + self.metadata = metadata + self.is_coordinator = is_coordinator + + def create_local_plan(self) -> LoadPlan: + assert self.metadata is not None + if self.flatten_state_dict: + # To support checkpoints that are saved before v2.4, we have to + # differentiate if the missing keys are due to old checkpoints. + # The contracts are: + # 1. There are 3 cases when we found a missing key. + # 1.1 Actual missing key, but allow_partial_load is False + # 1.2 Actual missing key, but allow_partial load is True + # 1.3 Old checkpoint, but allow_partial_load is False + # 1.4 Old checkpoint, but allow_partial_load is True + # 2. If we found a missing key, we first convert the keys back to + # the key format of v2.3 + # 3. If the previous missing keys are in the v2.3 keys, we assume + # this is a old checkpoint. + # 4. Pass the state_dict to `create_default_local_load_plan()`, + # which has the logic to check missing for allow_partial_load. + # So for 1.2 and 1.4 cases, we delegate allow_partial_load check to + # `create_default_local_load_plan()`. The logic here is to determine + # whether the checkpoint belong to 2.3 (or before) or 2.4 (or after). + current_keys = set(self.state_dict.keys()) + load_keys = set(self.metadata.state_dict_metadata.keys()) + missing_keys = load_keys - current_keys + if missing_keys: + _version._derived_version = "2_3" + old_state_dict, old_mappings = flatten_state_dict( + self.original_state_dict + ) + old_keys = set(old_state_dict.keys()) + if old_keys & missing_keys: + self.state_dict, self.mappings = old_state_dict, old_mappings + # _derived_version is only used by flatten_state_dict now. + # Set it back to None so that later we can save to a new version. + _version._derived_version = None + + return create_default_local_load_plan( + self.state_dict, self.metadata, not self.allow_partial_load + ) + + def create_global_plan(self, global_plan: List[LoadPlan]) -> List[LoadPlan]: + return create_default_global_load_plan(global_plan) + + def finish_plan(self, new_plan: LoadPlan) -> LoadPlan: + return new_plan + + def load_bytes(self, read_item: ReadItem, value: io.BytesIO) -> None: + if self.flatten_state_dict: + set_element( + self.original_state_dict, + self.mappings[read_item.dest_index.fqn], + core.load(value, weights_only=False), + ) + else: + self.state_dict[read_item.dest_index.fqn] = core.load( + value, weights_only=False + ) + + def resolve_tensor(self, read_item: ReadItem): + tensor = self.lookup_tensor(read_item.dest_index) + return self.transform_tensor(read_item, tensor) + + def commit_tensor(self, read_item: ReadItem, tensor: core.Tensor) -> None: + pass + + def lookup_tensor(self, index: MetadataIndex) -> core.Tensor: + """Extension from the planner interface to make it easy to extend the default planner.""" + return find_state_dict_object(self.state_dict, index) + + def transform_tensor(self, read_item: ReadItem, tensor: core.Tensor): + """Extension from the planner interface to make it easy to extend the default planner.""" + return narrow_tensor_by_index(tensor, read_item.dest_offsets, read_item.lengths) + + +class _EmptyStateDictLoadPlanner(DefaultLoadPlanner): + """ + Extension of DefaultLoadPlanner, which rebuilds state_dict from the saved metadata. + Useful for loading in state_dict without first initializing a model, such as + when converting a DCP checkpoint into a Torch save file. + + . N.B. `state_dict` must be an empty dictionary when used with this LoadPlanner + + .. warning:: + Because the entire state dict is initialized, It's recommended to only utilize + this LoadPlanner on a single rank or process to avoid OOM. + + """ + + def __init__(self, keys=None, *args, **kwargs): + self.keys = keys + super().__init__(*args, **kwargs) + + def _should_include_key(self, key: str, metadata: Metadata) -> bool: + if self.keys is None: + return True + + if key in self.keys: + True + + unflattened_keys: List[str] = [] + planner_data = metadata.planner_data.get(key) + for unflattened_key in planner_data: + if unflattened_keys: + unflattened_keys.append( + ".".join([unflattened_keys[-1], str(unflattened_key)]) + ) + + else: + unflattened_keys.append(unflattened_key) + + if any(unflattened_key in self.keys for unflattened_key in unflattened_keys): + return True + + return False + + def set_up_planner( + self, + state_dict: STATE_DICT_TYPE, + metadata: Optional[Metadata] = None, + is_coordinator: bool = False, + ) -> None: + assert not state_dict + assert metadata is not None + + # rebuild the state dict from the metadata + for k, v in metadata.state_dict_metadata.items(): + if not self._should_include_key(k, metadata): + continue + + if isinstance(v, TensorStorageMetadata): + v = core.empty(v.size, dtype=v.properties.dtype) # type: ignore[assignment] + if k in metadata.planner_data: + set_element(state_dict, metadata.planner_data[k], v) + else: + state_dict[k] = v + + super().set_up_planner(state_dict, metadata, is_coordinator) + + +def create_default_local_load_plan( + state_dict: Dict[str, Any], metadata: Metadata, strict: bool = True +) -> LoadPlan: + requests = [] + """ + Create the ``LoadPlan`` used by DefaultLoadPlanner. + + It produces one read item per value in ``state_dict`` using the metadata in ``metadata``. + + The default behavior is to match key exactly between state_dict and metadata. + It handles resharding by issuing multiple read requests against storage in order to match + load requirements. + """ + + for fqn, obj in state_dict.items(): + # ignore state_dict keys which do not exist in `state_dict` if strict=False + if fqn not in metadata.state_dict_metadata: + if strict: + raise RuntimeError(f"Missing key in checkpoint state_dict: {fqn}.") + else: + continue + + md = metadata.state_dict_metadata[fqn] + # Since DTensor supports submesh, adding extra check to ensure _create_read_items() + # gets called only when the current rank is part of the mesh for the corresponding DTensor. + # if isinstance(obj, DTensor): + # if obj.device_mesh.get_coordinate() is not None: + # requests += _create_read_items(fqn, md, obj) + # else: + requests += _create_read_items(fqn, md, obj) + + return LoadPlan(requests) + + +def create_default_global_load_plan( + all_plans: List[LoadPlan], +) -> List[LoadPlan]: + """ + Create global load plan used by DefaultLoadPlanner. + + The default load behavior involved no global coordination and this function + currently doesn't change the local plans. + """ + return all_plans + + +def create_default_local_save_plan( + state_dict: Dict[str, Any], is_coordinator: bool +) -> SavePlan: + """ + Create the ``SavePlan`` used by DefaultSavePlanner. + + On non-coordinator ranks, this function ignores tensors and non-tensor objects, + only producing writes for ShardedTensor objects. + + On the coordinator rank, produce writes for all values. + """ + requests = [] + for fqn, obj in state_dict.items(): + # Since DTensor supports submesh, adding extra check to ensure _create_write_items() + # gets called only when the current rank is part of the mesh for the corresponding DTensor. + # if isinstance(obj, DTensor): + # if obj.device_mesh.get_coordinate() is not None: + # requests += _create_write_items(fqn, obj) + # else: + # For the plain tensor and non-tensor values, add the request for all + # the ranks. Coordinator will decides whether to deduplicate the + # values based on the keys. + requests += _create_write_items(fqn, obj) + + return SavePlan(requests) + + +def create_default_global_save_plan( + all_plans: List[SavePlan], + rewrite_index_hints: bool = True, +) -> Tuple[List[SavePlan], Metadata]: + """ + Create the global plan and metadata used by DefaultSavePlanner. + + Metadata is produced by concatenating the metadata of all ``WriteItem`` from the supplied plans. + + The only global planning change is to update index hints in all ``MetadataIndex`` objects if + ``rewrite_index_hints`` is True. + """ + md: Dict[str, STORAGE_TYPES] = {} + new_plans = [] + for plan in all_plans: + new_items = [] + for item in plan.items: + # if not item.type == WriteItemType.SHARD: + # assert item.index.fqn not in md + + if item.type == WriteItemType.BYTE_IO: + md[item.index.fqn] = BytesStorageMetadata() + new_items.append(item) + else: + assert item.tensor_data is not None + tensor_md = cast( + TensorStorageMetadata, + md.setdefault( + item.index.fqn, + TensorStorageMetadata( + properties=item.tensor_data.properties, + size=item.tensor_data.size, + chunks=[], + ), + ), + ) + new_item = item + if rewrite_index_hints: + new_index = dataclasses.replace( + item.index, index=len(tensor_md.chunks) + ) + new_item = dataclasses.replace(item, index=new_index) + new_items.append(new_item) + + assert ( + item.tensor_data.chunk is not None + ), f""" + Cannot create MD for tensor without bounds. + FQN: {item.index.fqn} + """ + tensor_md.chunks.append(item.tensor_data.chunk) + new_plans.append(dataclasses.replace(plan, items=new_items)) + return (new_plans, Metadata(md)) + + +def _create_default_local_metadata(state_dict: STATE_DICT_TYPE) -> Metadata: + """Return the ``Metadata`` if DefaultSavePlanner was used to checkpoint ``state_dict``.""" + plan = _create_default_metadata_only_plan(state_dict) + _, md = create_default_global_save_plan([plan]) + return md + + +def _check_box_overlap(box0: ChunkStorageMetadata, box1: ChunkStorageMetadata) -> bool: + """Check if two boxes overlap. Tuples are (offset, lengths).""" + # For each dim of each shard, check if one shard resides on the other + # end of second shard with respect to that dim. As an example for a 2D + # shard, we would check if one shard is above or on the left of the + # other shard. + ndims = len(box0.offsets) + for i in range(ndims): + if box0.offsets[i] >= box1.offsets[i] + box1.sizes[i]: + return False + if box1.offsets[i] >= box0.offsets[i] + box0.sizes[i]: + return False + + return True + + +def _check_box_bounds( + outer_box_size: core.Size, inner_box: ChunkStorageMetadata +) -> bool: + for i in range(len(outer_box_size)): + if inner_box.offsets[i] < 0: + return False + if inner_box.sizes[i] < 0: + return False + if inner_box.offsets[i] + inner_box.sizes[i] > outer_box_size[i]: + return False + + return True + + +def _validate_global_plan(global_plan: List[SavePlan], metadata: Metadata) -> bool: + all_good = True + for key, value in metadata.state_dict_metadata.items(): + if isinstance(value, BytesStorageMetadata): + continue + if len(value.size) == 0: + continue + chunks_volume = 0 + for chunk_idx, chunk0 in enumerate(value.chunks): + # Compute the volume + if not _check_box_bounds(value.size, chunk0): + logger.warning( + """ + key:%s has out of bounds chunk: + tensor-size:%s chunk: %s + """, + key, + value.size, + chunk0, + ) + all_good = False + chunks_volume += reduce(operator.mul, chunk0.sizes, 1) + + # Check for overlap + for chunk1 in value.chunks[chunk_idx + 1 :]: + if _check_box_overlap(chunk0, chunk1): + logger.warning( + "key:%s has overlapping chunks: %s %s", key, chunk0, chunk1 + ) + all_good = False + + # Check whether combined chunk cover the whole tensor + tensor_volume = reduce(operator.mul, value.size, 1) + if chunks_volume != tensor_volume: + logger.warning( + """ + key:%s invalid fill tensor-volume: + %s chunks-volume: %s + """, + key, + tensor_volume, + chunks_volume, + ) + all_good = False + + return all_good diff --git a/mindnlp/core/distributed/checkpoint/examples/async_checkpointing_example.py b/mindnlp/core/distributed/checkpoint/examples/async_checkpointing_example.py new file mode 100644 index 000000000..4c727cf0b --- /dev/null +++ b/mindnlp/core/distributed/checkpoint/examples/async_checkpointing_example.py @@ -0,0 +1,139 @@ +# mypy: allow-untyped-defs +# Owner(s): ["oncall: distributed"] + +import os +import shutil +import traceback + +from mindnlp import core +from mindnlp import core.distributed as dist +from mindnlp import core.distributed.checkpoint as dcp +from mindnlp import core.multiprocessing as mp +from mindnlp import core.nn as nn +from mindnlp import core.nn.functional as F +from core.distributed.checkpoint.state_dict import ( + _patch_model_state_dict, + _patch_optimizer_state_dict, +) +from core.distributed.fsdp import FullyShardedDataParallel as FSDP +from core.distributed.tensor.device_mesh import init_device_mesh + + +DEVICE = "cuda" +NUM_EPOCHS = 1000 +SAVE_PERIOD = 10 +FAULT_PERIOD = 25 +CHECKPOINT_DIR = f"~/{os.environ.get('LOGNAME', '')}/checkpoint" + + +class InjectedException(Exception): + pass + + +class Model(core.nn.Module): + def __init__(self) -> None: + super().__init__() + self.net1 = nn.Linear(8, 32) + self.net2 = nn.Linear(32, 128) + self.net3 = nn.Linear(128, 64) + self.net4 = nn.Linear(64, 8) + self.net5 = nn.Linear(8, 1) + + def forward(self, x): + x = F.relu(self.net1(x)) + x = F.relu(self.net2(x)) + x = F.relu(self.net3(x)) + x = F.relu(self.net4(x)) + x = F.sigmoid(self.net5(x)) + return x + + +def _init_model(rank, world_size): + device_mesh = init_device_mesh(DEVICE, (world_size,)) + + # Create a dummy model and wrap it in FSDP + model = Model().cuda() + device_mesh = init_device_mesh(DEVICE, (world_size,)) + model = FSDP(model, device_mesh=device_mesh, use_orig_params=True) + + optim = core.optim.Adam(model.parameters(), lr=0.0001) + + _patch_model_state_dict(model) + _patch_optimizer_state_dict(model, optimizers=optim) + + return model, optim + + +def _print(msg): + if dist.get_rank() == 0: + print(msg) + + +def _input(): + x = core.rand(128, 8, device="cuda") + y = core.zeros(128, 1, device="cuda") + + y[core.sum(x, dim=1) >= 4] = 1.0 + + return x, y + + +def run(rank, world_size): + # Set up world pg + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "12355" + + dist.init_process_group("cpu:gloo,cuda:nccl", rank=rank, world_size=world_size) + core.cuda.set_device(rank) + + model, optim = _init_model(rank, world_size) + state_dict = {"model": model, "optim": optim} + loss_calc = core.nn.BCELoss() + + f = None + for epoch in range(NUM_EPOCHS): + try: + core.manual_seed(epoch) + x, y = _input() + + loss = loss_calc(model(x), y) + + _print(f"{epoch=} {loss=}") + + loss.backward() + optim.step() + optim.zero_grad() + + if epoch % SAVE_PERIOD == 0: + if f is not None: + f.result() + f = dcp.state_dict_saver.async_save( + state_dict, checkpoint_id=CHECKPOINT_DIR + ) + + if FAULT_PERIOD > 0 and epoch % FAULT_PERIOD == 0: + raise InjectedException("Fault injection!") + + except InjectedException as e: + dist.barrier() + + _print("Trainer encountered exception:") + traceback.print_tb(e.__traceback__) + + _print("Reloading model from last checkpoint!") + if f is not None: + f.result() + dcp.load(state_dict) + + +if __name__ == "__main__": + world_size = core.cuda.device_count() + print(f"Running an example of Async Checkpointing on {world_size} devices.") + shutil.rmtree(CHECKPOINT_DIR, ignore_errors=True) + + mp.spawn( + run, + args=(world_size,), + nprocs=world_size, + join=True, + ) diff --git a/mindnlp/core/distributed/checkpoint/examples/fsdp_checkpoint_example.py b/mindnlp/core/distributed/checkpoint/examples/fsdp_checkpoint_example.py new file mode 100644 index 000000000..2971114be --- /dev/null +++ b/mindnlp/core/distributed/checkpoint/examples/fsdp_checkpoint_example.py @@ -0,0 +1,131 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates + +""" +The following example demonstrates how to use Pytorch Distributed Checkpoint to save a FSDP model. + +This is the current recommended way to checkpoint FSDP. +core.save() and core.load() is not recommended when checkpointing sharded models. +""" + +import os +import shutil + +from mindnlp import core +from mindnlp import core.distributed as dist +from mindnlp import core.distributed.checkpoint as dist_cp +from mindnlp import core.multiprocessing as mp +from core.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict +from core.distributed.fsdp import FullyShardedDataParallel as FSDP +from core.distributed.fsdp.fully_sharded_data_parallel import StateDictType + + +CHECKPOINT_DIR = f"/scratch/{os.environ['LOGNAME']}/checkpoint" + + +def opt_at(opt, idx): + return list(opt.state.values())[idx] + + +def init_model(): + model = FSDP(core.nn.Linear(4, 4).cuda(dist.get_rank())) + optim = core.optim.Adam(model.parameters(), lr=0.1) + model(core.rand(4, 4)).sum().backward() + optim.step() + + return model, optim + + +def print_params(stage, model_1, model_2, optim_1, optim_2): + with FSDP.summon_full_params(model_1): + with FSDP.summon_full_params(model_2): + print( + f"{stage} --- rank: {dist.get_rank()}\n" + f"model.weight: {model_1.weight}\n" + f"model_2.weight:{model_2.weight}\n" + f"model.bias: {model_1.bias}\n" + f"model_2.bias: {model_2.bias}\n" + ) + + print( + f"{stage} --- rank: {dist.get_rank()}\n" + f"optim exp_avg:{opt_at(optim_1, 0)['exp_avg']}\n" + f"optim_2 exp_avg:{opt_at(optim_2, 0)['exp_avg']}\n" + f"optim exp_avg_sq:{opt_at(optim_1, 0)['exp_avg_sq']}\n" + f"optim_2 exp_avg_sq:{opt_at(optim_2, 0)['exp_avg_sq']}\n" + ) + + +def run_fsdp_checkpoint_example(rank, world_size): + # Set up world pg + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "12355" + + # Initialize the process group + dist.init_process_group("cpu:gloo,cuda:nccl", rank=rank, world_size=world_size) + core.cuda.set_device(rank) + + # Create a model + model_1, optim_1 = init_model() + + # Save the model to CHECKPOINT_DIR + with FSDP.state_dict_type(model_1, StateDictType.SHARDED_STATE_DICT): + state_dict = { + "model": model_1.state_dict(), + "optim": FSDP.optim_state_dict(model_1, optim_1), + } + + dist_cp.save_state_dict( + state_dict=state_dict, + storage_writer=dist_cp.FileSystemWriter(CHECKPOINT_DIR), + ) + + # Create a second model + model_2, optim_2 = init_model() + + # Print the model parameters for both models. + # Before loading, the parameters should be different. + print_params("Before loading", model_1, model_2, optim_1, optim_2) + + # Load model_2 with parameters saved in CHECKPOINT_DIR + with FSDP.state_dict_type(model_2, StateDictType.SHARDED_STATE_DICT): + state_dict = { + "model": model_2.state_dict(), + # cannot load the optimizer state_dict together with the model state_dict + } + + dist_cp.load_state_dict( + state_dict=state_dict, + storage_reader=dist_cp.FileSystemReader(CHECKPOINT_DIR), + ) + model_2.load_state_dict(state_dict["model"]) + + optim_state = load_sharded_optimizer_state_dict( + model_state_dict=state_dict["model"], + optimizer_key="optim", + storage_reader=dist_cp.FileSystemReader(CHECKPOINT_DIR), + ) + + flattened_osd = FSDP.optim_state_dict_to_load( + model_2, optim_2, optim_state["optim"] + ) + optim_2.load_state_dict(flattened_osd) + + # Print the model parameters for both models. + # After loading, the parameters should be the same. + print_params("After loading", model_1, model_2, optim_1, optim_2) + + # Shut down world pg + dist.destroy_process_group() + + +if __name__ == "__main__": + world_size = core.cuda.device_count() + print(f"Running fsdp checkpoint example on {world_size} devices.") + shutil.rmtree(CHECKPOINT_DIR, ignore_errors=True) + mp.spawn( + run_fsdp_checkpoint_example, + args=(world_size,), + nprocs=world_size, + join=True, + ) diff --git a/mindnlp/core/distributed/checkpoint/examples/stateful_example.py b/mindnlp/core/distributed/checkpoint/examples/stateful_example.py new file mode 100644 index 000000000..95297507c --- /dev/null +++ b/mindnlp/core/distributed/checkpoint/examples/stateful_example.py @@ -0,0 +1,107 @@ +# mypy: allow-untyped-defs +# Owner(s): ["oncall: distributed"] + +# pyre-unsafe + + +import os +import shutil + +from mindnlp import core +from mindnlp import core.distributed as dist +from mindnlp import core.distributed.checkpoint as dcp +from mindnlp import core.multiprocessing as mp +from mindnlp import core.nn as nn +from core.distributed.checkpoint.state_dict import ( + _patch_model_state_dict, + _patch_optimizer_state_dict, +) +from core.distributed.device_mesh import init_device_mesh +from core.distributed.fsdp import FullyShardedDataParallel as FSDP + + +CHECKPOINT_DIR = f"~/{os.environ['LOGNAME']}/checkpoint" + + +class Model(core.nn.Module): + def __init__(self) -> None: + super().__init__() + core.manual_seed(0) + self.net1 = nn.Sequential(nn.Linear(8, 16), nn.ReLU()) + self.net2 = nn.Sequential(nn.Linear(16, 32), nn.ReLU()) + self.net3 = nn.Linear(32, 64) + self.net4 = nn.Sequential(nn.ReLU(), nn.Linear(64, 8)) + + def forward(self, x): + return self.net4(self.net3(self.net2(self.net1(x)))) + + def get_input(self): + return core.rand(8, 8, device="cuda") + + +def _make_stateful(model, optim): + _patch_model_state_dict(model) + _patch_optimizer_state_dict(model, optimizers=optim) + + +def _train(model, optim, train_steps=1): + core.manual_seed(0) + loss = None + for _ in range(train_steps): + loss = model(model.get_input()).sum() + loss.backward() + optim.step() + optim.zero_grad() + + return loss + + +def _init_model(device, world_size): + device_mesh = init_device_mesh(device, (world_size,)) + model = Model().cuda() + model = FSDP( + model, + device_mesh=device_mesh, + use_orig_params=True, + ) + optim = core.optim.Adam(model.parameters(), lr=0.1) + _make_stateful(model, optim) + + return model, optim + + +def run(rank, world_size, device="cuda"): + # Set up world pg + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "12355" + + dist.init_process_group("cpu:gloo,cuda:nccl", rank=rank, world_size=world_size) + core.cuda.set_device(rank) + + model, optim = _init_model(device, world_size) + _train(model, optim, train_steps=2) + + dcp.save( + state_dict={"model": model, "optimizer": optim}, + checkpoint_id=CHECKPOINT_DIR, + ) + + # presumably do something else + model, optim = _init_model(device, world_size) + dcp.load( + state_dict={"model": model, "optimizer": optim}, + checkpoint_id=CHECKPOINT_DIR, + ) + _train(model, optim, train_steps=2) + + +if __name__ == "__main__": + world_size = core.cuda.device_count() + print(f"Running stateful checkpoint example on {world_size} devices.") + shutil.rmtree(CHECKPOINT_DIR, ignore_errors=True) + mp.spawn( + run, + args=(world_size,), + nprocs=world_size, + join=True, + ) diff --git a/mindnlp/core/distributed/checkpoint/filesystem.py b/mindnlp/core/distributed/checkpoint/filesystem.py new file mode 100644 index 000000000..2bc6a5ece --- /dev/null +++ b/mindnlp/core/distributed/checkpoint/filesystem.py @@ -0,0 +1,768 @@ +# mypy: allow-untyped-defs +import collections +import dataclasses +import io +import operator +import os +import pickle +import queue +import threading +import uuid +import warnings +from abc import ABC, abstractmethod +from contextlib import contextmanager +from dataclasses import dataclass +from pathlib import Path +from typing import ( + Any, + Callable, + cast, + Dict, + Generator, + IO, + Iterable, + Iterator, + List, + Optional, + Tuple, + Union, +) + +from mindnlp import core +from mindnlp.core import Tensor +# from core._utils import _get_available_device_type, _get_device_module +from core.distributed._shard._utils import narrow_tensor_by_index +from core.distributed.checkpoint.metadata import ( + Metadata, + MetadataIndex, + STATE_DICT_TYPE, + StorageMeta, +) +from core.distributed.checkpoint.planner import ( + LoadItemType, + LoadPlan, + LoadPlanner, + ReadItem, + SavePlan, + SavePlanner, + WriteItem, + WriteItemType, +) +from core.distributed.checkpoint.staging import BlockingAsyncStager +from core.distributed.checkpoint.storage import ( + StorageReader, + StorageWriter, + WriteResult, +) +from core.distributed.checkpoint.utils import _create_file_view + + +__all__ = ["FileSystemWriter", "FileSystemReader", "FileSystem", "FileSystemBase"] + +_metadata_fn: str = ".metadata" + + +@dataclass +class _StorageInfo: + """This is the per entry storage info.""" + + relative_path: str + offset: int + length: int + + +@dataclass +class _StoragePrefix: + prefix: str + + +DEFAULT_SUFFIX = ".distcp" + + +def _generate_uuid() -> str: + return str(uuid.uuid4()) + + +class _TensorLoader(ABC): + @abstractmethod + def add(self, size: int, obj: object) -> None: + pass + + @abstractmethod + def start_loading(self) -> None: + pass + + @abstractmethod + def values(self) -> Iterator[Tuple[core.Tensor, object]]: + pass + + +class _SerialCpuLoader(_TensorLoader): + def __init__(self, resolve_fun: Callable) -> None: + self.resolve_fun = resolve_fun + self.items: List[Tuple[int, object]] = [] + + def add(self, size: int, obj: object) -> None: + self.items.append((size, obj)) + + def start_loading(self) -> None: + pass + + def values(self) -> Iterator[Tuple[core.Tensor, object]]: + for _, obj in self.items: + tensor = self.resolve_fun(obj).detach() + tensor = tensor.cpu() + if tensor.storage().size() != tensor.numel(): + tensor = tensor.clone() + yield ( + tensor, + obj, + ) + + +class _OverlappingCpuLoader(_TensorLoader): + def __init__( + self, + resolve_fun: Callable, + stream: Optional[core.Stream] = None, + inflight_threshhold: int = 1_000_000, + ) -> None: + self.resolve_fun = resolve_fun + self.items: List[Tuple[int, object]] = [] + self.inflight_threshhold = inflight_threshhold + self.in_flight_data = 0 + self.current_items: collections.deque = collections.deque() + self.idx = 0 + self.started = False + self.device_type = ( + stream.device_type if stream else _get_available_device_type() + ) + self.device_module = _get_device_module(self.device_type) + self.stream = cast( + core.cuda.Stream, stream or self.device_module.current_stream() + ) + if self.stream != self.device_module.current_stream(): + self.stream.wait_stream(self.device_module.current_stream()) + + @property + def _done(self) -> bool: + return self.idx >= len(self.items) + + def _drain(self) -> List[Tuple[core.Tensor, object]]: + drained = [] + if self.in_flight_data >= self.inflight_threshhold: + self.stream.synchronize() + while self.in_flight_data >= self.inflight_threshhold: + val = self.current_items.popleft() + self.in_flight_data -= val[0].numel() * val[0].element_size() + drained.append(val) + return drained + + def _refill(self) -> None: + with self.device_module.stream(self.stream): + while not self._done and self.in_flight_data < self.inflight_threshhold: + _, obj = self.items[self.idx] + self.idx += 1 + tensor = self.resolve_fun(obj).detach() + if tensor.device.type == self.device_type: + tensor = tensor.to(device="cpu", non_blocking=True) + elif tensor.device == core.device("cpu"): + if ( + tensor.untyped_storage().size() + != tensor.numel() * tensor.itemsize + ): + # this forces the tensor to be both contiguous and with minimal storage + tensor = tensor.clone() + + self.current_items.append( + ( + tensor, + obj, + ) + ) + self.in_flight_data += tensor.numel() * tensor.element_size() + + def _finish(self) -> Iterable[Tuple[core.Tensor, object]]: + assert self._done + if len(self.current_items) > 0: + self.stream.synchronize() + return self.current_items + + def add(self, size: int, obj: object) -> None: + if self.started: + raise RuntimeError("cannot add items after loading started") + self.items.append((size, obj)) + + def start_loading(self) -> None: + if self.started: + return + self.started = True + self.items.sort(key=operator.itemgetter(0)) + self._refill() + + def values(self) -> Iterator[Tuple[core.Tensor, object]]: + self.start_loading() + while not self._done: + drained = self._drain() + self._refill() + yield from drained + + yield from self._finish() + + +def _item_size(item: WriteItem) -> int: + size = 1 + assert item.tensor_data is not None + # can't use math.prod as PT needs to support older python + for s in item.tensor_data.size: + size *= s + + dtype = item.tensor_data.properties.dtype + return size * core._utils._element_size(dtype) + + +def _split_by_size_and_type(bins: int, items: List[WriteItem]) -> List[List[WriteItem]]: + if bins == 1: + return [items] + + bytes_w = [wi for wi in items if wi.type == WriteItemType.BYTE_IO] + tensor_w = [wi for wi in items if wi.type != WriteItemType.BYTE_IO] + + buckets: List[List[WriteItem]] = [[] for _ in range(bins)] + bucket_sizes = [0 for _ in range(bins)] + + tensor_w.sort(key=_item_size, reverse=True) + + for i, wi in enumerate(bytes_w): + buckets[i % bins].append(wi) + + for wi in tensor_w: + # TODO replace with headq + idx = min(enumerate(bucket_sizes), key=operator.itemgetter(1))[0] + buckets[idx].append(wi) + bucket_sizes[idx] += _item_size(wi) + + return buckets + + +def _write_item( + stream: io.IOBase, + data: Union[io.BytesIO, core.Tensor], + write_item: WriteItem, + storage_key: str, +) -> WriteResult: + offset = stream.tell() + + if write_item.type == WriteItemType.BYTE_IO: + assert isinstance(data, io.BytesIO) + stream.write(data.getbuffer()) + else: + assert isinstance(data, core.Tensor) + assert data.device == core.device("cpu") + core.save(data, cast(IO[bytes], stream)) + length = stream.tell() - offset + return WriteResult( + index=write_item.index, + size_in_bytes=length, + storage_data=_StorageInfo(storage_key, offset, length), + ) + + +def _write_files_from_queue( + create_stream: Callable, + file_queue: queue.Queue, + result_queue: queue.Queue, + planner: SavePlanner, + inflight_threshhold: int, + use_fsync: bool, + thread_count: int, +) -> None: + try: + while True: + file_name, storage_key, write_items = file_queue.get_nowait() + loader: _TensorLoader + + custom_backend_name = core._C._get_privateuse1_backend_name() + custom_device_mod = getattr(torch, custom_backend_name, None) + + # TODO: Using the OverlappingCpuLoader with multiple threads creates significant + # performance degredation, observed as being related to cuda stream syncs. We + # should try to fix this and use _OverlappingCpuLoader for all threaded cases + if ( + thread_count == 1 + and ( + core.cuda.is_available() + or (custom_device_mod and custom_device_mod.is_available()) + ) + and inflight_threshhold > 0 + ): + loader = _OverlappingCpuLoader( + planner.resolve_data, + inflight_threshhold=inflight_threshhold, + ) + else: + loader = _SerialCpuLoader( + planner.resolve_data, + ) + + tensor_w = [wi for wi in write_items if wi.type != WriteItemType.BYTE_IO] + for write_item in tensor_w: + loader.add(_item_size(write_item), write_item) + loader.start_loading() + + bytes_w = [wi for wi in write_items if wi.type == WriteItemType.BYTE_IO] + write_results = [] + + with create_stream(file_name, "wb") as stream: + for write_item in bytes_w: + data = planner.resolve_data(write_item) + write_results.append( + _write_item(stream, data, write_item, storage_key) + ) + + for tensor, write_item in loader.values(): + # assert tensor.is_cpu + write_results.append( + _write_item(stream, tensor, write_item, storage_key) + ) + + if use_fsync: + try: + os.fsync(stream.fileno()) + except AttributeError: + os.sync() + result_queue.put(write_results) + except queue.Empty: + pass + + +class FileSystemBase(ABC): + @contextmanager + @abstractmethod + def create_stream( + self, path: Union[str, os.PathLike], mode: str + ) -> Generator[io.IOBase, None, None]: + ... + + @abstractmethod + def concat_path( + self, path: Union[str, os.PathLike], suffix: str + ) -> Union[str, os.PathLike]: + ... + + @abstractmethod + def rename( + self, path: Union[str, os.PathLike], new_path: Union[str, os.PathLike] + ) -> None: + ... + + @abstractmethod + def init_path(self, path: Union[str, os.PathLike]) -> Union[str, os.PathLike]: + ... + + @abstractmethod + def mkdir(self, path: Union[str, os.PathLike]) -> None: + ... + + @classmethod + @abstractmethod + def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool: + ... + + @abstractmethod + def exists(self, path: Union[str, os.PathLike]) -> bool: + ... + + @abstractmethod + def rm_file(self, path: Union[str, os.PathLike]) -> None: + ... + + +class FileSystem(FileSystemBase): + @contextmanager + def create_stream( + self, path: Union[str, os.PathLike], mode: str + ) -> Generator[io.IOBase, None, None]: + with cast(Path, path).open(mode) as stream: + yield cast(io.IOBase, stream) + + def concat_path( + self, path: Union[str, os.PathLike], suffix: str + ) -> Union[str, os.PathLike]: + return cast(Path, path) / suffix + + def init_path(self, path: Union[str, os.PathLike]) -> Union[str, os.PathLike]: + if not isinstance(path, Path): + path = Path(path) + return path + + def rename( + self, path: Union[str, os.PathLike], new_path: Union[str, os.PathLike] + ) -> None: + cast(Path, path).rename(cast(Path, new_path)) + + def mkdir(self, path: Union[str, os.PathLike]) -> None: + cast(Path, path).mkdir(parents=True, exist_ok=True) + + @classmethod + def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool: + if isinstance(checkpoint_id, Path): + return True + + if "://" in str(checkpoint_id): + return False + + for p in Path(checkpoint_id).parents: + if p.exists() and os.access(str(p), os.W_OK): + return True + + return False + + def exists(self, path: Union[str, os.PathLike]) -> bool: + return cast(Path, path).exists() + + def rm_file(self, path: Union[str, os.PathLike]) -> None: + cast(Path, path).unlink() + + +class _FileSystemWriter(StorageWriter): + """ + Basic implementation of StorageWriter using file IO. + + This implementation makes the following assumptions and simplifications: + + * The checkpoint path is an empty or non-existing directory. + * File creation is atomic + + The checkpoint consist of one file per write request plus + a `.metadata` file with the serialized metadata. + + """ + + def __init__( + self, + path: Union[str, os.PathLike], + single_file_per_rank: bool = True, + sync_files: bool = True, + thread_count: int = 1, + per_thread_copy_ahead: int = 10_000_000, + overwrite: bool = True, + *args: Any, + **kwargs: Any, + ) -> None: + """ + Initialize the writer pointing to `path`. + + Args: + path: directory where the checkpoint will be written to. + single_file_per_rank: Produce one file per rank instead of one file per tensor/blob. Default to True. + sync_files : force files to be synced to permanent storage. Default to True. + thread_count: Number of IO threads to use to write. Default to 1. + per_thread_copy_ahead: How many bytes to copy from the GPU ahead of saving then. Default 10Mb. + overwrite: Whether to allow overwriting existing checkpoints. Defaults to True. + + N. B. If sync_files is disabled, there's no guarantee that the checkpoint will be consistent in the case of a failure. + """ + super().__init__() + self.fs = FileSystem() + self.path = self.fs.init_path(path) + self.single_file_per_rank = single_file_per_rank + self.sync_files = sync_files + self.thread_count = thread_count + self.per_thread_copy_ahead = per_thread_copy_ahead + self.save_id = _generate_uuid() + self.overwrite = overwrite + + def reset(self, checkpoint_id: Union[str, os.PathLike, None] = None) -> None: + if checkpoint_id: + self.path = self.fs.init_path(checkpoint_id) + self.save_id = _generate_uuid() + + def set_up_storage_writer(self, is_coordinator: bool) -> None: + pass + + def prepare_local_plan(self, plan: SavePlan) -> SavePlan: + self.fs.mkdir(self.path) + if self.fs.exists(self.metadata_path): + if self.overwrite: + warnings.warn( + f"Detected an existing checkpoint in {self.metadata_path}, overwriting since {self.overwrite=}." + " Past version 2.5 of PyTorch, `overwrite` will default to False. Set this variable to True to" + " maintain this functionality or False to raise when an existing checkpoint is found." + ) + else: + raise RuntimeError(f"Checkpoint already exists and {self.overwrite=}.") + + return plan + + def prepare_global_plan(self, plans: List[SavePlan]) -> List[SavePlan]: + new_plans = [ + dataclasses.replace(plan, storage_data=_StoragePrefix(f"__{i}_")) + for i, plan in enumerate(plans) + ] + return new_plans + + def write_data( + self, + plan: SavePlan, + planner: SavePlanner, + ): + storage_plan: _StoragePrefix = plan.storage_data + file_count = 0 + + def gen_file(): + nonlocal file_count + file_name = f"{storage_plan.prefix}{file_count}{DEFAULT_SUFFIX}" + file_count += 1 + return file_name + + file_queue: queue.Queue = queue.Queue() + if self.single_file_per_rank: + for bucket in _split_by_size_and_type(self.thread_count, plan.items): + file_name = gen_file() + path = self.fs.concat_path(self.path, file_name) + file_queue.put((path, file_name, bucket)) + else: + for item in plan.items: + file_name = gen_file() + path = self.fs.concat_path(self.path, file_name) + file_queue.put((path, file_name, [item])) + + result_queue: queue.Queue = queue.Queue() + + threads = [] + for _ in range(1, self.thread_count): + t = threading.Thread( + target=_write_files_from_queue, + args=( + self.fs.create_stream, + file_queue, + result_queue, + planner, + self.per_thread_copy_ahead, + self.sync_files, + self.thread_count, + ), + ) + t.start() + threads.append(t) + + _write_files_from_queue( + create_stream=self.fs.create_stream, + file_queue=file_queue, + result_queue=result_queue, + planner=planner, + inflight_threshhold=self.per_thread_copy_ahead, + use_fsync=self.sync_files, + thread_count=self.thread_count, + ) + + for t in threads: + t.join() + + res = [] + try: + while True: + res += result_queue.get_nowait() + except queue.Empty: + fut: Future[List[WriteResult]] = Future() + fut.set_result(res) + return fut + + def finish(self, metadata: Metadata, results: List[List[WriteResult]]) -> None: + storage_md = {} + for wr_list in results: + storage_md.update({wr.index: wr.storage_data for wr in wr_list}) + metadata.storage_data = storage_md + + metadata.storage_meta = self.storage_meta() + + tmp_path = cast(Path, self.fs.concat_path(self.path, f"{_metadata_fn}.tmp")) + with self.fs.create_stream(tmp_path, "wb") as metadata_file: + pickle.dump(metadata, metadata_file) + if self.sync_files: + try: + os.fsync(metadata_file.fileno()) + except AttributeError: + os.sync() + + # delete in-case other checkpoints were present. + if self.fs.exists(self.metadata_path): + self.fs.rm_file(self.metadata_path) + + self.fs.rename(tmp_path, self.metadata_path) + + def storage_meta(self) -> Optional[StorageMeta]: + return StorageMeta(checkpoint_id=self.checkpoint_id, save_id=self.save_id) + + @property + def metadata_path(self) -> Union[str, os.PathLike]: + return cast(Path, self.fs.concat_path(self.path, _metadata_fn)) + + @property + def checkpoint_id(self) -> Union[str, os.PathLike]: + """ + return the checkpoint_id that will be used to save the checkpoint. + """ + return self.path + + @classmethod + def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool: + return FileSystem.validate_checkpoint_id(checkpoint_id) + + +class FileSystemReader(StorageReader): + def __init__(self, path: Union[str, os.PathLike]) -> None: + super().__init__() + self.fs = FileSystem() + self.path = self.fs.init_path(path) + self.storage_data: Dict[MetadataIndex, _StorageInfo] = {} + self.load_id = _generate_uuid() + + def _slice_file(self, file, sinfo: _StorageInfo) -> io.IOBase: + return _create_file_view(file, sinfo.offset, sinfo.length) + + def reset(self, checkpoint_id: Union[str, os.PathLike, None] = None) -> None: + self.storage_data = {} + if checkpoint_id: + self.path = self.fs.init_path(checkpoint_id) + self.load_id = _generate_uuid() + + def read_data(self, plan: LoadPlan, planner: LoadPlanner): + # group requests by file + per_file: Dict[str, List[ReadItem]] = {} + for read_item in plan.items: + item_md = self.storage_data[read_item.storage_index] + path = item_md.relative_path + per_file.setdefault(path, []).append(read_item) + + for relative_path, reqs in per_file.items(): + new_path = self.fs.concat_path(self.path, relative_path) + with self.fs.create_stream(new_path, "rb") as stream: + # TODO sort by offset and cache the reading + for req in reqs: + item_md = self.storage_data[req.storage_index] + file_slice = self._slice_file(stream, item_md) + if req.type == LoadItemType.BYTE_IO: + read_bytes = io.BytesIO(file_slice.read(item_md.length)) + read_bytes.seek(0) + planner.load_bytes(req, read_bytes) + else: + tensor = cast( + Tensor, + core.load( + cast(IO[bytes], file_slice), + map_location="cpu", + weights_only=True, + ), + ) + tensor = narrow_tensor_by_index( + tensor, req.storage_offsets, req.lengths + ) + target_tensor = planner.resolve_tensor(req).detach() + + assert ( + target_tensor.size() == tensor.size() + ), f"req {req.storage_index} mismatch sizes {target_tensor.size()} vs {tensor.size()}" + target_tensor.copy_(tensor) + planner.commit_tensor(req, target_tensor) + + # fut: Future = Future() + # fut.set_result(None) + # return fut + return None + + # Implementing the abstract function in StorageReader + def read_metadata(self) -> Metadata: + path = self.fs.concat_path(self.path, ".metadata") + with self.fs.create_stream(path, "rb") as metadata_file: + metadata = pickle.load(metadata_file) + + if getattr(metadata, "storage_meta", None) is None: + metadata.storage_meta = StorageMeta() + metadata.storage_meta.load_id = self.load_id + + return metadata + + def set_up_storage_reader(self, metadata: Metadata, is_coordinator: bool) -> None: + self.storage_data = metadata.storage_data + assert self.storage_data is not None + + def prepare_local_plan(self, plan: LoadPlan) -> LoadPlan: + return plan + + def prepare_global_plan(self, plans: List[LoadPlan]) -> List[LoadPlan]: + return plans + + @property + def checkpoint_id(self) -> Union[str, os.PathLike]: + """ + return the checkpoint_id that will be used to load the checkpoint. + """ + return self.path + + @classmethod + def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool: + return FileSystem.validate_checkpoint_id(checkpoint_id) + + +class FileSystemWriter(_FileSystemWriter, BlockingAsyncStager): + """ + Basic implementation of StorageWriter using file IO. + + This implementation makes the following assumptions and simplifications: + + * The checkpoint path is an empty or non-existing directory. + * File creation is atomic + + The checkpoint consist of one file per write request plus + a `.metadata` file with the serialized metadata. + + """ + + def __init__( + self, + path: Union[str, os.PathLike], + single_file_per_rank: bool = True, + sync_files: bool = True, + thread_count: int = 1, + per_thread_copy_ahead: int = 10_000_000, + cache_staged_state_dict: bool = False, + overwrite: bool = True, + ) -> None: + """ + Initialize the writer pointing to `path`. + + Args: + path: directory where the checkpoint will be written to. + single_file_per_rank: Produce one file per rank instead of one file per tensor/blob. Default to True. + sync_files : force files to be synced to permanent storage. Default to True. + thread_count: Number of IO threads to use to write. Default to 1. + per_thread_copy_ahead: How many bytes to copy from the GPU ahead of saving then. Default 10Mb. + cache_staged_state_dict: Whether to cache the staged state_dict. This option decreases staging latency + at the cost of increases memory usage. Additionally, if this parameter is set to True, it's the expectation + that the stager is maintained and re-used for multiple dcp.async_save calls. Default to False. + overwrite: Whether to allow overwriting existing checkpoints. Defaults to True. + + N. B. If sync_files is disabled, there's no guarantee that the checkpoint will be consistent in the case of a failure. + """ + _FileSystemWriter.__init__( + self, + path=path, + single_file_per_rank=single_file_per_rank, + sync_files=sync_files, + thread_count=thread_count, + per_thread_copy_ahead=per_thread_copy_ahead, + overwrite=overwrite, + ) + BlockingAsyncStager.__init__( + self, + cache_staged_state_dict=cache_staged_state_dict, + ) + + def stage(self, state_dict: STATE_DICT_TYPE) -> STATE_DICT_TYPE: + """Override of AsyncStager.stage""" + # in the async case, the state dict is already on CPU, so maintaining this + # buffer makes no sense + self.per_thread_copy_ahead = 0 + return super().stage(state_dict) diff --git a/mindnlp/core/distributed/checkpoint/format_utils.py b/mindnlp/core/distributed/checkpoint/format_utils.py new file mode 100644 index 000000000..bada79297 --- /dev/null +++ b/mindnlp/core/distributed/checkpoint/format_utils.py @@ -0,0 +1,280 @@ +# mypy: allow-untyped-defs +import argparse +import os +from enum import Enum +from typing import cast, Dict, List, Optional, Union + +from mindnlp import core +from mindnlp import core.distributed as dist +from core.distributed._shard._utils import narrow_tensor_by_index +from core.distributed.checkpoint import FileSystemReader, FileSystemWriter +from core.distributed.checkpoint._nested_dict import flatten_state_dict +from core.distributed.checkpoint.default_planner import ( + _EmptyStateDictLoadPlanner, + DefaultLoadPlanner, +) +from core.distributed.checkpoint.metadata import ( + Metadata, + STATE_DICT_TYPE, + STORAGE_TYPES, + TensorProperties, + TensorStorageMetadata, +) +from core.distributed.checkpoint.planner import LoadItemType, LoadPlan, LoadPlanner +from core.distributed.checkpoint.planner_helpers import _create_chunk_list +from core.distributed.checkpoint.state_dict_loader import _load_state_dict +from core.distributed.checkpoint.state_dict_saver import _save_state_dict +from core.distributed.checkpoint.storage import StorageReader +from core.futures import Future + + +__all__ = [ + "dcp_to_torch_save", + "torch_save_to_dcp", + "BroadcastingTorchSaveReader", + "DynamicMetaLoadPlanner", +] + + +class BroadcastingTorchSaveReader(StorageReader): + """ + StorageReader for reading a Torch Save file. This reader will read the entire checkpoint + on the coordinator rank, and then broadcast and shard each tensor to all ranks. + + . N.B. Intended to be used with DynamicMetaLoadPlanner + + .. warning:: + Current implementation only supports loading Tensors. + + >>> # xdoctest: +SKIP("undefined vars") + >>> sd = {"mode": model} + >>> dcp.load( + >>> sd, + >>> storage_reader=BroadcastingTorchSaveReader(), + >>> planner=DynamicMetaLoadPlanner(), + >>> checkpoint_id="path_to_model.pt" + >>> ) + """ + + def __init__( + self, + checkpoint_id: Optional[Union[str, os.PathLike]] = None, + coordinator_rank: int = 0, + ) -> None: + self.checkpoint_id = checkpoint_id + self.coordinator_rank = coordinator_rank + + def read_metadata(self) -> Metadata: + """Extends the default StorageReader to support building the metadata file""" + # Metadata is built in planner.set_up_planner, since we are not actually reading metadata from + # the disk + return Metadata(state_dict_metadata={}) + + def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]: + """ + Reads torch save data on the coordinator rank, and broadcast afterwards + this incurrs a communication cost, but avoids having to load + the entire checkpoint on each rank, hopefully preventing OOM issues + """ + planner = cast(DefaultLoadPlanner, planner) + + # data is read in on the coordinator rank, and broadcast afterwards + # this incurrs a communication cost, but it avoids having to load + # the entire checkpoint on each rank, hopefully preventing OOM issues + # TODO: read on each host, instead of only the coordinator + if self.is_coordinator: + assert self.checkpoint_id is not None + torch_state_dict = core.load( + self.checkpoint_id, map_location="cpu", weights_only=False + ) + if planner.flatten_state_dict: + torch_state_dict, _ = flatten_state_dict(torch_state_dict) + else: + torch_state_dict = None + + for req in plan.items: + if req.type == LoadItemType.BYTE_IO: + raise RuntimeError( + f"Non-tensor value identified at {req.storage_index.fqn}. " + f"At this time {type(self).__name__} only supports loading Tensors." + ) + + # Broadcast the tensor from the coordinator rank + if self.is_coordinator: + pg_device = dist.distributed_c10d._get_pg_default_device() + tensor = torch_state_dict[req.storage_index.fqn].to(pg_device) + else: + tensor = core.empty_like(planner.state_dict[req.storage_index.fqn]) + + dist.broadcast(tensor, src=self.coordinator_rank, async_op=False) + + tensor = narrow_tensor_by_index(tensor, req.storage_offsets, req.lengths) + target_tensor = planner.resolve_tensor(req).detach() + assert target_tensor.size() == tensor.size(), ( + f"req {req.storage_index} mismatch sizes, " + f"{target_tensor.size()} vs {tensor.size()}" + ) + target_tensor.copy_(tensor) + planner.commit_tensor(req, target_tensor) + + fut: Future = Future() + fut.set_result(None) + return fut + + def set_up_storage_reader(self, metadata: Metadata, is_coordinator: bool) -> None: + """Implementation of the StorageReader method""" + self.is_coordinator = is_coordinator + if self.is_coordinator: + assert dist.get_rank() == self.coordinator_rank + + assert self.checkpoint_id is not None + + def prepare_local_plan(self, plan: LoadPlan) -> LoadPlan: + """Implementation of the StorageReader method""" + return plan + + def prepare_global_plan(self, global_plan: List[LoadPlan]) -> List[LoadPlan]: + """Implementation of the StorageReader method""" + return global_plan + + def reset(self, checkpoint_id: Union[str, os.PathLike, None] = None) -> None: + """Implementation of the StorageReader method""" + self.checkpoint_id = checkpoint_id + + @classmethod + def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool: + """Implementation of the StorageReader method""" + return os.path.isfile(checkpoint_id) + + +class DynamicMetaLoadPlanner(DefaultLoadPlanner): + """ + Extension of DefaultLoadPlanner, which creates a new Metadata object based on the passed in state dict, + avoiding the need to read metadata from disk. This is useful when reading formats which don't have a + metadata file, like Torch Save files. + + . N.B. Intended to be used with BroadcastingTorchSaveReader + + .. warning:: + Current implementation only supports loading Tensors. + + >>> # xdoctest: +SKIP("undefined vars") + >>> sd = {"mode": model} + >>> dcp.load( + >>> sd, + >>> storage_reader=BroadcastingTorchSaveReader(), + >>> planner=DynamicMetaLoadPlanner(), + >>> checkpoint_id="path_to_model.pt" + >>> ) + """ + + def set_up_planner( + self, + state_dict: STATE_DICT_TYPE, + metadata: Optional[Metadata] = None, + is_coordinator: bool = False, + ) -> None: + """Setups of the planner, extnding default behavior by creating the Metadata object from the state dict""" + super().set_up_planner(state_dict, metadata, is_coordinator) + + state_dict_metadata: Dict[str, STORAGE_TYPES] = {} + for key, tensor in self.state_dict.items(): + if not core.is_tensor(tensor): + raise RuntimeError( + f"Non-tensor value identified at {key}. " + f"At this time {type(self).__name__} only supports loading Tensors." + ) + + state_dict_metadata[key] = TensorStorageMetadata( + TensorProperties(dtype=tensor.dtype), + tensor.size(), + _create_chunk_list(tensor), + ) + self.metadata = Metadata(state_dict_metadata=state_dict_metadata) + + +def dcp_to_torch_save( + dcp_checkpoint_dir: Union[str, os.PathLike], + torch_save_path: Union[str, os.PathLike], +): + """ + Given a directory containing a DCP checkpoint, this function will convert it into a + Torch save file. + + Args: + dcp_checkpoint_dir: Directory containing the DCP checkpoint. + torch_save_path: Filename to store the converted Torch save file. + + .. warning:: + To avoid OOM, it's recommended to only run this function on a single rank. + """ + sd: STATE_DICT_TYPE = {} + _load_state_dict( + sd, + storage_reader=FileSystemReader(dcp_checkpoint_dir), + planner=_EmptyStateDictLoadPlanner(), + no_dist=True, + ) + core.save(sd, torch_save_path) + + +def torch_save_to_dcp( + torch_save_path: Union[str, os.PathLike], + dcp_checkpoint_dir: Union[str, os.PathLike], +): + """ + Given the location of a torch save file, converts it into a DCP checkpoint. + + Args: + torch_save_path: Filename of the Torch save file. + dcp_checkpoint_dir: Directory to store the DCP checkpoint. + + .. warning:: + To avoid OOM, it's recommended to only run this function on a single rank. + """ + + state_dict = core.load(torch_save_path, weights_only=False) + # we don't need stateful behavior here because the expectation is anything loaded by + # core.load would not contain stateful objects. + _save_state_dict( + state_dict, storage_writer=FileSystemWriter(dcp_checkpoint_dir), no_dist=True + ) + + +if __name__ == "__main__": + + class FormatMode(Enum): + TORCH_TO_DCP = "torch_to_dcp" + DCP_TO_TORCH = "dcp_to_torch" + + # Parse command-line arguments + parser = argparse.ArgumentParser() + parser.add_argument( + "mode", + type=str, + help="Conversion mode", + choices=[m.value for m in FormatMode], + default=FormatMode.TORCH_TO_DCP, + ) + parser.add_argument("src", type=str, help="Path to the source model") + parser.add_argument("dst", type=str, help="Path to the destination model") + args = parser.parse_args() + + print( + f"Converting checkpoint from {args.src} to {args.dst} using method: '{args.mode}'" + ) + checkpoint_missing_warning = ( + f"No checkpoint found at {args.src}. Skipping conversion." + ) + if args.mode == FormatMode.TORCH_TO_DCP.value: + if os.path.isfile(args.src): + torch_save_to_dcp(args.src, args.dst) + else: + print(checkpoint_missing_warning) + elif args.mode == FormatMode.DCP_TO_TORCH.value: + if os.path.isdir(args.src): + dcp_to_torch_save(args.src, args.dst) + else: + print(checkpoint_missing_warning) + else: + raise ValueError(f"Unknown conversion mode: {args.mode}") diff --git a/mindnlp/core/distributed/checkpoint/logger.py b/mindnlp/core/distributed/checkpoint/logger.py new file mode 100644 index 000000000..0964b43fc --- /dev/null +++ b/mindnlp/core/distributed/checkpoint/logger.py @@ -0,0 +1,103 @@ +# mypy: allow-untyped-defs +import functools +import time +from typing import Any, Callable, Dict, List, TypeVar +from typing_extensions import ParamSpec +from uuid import uuid4 + +from mindnlp import core.distributed.c10d_logger as c10d_logger +from core.distributed.checkpoint.logging_handlers import DCP_LOGGER_NAME + + +__all__: List[str] = [] + +global _dcp_logger +_dcp_logger = c10d_logger._get_or_create_logger(DCP_LOGGER_NAME) + +_T = TypeVar("_T") +_P = ParamSpec("_P") + + +def _msg_dict_from_dcp_method_args(*args, **kwargs) -> Dict[str, Any]: + """ + Extracts log data from dcp method args + """ + msg_dict = {} + + # checkpoint ID can be passed in through the serializer or through the checkpoint id directly + storage_writer = kwargs.get("storage_writer", None) + storage_reader = kwargs.get("storage_reader", None) + planner = kwargs.get("planner", None) + + checkpoint_id = kwargs.get("checkpoint_id", None) + if not checkpoint_id and (serializer := storage_writer or storage_reader): + checkpoint_id = getattr(serializer, "checkpoint_id", None) + + msg_dict["checkpoint_id"] = ( + str(checkpoint_id) if checkpoint_id is not None else checkpoint_id + ) + + # Uniquely identify a _dcp_method_logger wrapped function call. + msg_dict["uuid"] = str(uuid4().int) + + if storage_writer: + msg_dict["storage_writer"] = storage_writer.__class__.__name__ + + if storage_reader: + msg_dict["storage_reader"] = storage_reader.__class__.__name__ + + if planner: + msg_dict["planner"] = planner.__class__.__name__ + + return msg_dict + + +def _get_msg_dict(func_name, *args, **kwargs) -> Dict[str, Any]: + msg_dict = _msg_dict_from_dcp_method_args(*args, **kwargs) + msg_dict.update(c10d_logger._get_msg_dict(func_name, *args, **kwargs)) + + return msg_dict + + +def _dcp_method_logger( + log_exceptions: bool = False, **wrapper_kwargs: Any +) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: # pyre-ignore + """This method decorator logs the start, end, and exception of wrapped events.""" + + def decorator(func: Callable[_P, _T]): + @functools.wraps(func) + def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _T: + msg_dict = _get_msg_dict( + func.__name__, *args, **{**wrapper_kwargs, **kwargs} + ) + + # log start event + msg_dict["event"] = "start" + t0 = time.time_ns() + msg_dict["time"] = t0 + msg_dict["log_exceptions"] = log_exceptions + _dcp_logger.debug(msg_dict) + + # exceptions + try: + result = func(*args, **kwargs) + except BaseException as error: + if log_exceptions: + msg_dict["event"] = "exception" + msg_dict["error"] = f"{error}" + msg_dict["time"] = time.time_ns() + _dcp_logger.error(msg_dict) + raise + + # end event + msg_dict["event"] = "end" + t1 = time.time_ns() + msg_dict["time"] = time.time_ns() + msg_dict["times_spent"] = t1 - t0 + _dcp_logger.debug(msg_dict) + + return result + + return wrapper + + return decorator diff --git a/mindnlp/core/distributed/checkpoint/logging_handlers.py b/mindnlp/core/distributed/checkpoint/logging_handlers.py new file mode 100644 index 000000000..678315976 --- /dev/null +++ b/mindnlp/core/distributed/checkpoint/logging_handlers.py @@ -0,0 +1,15 @@ +import logging +from typing import List + +from core.distributed.logging_handlers import _log_handlers + + +__all__: List[str] = [] + +DCP_LOGGER_NAME = "dcp_logger" + +_log_handlers.update( + { + DCP_LOGGER_NAME: logging.NullHandler(), + } +) diff --git a/mindnlp/core/distributed/checkpoint/metadata.py b/mindnlp/core/distributed/checkpoint/metadata.py new file mode 100644 index 000000000..555656c26 --- /dev/null +++ b/mindnlp/core/distributed/checkpoint/metadata.py @@ -0,0 +1,182 @@ +# mypy: allow-untyped-defs +import os +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Dict, List, Optional, Sequence, Union + +from mindnlp import core +from core.distributed.checkpoint.stateful import StatefulT + + +__all__ = [ + "ChunkStorageMetadata", + "TensorStorageMetadata", + "BytesStorageMetadata", + "Metadata", + "MetadataIndex", + "TensorProperties", + "StorageMeta", +] + + +@dataclass +class ChunkStorageMetadata: + """ + Each chunk is expected to have the same properties of the TensorStorageMetadata + that includes it. + """ + + offsets: core.Size + sizes: core.Size + + +class _MEM_FORMAT_ENCODING(Enum): + """Describe the memory format of a tensor.""" + + TORCH_CONTIGUOUS_FORMAT = 0 + TORCH_CHANNELS_LAST = 1 + TORCH_PRESERVE_FORMAT = 2 + + +@dataclass +class TensorProperties: + """Properties used to create :class:`Tensor`""" + + # Regular tensor fields + dtype: core.dtype = field(default_factory=core.get_default_dtype) + # This field is deprecated. + # layout: core.layout = field(default=core.strided) + # This field is deprecated. + requires_grad: bool = False + # This field is deprecated. + # memory_format: core.memory_format = field(default=core.contiguous_format) + # This field is deprecated. + pin_memory: bool = False + + def __getstate__(self): + # Since core.memory_format cannot be pickled! + # memory_format = self.memory_format + # if memory_format == core.contiguous_format: + # mem_format_encoding = _MEM_FORMAT_ENCODING.TORCH_CONTIGUOUS_FORMAT + # elif memory_format == core.channels_last: + # mem_format_encoding = _MEM_FORMAT_ENCODING.TORCH_CHANNELS_LAST + # elif memory_format == core.preserve_format: + # mem_format_encoding = _MEM_FORMAT_ENCODING.TORCH_PRESERVE_FORMAT + # else: + # raise RuntimeError(f"Invalid core.memory_format: {memory_format}") + + return ( + self.dtype, + # self.layout, + self.requires_grad, + # mem_format_encoding, + self.pin_memory, + ) + + def __setstate__( + self, + state, + ): + ( + self.dtype, + # self.layout, + self.requires_grad, + # mem_format_encoding, + self.pin_memory, + ) = state + + # if mem_format_encoding == _MEM_FORMAT_ENCODING.TORCH_CONTIGUOUS_FORMAT: + # memory_format = core.contiguous_format + # elif mem_format_encoding == _MEM_FORMAT_ENCODING.TORCH_CHANNELS_LAST: + # memory_format = core.channels_last + # elif mem_format_encoding == _MEM_FORMAT_ENCODING.TORCH_PRESERVE_FORMAT: + # memory_format = core.preserve_format + # else: + # raise RuntimeError( + # f"Invalid core.memory_format encoding: {mem_format_encoding}" + # ) + + # self.memory_format = memory_format + + @staticmethod + def create_from_tensor(tensor: core.Tensor) -> "TensorProperties": + return TensorProperties( + dtype=tensor.dtype, + # layout=tensor.layout, + requires_grad=tensor.requires_grad, + # memory_format=core.contiguous_format, + # pin_memory=tensor.is_pinned(), + ) + + +@dataclass +class TensorStorageMetadata: + properties: TensorProperties + size: core.Size + chunks: List[ChunkStorageMetadata] + + +@dataclass +class BytesStorageMetadata: + pass + + +STORAGE_TYPES = Union[TensorStorageMetadata, BytesStorageMetadata] +STATE_DICT_TYPE = Dict[str, Union[StatefulT, Any]] + + +@dataclass +class StorageMeta: + checkpoint_id: Union[str, os.PathLike, None] = None + save_id: Optional[str] = None + load_id: Optional[str] = None + + +@dataclass +class Metadata: + """This class represents the metadata of the checkpoint.""" + + # Keys are the same from the `state_dict` used. + state_dict_metadata: Dict[str, STORAGE_TYPES] + # It is the responsibility of the planner and storage plugins to ensure + # backward compatibility of the planner_data and storage_data. DCP will + # also ensure the backward compatibility of the metadata in this file and + # the metadata of the built-in planner and storage plugins. + planner_data: Any = None + storage_data: Any = None + storage_meta: Optional[StorageMeta] = None + + +@dataclass(frozen=True) +class MetadataIndex: + """This class represents a lookup key for items in a state dict or Metadata.""" + + fqn: str + """Fully Qualified Name of the object""" + + offset: Optional[core.Size] = None + """If the object is a tensor, offset into the tensor we're looking for""" + + index: Optional[int] = field(hash=False, compare=False, default=None) + """ + Index hint when searching for tensor chunk to speedup lookups (optional) + + A common representation of a sharded tensor is as a list of chunks so to + find the index in such a list you need to linear search it. + + When constructing an instance of MetadataIndex that points to that list, + one can provide the index as a hint and it will be probed first before + the linear search and thus making it significantly faster. + """ + + def __init__( + self, + fqn: str, + offset: Optional[Sequence[int]] = None, + index: Optional[int] = None, + ): + # We must use object.__setattr__ due to frozen=True + object.__setattr__(self, "fqn", fqn) + object.__setattr__(self, "index", index) + if offset is not None: + object.__setattr__(self, "offset", core.Size(offset)) diff --git a/mindnlp/core/distributed/checkpoint/optimizer.py b/mindnlp/core/distributed/checkpoint/optimizer.py new file mode 100644 index 000000000..abf2430db --- /dev/null +++ b/mindnlp/core/distributed/checkpoint/optimizer.py @@ -0,0 +1,356 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates + +import dataclasses +from typing import cast, Dict, List, Optional, Sequence, Tuple, Union + +from mindnlp import core +from mindnlp import core.distributed as dist +from core._utils import _get_device_module +from core.distributed._shard.sharded_tensor.api import ShardedTensor +from core.distributed._shard.sharded_tensor.metadata import ( + TensorProperties as ShardTensorProperties, +) +from core.distributed._shard.sharded_tensor.shard import Shard +from core.distributed._shard.sharding_spec.chunk_sharding_spec import ChunkShardingSpec +from core.distributed.checkpoint._nested_dict import unflatten_state_dict +from core.distributed.checkpoint.default_planner import DefaultLoadPlanner +from core.distributed.checkpoint.metadata import ( + BytesStorageMetadata, + ChunkStorageMetadata, + Metadata, + MetadataIndex, + STATE_DICT_TYPE, + TensorProperties, + TensorStorageMetadata, +) +from core.distributed.checkpoint.planner import LoadPlan, LoadPlanner +from core.distributed.checkpoint.planner_helpers import ( + _create_read_items, + create_read_items_for_chunk_list, +) +from core.distributed.checkpoint.state_dict_loader import load_state_dict +from core.distributed.checkpoint.storage import StorageReader +from core.distributed.checkpoint.utils import ( + _element_wise_add, + _element_wise_sub, + _normalize_device_info, +) +from core.distributed.distributed_c10d import _get_default_group +from core.distributed.fsdp._shard_utils import _create_chunk_sharded_tensor +from core.distributed.remote_device import _remote_device +from core.distributed.tensor import DTensor + + +STATE_DICT_2D_LAYOUT = Dict[str, Tuple[Optional[Sequence[int]], Sequence[int]]] + + +# TODO: Update docstrings for optimizer.py +__all__ = [ + "load_sharded_optimizer_state_dict", +] + + +def _gen_rank_device(global_rank: int, device_type: str = "cuda") -> str: + if device_type == "cpu": + return "cpu" + device_module = _get_device_module(device_type) + if device_module.is_available(): + return _normalize_device_info( + device_type, global_rank % device_module.device_count() + ) + return "cpu" + + +def _create_colwise_spec( + pg: Optional[dist.ProcessGroup] = None, +) -> ChunkShardingSpec: + pg_device_type = dist.distributed_c10d._get_pg_default_device(pg).type + if pg is None: + placements = [ + f"rank:{idx}/{_gen_rank_device(idx, pg_device_type)}" + for idx in range(dist.get_world_size()) + ] + else: + placements = [ + f"rank:{idx}/{_gen_rank_device(dist.get_global_rank(pg, idx), pg_device_type)}" + for idx in range(pg.size()) + ] + return ChunkShardingSpec( + dim=0, + placements=cast(List[Union[_remote_device, str]], placements), + ) + + +def _is_nested_tensor(val: core.Tensor) -> bool: + if type(val) is ShardedTensor: + if len(val.local_shards()) == 0: + return False + if type(val.local_shards()[0].tensor) is ShardedTensor: + return True + if type(val.local_shards()[0].tensor) is DTensor: + raise ValueError("Cannot handle DTensor nested insided ShardedTensor") + elif type(val) is DTensor and ( + type(val._local_tensor) is DTensor or type(val._local_tensor) is ShardedTensor + ): + raise ValueError("Cannot handle nested DTensor") + return False + + +def _alloc_tensor( + props: TensorProperties, size: Sequence[int], device_type: str = "cuda" +) -> core.Tensor: + if device_type == "cpu": + device = cast(core.device, _get_device_module(device_type).current_device()) + else: + device = core.device( + device_type, _get_device_module(device_type).current_device() + ) + + return core.empty( + size=size, + dtype=props.dtype, + layout=props.layout, + requires_grad=props.requires_grad, + pin_memory=props.pin_memory, + device=device, + ) + + +def _get_state_dict_2d_layout( + state_dict: STATE_DICT_TYPE, +) -> Tuple[STATE_DICT_2D_LAYOUT, Optional[dist.ProcessGroup]]: + """ + Load the right TP slice of the optimizer state. + + This is not easy since the per-tensor slicing can't be inferred from checkpoint metadata. + We take advantage of the model state_dict producing a sliced ST to figure out what we need to load. + This is pretty fragile and it might be easier for FSDP to compute this info for us. + Returns a dictionary where keys are the same of the state_dict and the value is a tuple of + (offset, size) for the current rank TP slice. + N.B. The state_dict *MUST* come from FSDP.sharded_state_dict. + """ + specs: STATE_DICT_2D_LAYOUT = {} + dp_pg: Optional[dist.ProcessGroup] = None + for key, value in state_dict.items(): + specs[key] = (None, value.size()) + if _is_nested_tensor(value): + assert ( + len(value.local_shards()) == 1 + ), "Cannot handle ST with multiple shards" + assert isinstance( + value, ShardedTensor + ), "Can only handle nested ShardedTensor" + shard = value.local_shards()[0] + specs[key] = ( + shard.metadata.shard_offsets, + shard.metadata.shard_sizes, + ) + dp_pg = shard.tensor._process_group # type: ignore[attr-defined] + + return ( + specs, + dp_pg, + ) + + +class _ReaderWithOffset(DefaultLoadPlanner): + translation: Dict[MetadataIndex, MetadataIndex] + state_dict: STATE_DICT_TYPE + metadata: Metadata + + def __init__(self, fqn_to_offset: Dict[str, Sequence[int]]) -> None: + super().__init__() + self.fqn_to_offset = fqn_to_offset + self.metadata = Metadata({}) + self.state_dict = {} + self.translation = {} + + def create_local_plan(self) -> LoadPlan: + requests = [] + self.translation = {} + for fqn, obj in self.state_dict.items(): + md = self.metadata.state_dict_metadata[fqn] + if not isinstance(obj, ShardedTensor): + requests += _create_read_items(fqn, md, obj) + continue + + if fqn not in self.fqn_to_offset: + requests += _create_read_items(fqn, md, obj) + continue + + offset = self.fqn_to_offset[fqn] + + assert len(obj.local_shards()) == 1 + original_shard = obj.local_shards()[0] + local_chunks = [ + ChunkStorageMetadata( + offsets=core.Size( + _element_wise_add(original_shard.metadata.shard_offsets, offset) + ), + sizes=core.Size(original_shard.metadata.shard_sizes), + ) + ] + + reqs = create_read_items_for_chunk_list( + fqn, cast(TensorStorageMetadata, md), local_chunks + ) + # TODO: The ReadItems will have a displaced MetadataIndex, fix it. + # TODO: we should change _create_sharded_read_items to have more ergonomic API + for ri in reqs: + assert ri.dest_index.offset is not None + original_offset = _element_wise_sub(ri.dest_index.offset, offset) + original_index = dataclasses.replace( + ri.dest_index, offset=core.Size(original_offset) + ) + self.translation[ri.dest_index] = original_index + + requests += reqs + return LoadPlan(requests) + + def lookup_tensor(self, index: MetadataIndex) -> core.Tensor: + return super().lookup_tensor(self.translation.get(index, index)) + + +def load_sharded_optimizer_state_dict( + model_state_dict: STATE_DICT_TYPE, + optimizer_key: str, + storage_reader: StorageReader, + planner: Optional[LoadPlanner] = None, +) -> STATE_DICT_TYPE: + """ + Load a state_dict in conjunction with FSDP sharded optimizer state. + + This is the current recommended way to checkpoint FSDP. + >>> # xdoctest: +SKIP + >>> from mindnlp import core.distributed.checkpoint as dist_cp + >>> # Save + >>> model: core.nn.Model + >>> optim_params = model.parameters() + >>> optim = core.optim.SGD(optim_params, lr=0.01) + >>> # Save + >>> with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): + >>> state_dict = { + >>> "optimizer": FSDP.optim_state_dict(model, optim), + >>> "model": model.state_dict() + >>> } + >>> dist_cp.save_state_dict( + >>> state_dict=optim_state, + >>> storage_writer=dist_cp.FileSystemWriter("checkpoint"), + >>> planner=dist_cp.DefaultSavePlanner(), + >>> ) + >>> + >>> # Load + >>> with FSDP.state_dict_type(model_tp, StateDictType.SHARDED_STATE_DICT): + >>> model_state_dict = model_tp.state_dict() + >>> checkpoint = { + >>> "model": model_state_dict + >>> } + >>> dist_cp.load_state_dict( + >>> state_dict=checkpoint, + >>> storage_reader=dist_cp.FileSystemReader(checkpoint_file), + >>> planner=dist_cp.DefaultLoadPlanner(), + >>> ) + >>> model.load_state_dict(checkpoint["model_state"]) + >>> + >>> optim_state = dist_cp.load_sharded_optimizer_state_dict( + >>> model_state_dict, + >>> optimizer_key="optimizer", + >>> storage_reader=dist_cp.FileSystemReader("checkpoint"), + >>> ) + >>> + >>> flattened_osd = FSDP.optim_state_dict_to_load( + >>> model, optim, optim_state["optimizer"] + >>> ) + >>> + >>> optim.load_state_dict(flattened_osd) + """ + metadata = storage_reader.read_metadata() + + layout_specs, dp_pg = _get_state_dict_2d_layout(model_state_dict) + dp_pg_device_type = dist.distributed_c10d._get_pg_default_device(dp_pg).type + device_module = _get_device_module(dp_pg_device_type) + + if dp_pg is None: + placements = [] + for i in range(dist.get_world_size()): + device_info = _normalize_device_info( + dp_pg_device_type, i % device_module.device_count() + ) + placements.append(f"rank:{i}/{device_info}") + sharding_spec = ChunkShardingSpec(dim=0, placements=placements) # type: ignore[arg-type] + else: + sharding_spec = _create_colwise_spec(dp_pg) + + # Create a state_dict for optimizer state + state_dict: STATE_DICT_TYPE = {} + + fqn_to_offset: Dict[str, Sequence[int]] = {} + for key, value in metadata.state_dict_metadata.items(): + key_path = metadata.planner_data[key] + if key_path[0] != optimizer_key: + continue + + if isinstance(value, BytesStorageMetadata): + state_dict[key] = "" + continue + + # value: TensorStorageMetadata + if value.size.numel() == 1: + state_dict[key] = _alloc_tensor( + value.properties, value.size, dp_pg_device_type + ) + elif dp_pg is None: + state_dict[key] = _create_chunk_sharded_tensor( + _alloc_tensor(value.properties, value.size, dp_pg_device_type), + rank=dist.get_rank(), + world_size=dist.get_world_size(), + num_devices_per_node=device_module.device_count(), + pg=_get_default_group(), + ) + else: + spec_key = key_path[2] + alloc_size = layout_specs.get(spec_key, (None, value.size))[1] + + properties = ShardTensorProperties( + dtype=value.properties.dtype, + layout=value.properties.layout, + requires_grad=value.properties.requires_grad, + memory_format=value.properties.memory_format, + pin_memory=value.properties.pin_memory, + ) + + st_md = sharding_spec.build_metadata(core.Size(alloc_size), properties) + local_shards = [] + current_rank = dist.get_rank(dp_pg) + for shard_md in st_md.shards_metadata: + if cast(_remote_device, shard_md.placement).rank() != current_rank: + continue + local_shards.append( + Shard( + tensor=_alloc_tensor( + value.properties, shard_md.shard_sizes, dp_pg_device_type + ), + metadata=shard_md, + ) + ) + + st = ShardedTensor._init_from_local_shards_and_global_metadata( + local_shards, st_md, process_group=dp_pg + ) + + if spec_key in layout_specs and layout_specs[spec_key][0] is not None: + fqn_to_offset[key] = cast(Sequence[int], layout_specs[spec_key][0]) + + state_dict[key] = st + + # Whether we unflatten before or after doesn't matter + load_state_dict( + state_dict=state_dict, + storage_reader=storage_reader, + # FIXME the type of planner is wrong in load_state_dict + planner=_ReaderWithOffset(fqn_to_offset) if dp_pg is not None else planner, + ) + + state_dict = unflatten_state_dict(state_dict, metadata.planner_data) + + return state_dict diff --git a/mindnlp/core/distributed/checkpoint/planner.py b/mindnlp/core/distributed/checkpoint/planner.py new file mode 100644 index 000000000..66b0ede13 --- /dev/null +++ b/mindnlp/core/distributed/checkpoint/planner.py @@ -0,0 +1,417 @@ +import abc +import io +import operator +from dataclasses import dataclass +from enum import auto, Enum +from functools import reduce +from typing import Any, List, Optional, Tuple, Union + +from mindnlp import core +from core.distributed.checkpoint.metadata import ( + ChunkStorageMetadata, + Metadata, + MetadataIndex, + STATE_DICT_TYPE, + StorageMeta, + TensorProperties, +) + + +__all__ = [ + "WriteItemType", + "LoadItemType", + "TensorWriteData", + "WriteItem", + "ReadItem", + "SavePlan", + "LoadPlan", + "SavePlanner", + "LoadPlanner", +] + + +class WriteItemType(Enum): + TENSOR = auto() + SHARD = auto() + BYTE_IO = auto() + + +class LoadItemType(Enum): + TENSOR = auto() + BYTE_IO = auto() + + +@dataclass(frozen=True) +class TensorWriteData: + chunk: ChunkStorageMetadata + properties: TensorProperties + size: core.Size + + +@dataclass(frozen=True) +class WriteItem: + """Dataclass which holds information about what needs to be written to storage.""" + + index: MetadataIndex + type: WriteItemType + + # Value present if it's a tensor write + tensor_data: Optional[TensorWriteData] = None + + def tensor_storage_size(self) -> Optional[int]: + """ + Calculates the storage size of the underlying tensor, or None if this is not a tensor write. + + Returns: + Optional[int] storage size, in bytes of underlying tensor if any. + """ + if self.tensor_data is None: + return None + + numels = reduce(operator.mul, self.tensor_data.size, 1) + dtype_size = core._utils._element_size(self.tensor_data.properties.dtype) + return numels * dtype_size + + +@dataclass(frozen=True) +class ReadItem: + # Read Item + type: LoadItemType + + # Index into the state_dict + dest_index: MetadataIndex + # Offsets into destination tensor + dest_offsets: core.Size + + # Index into the checkpoint + storage_index: MetadataIndex + # Offset into the checkpoint data + storage_offsets: core.Size + + # Size of the hypercube to copy + lengths: core.Size + + +@dataclass(frozen=True) +class SavePlan: + items: List[WriteItem] + storage_data: Any = None + planner_data: Any = None + + +@dataclass +class LoadPlan: + items: List[ReadItem] + storage_data: Any = None + planner_data: Any = None + + +class SavePlanner(abc.ABC): + """ + Abstract class defining the protocol used by save_state_dict to plan the save process. + + SavePlanners are stateful objects that can be used to customize the whole save process. + + SavePlanner acts as an access proxy to the state_dict, so any transformation done to it + will be visible to the whole process. + + A planner subclass can expect the following sequence of calls during save_state_dict: + + 1) set_up_planner - called on all ranks. + Signals the start of a checkpoint save. + + 2) create_local_plan - called on all ranks. + Process the state_dict and produces a `SavePlan` that will be sent for global planning. + + 3) create_global_plan - called on the coordinator rank only. + Takes the SavePlan from all ranks and make any global decision. + + 4) finish_plan - called on all ranks. + This gives each rank a chance to adjust to global planning decisions. + + 5) resolve_data - called multiple times on each rank + Lookups a value on the `state_dict` for the storage layer to write. + + Users are recommended to extend DefaultSavePlanner instead of this interface directly as + most changes can be expressed by changes in a single method. + + There are 3 usual patterns of extension: + + Rewriting state_dict. This is the simplest way to extend the save process as it + doesn't requite understanding the intrincacies of how SavePlan works: + + >>> # xdoctest: +SKIP("undefined vars") + >>> class RenamePlanner(DefaultSavePlanner): + >>> def set_up_planner( + >>> self, + >>> state_dict: STATE_DICT_TYPE, + >>> storage_meta: Optional[StorageMeta], + >>> is_coordinator: bool, + >>> ) -> None: + >>> # prefix all keys with `foo_`` + >>> super().set_up_planner({"foo_" + k: v for k, v in state_dict.items()}, storage_meta, is_coordinator) + + Modifying local plan and lookup in tandem. This is useful when fine control of how data is persisted + + >>> # xdoctest: +SKIP("undefined vars") + >>> class FP16Planner(DefaultSavePlanner): + >>> def create_local_plan(self): + >>> plan = super().create_local_plan() + >>> for p in plan: + >>> if p.tensor_data is not None: + >>> p.tensor_data.properties.dtype = core.float16 + >>> return plan + >>> + >>> def resolve_data(self, write_item): + >>> item = super().resolve_data(write_item) + >>> return item if write_item.type == WriteItemType.BYTE_IO else item.to(core.float16) + + Using the global planning step to make central decisions that can't be made individually by each rank + + >>> # xdoctest: +SKIP("undefined vars") + >>> from itertools import zip_longest + >>> from dataclasses import replace + >>> class DDPLoadBalancingPlanner(DefaultSavePlanner): + >>> # This uses the default local plan behavior of having all non-sharded writes in rank 0 + >>> # This sample doesn't handle ShardedTensors + >>> def create_global_plan(self, all_plans): + >>> iters = [iter(all_plans[0].items)] * len(all_plans) + >>> items_per_rank = [ + >>> [item for item in items if item is not None] + >>> for items in zip(*zip_longest(*iters), strict=True) + >>> ] + >>> all_plans = [ + >>> replace(plan, items=items) + >>> for plan, items in zip(all_plans, items_per_rank, strict=True) + >>> ] + >>> return super().create_global_plan(all_plans) + + Finally, some planners need to save additional metadata in the checkpoint, this is + accomplished by having each rank contribute their data items in the local plan and + the global planner aggregate them: + + >>> # xdoctest: +SKIP("undefined vars") + >>> class SaveExtraDataPlanner(DefaultSavePlanner): + >>> def create_local_plan(self) -> SavePlan: + >>> plan = super().create_local_plan() + >>> return replace(plan, planner_data="per-rank-data") + >>> + >>> def create_global_plan(self, all_plans: List[SavePlan]) -> Tuple[List[SavePlan], Metadata]: + >>> global_plan, metadata = super().create_global_plan(all_plans) + >>> merged_data = [p.planner_data for p in global_plan] + >>> metadata = replace(metadata, planner_data=merged_data) + >>> return global_plan, metadata + """ + + @abc.abstractmethod + def set_up_planner( + self, + state_dict: STATE_DICT_TYPE, + storage_meta: Optional[StorageMeta] = None, + is_coordinator: bool = False, + ) -> None: + """ + Initialize this planner to save ``state_dict``. + + Implementations should save those values as they won't be provided lated in the save process. + + This is called on all ranks. + """ + + @abc.abstractmethod + def create_local_plan(self) -> SavePlan: + """ + Compute the save plan for the current rank. + + This will be aggregated and passed to create_global_plan. + Planner specific data can be passed through SavePlan::planner_data. + + This is called on all ranks. + """ + + @abc.abstractmethod + def create_global_plan( + self, all_plans: List[SavePlan] + ) -> Tuple[List[SavePlan], Metadata]: + """ + Compute the global checkpoint plan and return the local plan of each rank. + + This is called on the coordinator rank only. + """ + + @abc.abstractmethod + def finish_plan(self, new_plan: SavePlan) -> SavePlan: + """ + Merge the plan created by `create_local_plan` and the result of `create_global_plan`. + + This is called on all ranks. + """ + + @abc.abstractmethod + def resolve_data(self, write_item: WriteItem) -> Union[core.Tensor, io.BytesIO]: + """ + Transform and prepare ``write_item`` from ``state_dict`` for storage, ensuring idempotency and thread-safety. + + Lookup the object associated with ``write_item`` in ``state_dict`` and apply any + transformation (such as serialization) prior to the storage layer consuming it. + + Called on each rank multiple times, at least once per WriteItem in the final SavePlan. + + This method should be idempotent and thread-save. StorageWriter implementations + are free to call it as frequently as they need. + + Any transformation that allocates memory should be lazily done when his method + is called in order to reduce peak memory required by checkpointing. + + When returning tensors, they can be on any device or format, they can be views too. + It's the storage layer responsibility to figure out how to save them. + """ + + +class LoadPlanner: + """ + Abstract class defining the protocol used by load_state_dict to plan the load process. + + LoadPlanner are stateful objects that can be used to customize the whole load process. + + LoadPlanner acts as an access proxy to the state_dict, so any transformation done to it + will be visible to the whole process. + + A planner subclass can expect the following sequence of calls during load_state_dict: + + 1) set_up_planner - called on all ranks. + Signals the start of loading a checkpoint. + + 2) create_local_plan - called on all ranks. + Process the state_dict and produces a `LoadPlan` that will be sent for global planning. + + 3) create_global_plan - called on the coordinator rank only. + Takes the LoadPlan from all ranks and make any global decision. + + 4) load_bytes - called multiple times on each rank + This is called once per non-tensor value in state_dict. + + 5) resolve_tensor and commit_tensor - called multiple times on each rank + They are called in pair for each Tensor value in state_dict. + + Users are recommended to extend DefaultLoadPlanner instead of this interface directly as + most changes can be expressed by changes in a single method. + + There are two usual patterns of extension: + + Rewriting state_dict. This is the simplest way to extend the load process as it + doesn't requite understanding the intrincacies of how LoadPlan works. We need + to keep a reference to the original state_dict as load happens in place so + we need to be able to perform it in place + + >>> # xdoctest: +SKIP("undefined vars") + >>> class RenamePlanner(DefaultLoadPlanner): + >>> def set_up_planner( + >>> self, + >>> state_dict: STATE_DICT_TYPE, + >>> metadata: Metadata, + >>> is_coordinator: bool, + >>> ) -> None: + >>> self.original_state_dict = state_dict + >>> state_dict = {"foo_" + k: v for k, v in state_dict.items()} + >>> + >>> if self.flatten_sharded_tensors: + >>> state_dict = _flatten_sharded_tensors(state_dict) + >>> + >>> if self.flatten_state_dict: + >>> state_dict, self.mappings = flatten_state_dict(state_dict) + >>> + >>> self.state_dict = state_dict + >>> self.metadata = metadata + >>> self.is_coordinator = is_coordinator + >>> + >>> def load_bytes(self, read_item, value): + >>> # Remove the "foo_" prefix + >>> self.original_state_dict[read_item.dest_index.fqn[4:]] = core.load(value, weights_only=False) + + + Modifying resolve_tensor and commit_tensor to handle load time transformation. + + >>> # xdoctest: +SKIP("undefined vars") + >>> class MetaModelMaterialize(DefaultSavePlanner): + >>> def resolve_tensor(self, read_item): + >>> tensor = super().resolve_tensor(read_item) + >>> return core.empty_like(tensor, device="cpu") + >>> + >>> def commit_tensor(self, read_item, tensor): + >>> self.state_dict[read_item.dest_index.fqn] = tensor + """ + + @abc.abstractmethod + def set_up_planner( + self, + state_dict: STATE_DICT_TYPE, + metadata: Optional[Metadata] = None, + is_coordinator: bool = False, + ) -> None: + """ + Initialize this instance to load data into ``state_dict``. + + . N.B. This is called on every rank. + """ + + @abc.abstractmethod + def create_local_plan(self) -> LoadPlan: + """ + Create a LoadPlan based on state_dict and metadata provided by set_up_planner. + + . N.B. This is called on every rank. + """ + + @abc.abstractmethod + def create_global_plan(self, global_plan: List[LoadPlan]) -> List[LoadPlan]: + """ + Compute the global load plan and return plans for each rank. + + . N.B. This is called on the coordinator rank only + """ + + @abc.abstractmethod + def finish_plan(self, central_plan: LoadPlan) -> LoadPlan: + """Accept the plan from coordinator and return final LoadPlan.""" + + @abc.abstractmethod + def load_bytes(self, read_item: ReadItem, value: io.BytesIO) -> None: + """ + Load the item described by ``read_item``and ``value``. + + This method is expected to modify in-place the underlying state_dict. + + The contents of ``value`` are defined by the SavePlanner used to produce + the checkpoint being loaded. + """ + + def resolve_bytes(self, read_item: ReadItem) -> io.BytesIO: + """ + Return the BytesIO to be used by the StorageReader to load `read_item`. + + The BytesIO should alias with one on the underlying state_dict as StorageReader will replace its contents. + """ + raise NotImplementedError("LoadPlanner.resolve_bytes is not implemented") + + @abc.abstractmethod + def resolve_tensor(self, read_item: ReadItem) -> core.Tensor: + """ + Return the tensor described by ``read_item`` to be used by the StorageReader to load `read_item`. + + The tensor should alias with one on the underlying state_dict as StorageReader will replace its contents. + If, for any reason, that's not possible, the planner can use the ``commit_tensor`` method to copy the data + back to the one in state_dict. + """ + + @abc.abstractmethod + def commit_tensor(self, read_item: ReadItem, tensor: core.Tensor) -> None: + """ + Call once the StorageReader finished loading data into ``tensor``. + + The provided tensor is the same one returned by the call to ``resolve_tensor``. + This method is only needed if this LoadPlanner needs to post process ``tensor`` prior to + copying it back to the one in the state_dict. + + The contents of tensor will follow its device synchronization model. + """ diff --git a/mindnlp/core/distributed/checkpoint/planner_helpers.py b/mindnlp/core/distributed/checkpoint/planner_helpers.py new file mode 100644 index 000000000..8065189c6 --- /dev/null +++ b/mindnlp/core/distributed/checkpoint/planner_helpers.py @@ -0,0 +1,389 @@ +# mypy: allow-untyped-defs +import io +from typing import Any, Callable, cast, Dict, List + +from mindnlp import core +from mindnlp import core.distributed as dist +# from core._utils import _get_device_module +from core.distributed._shard.metadata import ShardMetadata +from core.distributed._shard.sharded_tensor import ShardedTensor +# from core.distributed.tensor import DTensor +# from core.distributed.tensor._utils import compute_local_shape_and_global_offset + +from .metadata import ( + BytesStorageMetadata, + ChunkStorageMetadata, + MetadataIndex, + STATE_DICT_TYPE, + STORAGE_TYPES, + TensorProperties, + TensorStorageMetadata, +) +from .planner import ( + LoadItemType, + ReadItem, + SavePlan, + TensorWriteData, + WriteItem, + WriteItemType, +) +from .resharding import ( + _check_shard_metadata_pair_overlap, + _shards_get_overlap_region_wrt_saved_tensor, +) + + +__all__: List[str] = ["create_read_items_for_chunk_list"] + + +def _create_chunk_from_tensor(tensor: core.Tensor) -> ChunkStorageMetadata: + return ChunkStorageMetadata( + offsets=core.Size([0] * len(tensor.size())), sizes=tensor.size() + ) + + +def _chunk_for_shard(shard_md: ShardMetadata) -> ChunkStorageMetadata: + return ChunkStorageMetadata( + offsets=core.Size(shard_md.shard_offsets), + sizes=core.Size(shard_md.shard_sizes), + ) + + +def _sharded_tensor_metadata( + sharded_tensor: ShardedTensor, shard_md: ShardMetadata +) -> TensorWriteData: + shard_properties = sharded_tensor.metadata().tensor_properties + + properties = TensorProperties( + dtype=shard_properties.dtype, + # layout=shard_properties.layout, + requires_grad=shard_properties.requires_grad, + # memory_format=shard_properties.memory_format, + # pin_memory=shard_properties.pin_memory, + ) + + return TensorWriteData( + chunk=_chunk_for_shard(shard_md), + properties=properties, + size=sharded_tensor.metadata().size, + ) + + +# def _create_write_items_for_dtensor(fqn: str, tensor: DTensor) -> WriteItem: +# sizes, offsets = compute_local_shape_and_global_offset( +# tensor.shape, tensor.device_mesh, tensor.placements +# ) +# sizes, offsets = core.Size(sizes), core.Size(offsets) + +# return WriteItem( +# index=MetadataIndex(fqn, offsets), +# type=WriteItemType.SHARD, +# tensor_data=TensorWriteData( +# chunk=ChunkStorageMetadata( +# offsets=offsets, +# sizes=sizes, +# ), +# properties=TensorProperties.create_from_tensor(tensor.to_local()), +# size=tensor.size(), +# ), +# ) + + +def _create_write_item_for_shard( + fqn: str, sharded_tensor: ShardedTensor, shard_md: ShardMetadata +) -> WriteItem: + offsets = core.Size(shard_md.shard_offsets) + return WriteItem( + index=MetadataIndex(fqn, offsets), + type=WriteItemType.SHARD, + tensor_data=_sharded_tensor_metadata(sharded_tensor, shard_md), + ) + + +def _create_write_item_for_tensor(fqn: str, tensor: core.Tensor) -> WriteItem: + offsets = core.Size([0] * len(tensor.size())) + return WriteItem( + index=MetadataIndex(fqn, offsets), + type=WriteItemType.TENSOR, + tensor_data=TensorWriteData( + chunk=ChunkStorageMetadata(offsets=offsets, sizes=tensor.size()), + properties=TensorProperties.create_from_tensor(tensor), + size=tensor.size(), + ), + ) + + +def _create_write_item_for_bytesio(fqn: str, bytes: Any): + return WriteItem( + index=MetadataIndex(fqn), + type=WriteItemType.BYTE_IO, + ) + + +def _create_read_item_for_byteio( + dest_index, dest_offset, storage_index, storage_offset, length +): + return ReadItem( + type=LoadItemType.BYTE_IO, + dest_index=dest_index, + dest_offsets=core.Size((dest_offset,)), + storage_index=storage_index, + storage_offsets=core.Size((storage_offset,)), + lengths=core.Size((length,)), + ) + + +def _create_read_item_for_tensor( + dest_index, dest_offsets, storage_index, storage_offsets, lengths +): + return ReadItem( + type=LoadItemType.TENSOR, + dest_index=dest_index, + dest_offsets=core.Size(dest_offsets), + storage_index=storage_index, + storage_offsets=core.Size(storage_offsets), + lengths=core.Size(lengths), + ) + + +def create_read_items_for_chunk_list( + fqn: str, + checkpoint_md: TensorStorageMetadata, + local_chunks: List[ChunkStorageMetadata], +) -> List[ReadItem]: + """ + Create a list of ``ReadItem`` based on the checkpoint and local chunks. + + This applies the resharding algorithm and computes the reads needed + to satisfy ``local_chunks`` with a checkpoint described by ``checkpoint_md``. + + Args: + fqn (str) : The state_dict FQN to pass to ``ReadItem``. + checkpoint_md (TensorStorageMetadata): metadata for a given tensor + from a checkpoint. + local_chunks (List[ChunkStorageMetadata]): Local chunks that needs to be + loaded. + + Returns: + A list of ``ReadItem`` that will satisfy all input chunks. + """ + read_items = [] + # this is a naive quadratic algo that can be optimized later + for idx, shard in enumerate(local_chunks): + for storage_idx, storage_md in enumerate(checkpoint_md.chunks): + if not _check_shard_metadata_pair_overlap(shard, storage_md): + continue + + storage_offsets = [] + dest_offsets = [] + lengths = [] + for ( + _dim, + offset_for_saved_tensor, + offset_for_current_tensor, + length, + ) in _shards_get_overlap_region_wrt_saved_tensor( + saved_shard=storage_md, current_shard=shard + ): + storage_offsets.append(offset_for_saved_tensor) + dest_offsets.append(offset_for_current_tensor) + lengths.append(length) + + read_items.append( + _create_read_item_for_tensor( + dest_index=MetadataIndex(fqn, shard.offsets, idx), + dest_offsets=dest_offsets, + storage_index=MetadataIndex(fqn, storage_md.offsets, storage_idx), + storage_offsets=storage_offsets, + lengths=lengths, + ) + ) + return read_items + + +def _create_default_metadata_only_plan(state_dict: STATE_DICT_TYPE) -> SavePlan: + requests = [] + for fqn, obj in state_dict.items(): + # if isinstance(obj, DTensor): + # requests.append(_create_write_items_for_dtensor(fqn, obj)) + if isinstance(obj, ShardedTensor): + requests.extend( + _create_write_item_for_shard(fqn, obj, shard_md) + for shard_md in obj.metadata().shards_metadata + ) + elif isinstance(obj, core.Tensor): + requests.append(_create_write_item_for_tensor(fqn, obj)) + else: + requests.append(_create_write_item_for_bytesio(fqn, obj)) + return SavePlan(requests) + + +def _create_write_items(fqn: str, object: Any) -> List[WriteItem]: + if hasattr(object, "__create_write_items__"): + # DTensor implements _Checkpointable + return object.__create_write_items__(fqn, object) + elif isinstance(object, ShardedTensor): + return [ + _create_write_item_for_shard(fqn, object, shard.metadata) + for shard in object.local_shards() + ] + elif isinstance(object, core.Tensor): + return [_create_write_item_for_tensor(fqn, object)] + else: + return [_create_write_item_for_bytesio(fqn, object)] + + +# def _create_chunk_from_dtensor(tensor: DTensor) -> ChunkStorageMetadata: +# sizes, offsets = compute_local_shape_and_global_offset( +# tensor.shape, tensor.device_mesh, tensor.placements +# ) +# sizes, offsets = core.Size(sizes), core.Size(offsets) +# return ChunkStorageMetadata( +# offsets=offsets, +# sizes=sizes, +# ) + + +def _create_chunk_list(tensor: core.Tensor) -> List[ChunkStorageMetadata]: + if hasattr(tensor, "__create_chunk_list__"): + # DTensor implements _Checkpointable + local_chunks = tensor.__create_chunk_list__() # type: ignore[attr-defined] + elif isinstance(tensor, ShardedTensor): + local_chunks = [ + _chunk_for_shard(shard.metadata) for shard in tensor.local_shards() + ] + elif isinstance(tensor, core.Tensor): + local_chunks = [_create_chunk_from_tensor(tensor)] + else: + raise ValueError( + "Unsupported Type, expecting one of [Tensor, DTensor, ShardedTensor] " + f",but got {type(tensor)}" + ) + + return local_chunks + + +def _create_read_items(fqn: str, md: STORAGE_TYPES, obj: Any) -> List[ReadItem]: + if not isinstance(md, BytesStorageMetadata): + try: + local_chunks = _create_chunk_list(obj) + except ValueError as ex: + raise ValueError( + f"Invalid checkpoint metadata for {fqn}, " + + f"expected BytesStorageMetadata but found {type(md)}", + ) from ex + + return create_read_items_for_chunk_list(fqn, md, local_chunks) + else: + return [ + _create_read_item_for_byteio( + dest_index=MetadataIndex(fqn), + dest_offset=0, + storage_index=MetadataIndex(fqn), + storage_offset=0, + length=0, + ) + ] + + +def _init_state_dict(state_dict: Dict[str, Any]) -> Any: + """ + Initializes meta tensor if the meta tensor is DTensor or core.Tensor. + """ + + # def dtensor_func(value: DTensor): + # device = getattr(value, "device", None) + # if device == core.device("meta"): + # device_type = dist.distributed_c10d._get_pg_default_device().type + # device = cast( + # core.device, _get_device_module(device_type).current_device() + # ) + # new_local_tensor = core.empty_like(value.to_local(), device=device) + # # We need to pass shape and stride explicitly, since DTensor might be + # # sharded unevenly. + # dtensor = DTensor.from_local( + # new_local_tensor, + # device_mesh=value.device_mesh, + # placements=value.placements, + # shape=value.size(), + # stride=value.stride(), + # ) + # return dtensor + # else: + # return value + + def sharded_tensor_func(value: Any): + device = getattr(value, "device", None) + if device == core.device("meta"): + raise RuntimeError( + f"Found unsupported type {type(value)} for meta device loading." + ) + else: + return value + + def tensor_func(value: core.Tensor): + device = getattr(value, "device", None) + if device == core.device("meta"): + device_type = dist.distributed_c10d._get_pg_default_device().type + device = cast( + core.device, _get_device_module(device_type).current_device() + ) + tensor = core.empty_like(value, device=device) + return tensor + else: + return value + + _iterate_state_dict( + state_dict, + # dtensor_func, + sharded_tensor_func, + tensor_func, + ) + + +def _iterate_state_dict( + iter_object: Any, + # dtensor_func: Callable, + sharded_tensor_func: Callable, + tensor_func: Callable, +): + """ + Iterate through the state dict, applying the given functions to each tensor type + and update the state dict in place. + + Args: + iter_object (Any): the target state_dict. + sharded_tensor_func (Callable): the function to apply to ShardedTensor + dtensor_func (Callable): the function to apply to DTensor + tensor_func (Callable): the function to apply to Tensor + + # TODO: let state_dict_util._iterate_state_dict() to support in place option + so we don't need to have two versions of _iterate_state_dict. + """ + + # if isinstance(iter_object, DTensor): + # return dtensor_func(iter_object) + # el + if isinstance(iter_object, ShardedTensor): + return sharded_tensor_func(iter_object) + elif isinstance(iter_object, core.Tensor): + return tensor_func(iter_object) + elif ( + isinstance(iter_object, (int, float, str, bytes, io.BytesIO)) + or iter_object is None + ): + return iter_object + elif isinstance(iter_object, dict): + for key, value in iter_object.items(): + iter_object[key] = _iterate_state_dict( + value, sharded_tensor_func, tensor_func + ) + return iter_object + elif isinstance(iter_object, (list, tuple)): + ret = [ + _iterate_state_dict(v, sharded_tensor_func, tensor_func) + for v in iter_object + ] + if isinstance(iter_object, tuple): + ret = tuple(ret) # type: ignore[assignment] + return ret diff --git a/mindnlp/core/distributed/checkpoint/resharding.py b/mindnlp/core/distributed/checkpoint/resharding.py new file mode 100644 index 000000000..955bfd37d --- /dev/null +++ b/mindnlp/core/distributed/checkpoint/resharding.py @@ -0,0 +1,72 @@ +# mypy: allow-untyped-defs +from typing import List, Tuple + +from core.distributed.checkpoint.metadata import ChunkStorageMetadata + + +__all__: List[str] = [] + + +def _check_shard_metadata_pair_overlap( + shard1: ChunkStorageMetadata, shard2: ChunkStorageMetadata +): + """Check if two shards overlap.""" + # For each dim of each shard, check if one shard resides on the other + # end of second shard with respect to that dim. As an example for a 2D + # shard, we would check if one shard is above or on the left of the + # other shard. + ndims = len(shard1.offsets) + for i in range(ndims): + if shard1.offsets[i] >= shard2.offsets[i] + shard2.sizes[i]: + return False + if shard2.offsets[i] >= shard1.offsets[i] + shard1.sizes[i]: + return False + + return True + + +def _shards_get_overlap_region_wrt_saved_tensor( + saved_shard: ChunkStorageMetadata, current_shard: ChunkStorageMetadata +) -> List[Tuple[int, int, int, int]]: + """ + Return the overlapping region between saved_shard and current_shard. + + There returned list has the same number of elements as the tensor's dimension. + For each element, we produce a tuple with the following contents: + (dimension, `saved_shard` offset, `current_shard` offset, length) + + Offsets are relative to each shard. + """ + narrows = [] + for dim, ( + saved_shard_offset, + current_shard_offset, + saved_shard_size, + current_shard_size, + ) in enumerate( + zip( + saved_shard.offsets, + current_shard.offsets, + saved_shard.sizes, + current_shard.sizes, + ) + ): + min_range_end = min( + saved_shard_offset + saved_shard_size, + current_shard_offset + current_shard_size, + ) + + length = min_range_end - max(current_shard_offset, saved_shard_offset) + + if saved_shard_offset > current_shard_offset: + offset_for_saved_tensor = 0 + offset_for_current_tensor = saved_shard_offset - current_shard_offset + else: + offset_for_saved_tensor = current_shard_offset - saved_shard_offset + offset_for_current_tensor = 0 + + narrows.append( + (dim, offset_for_saved_tensor, offset_for_current_tensor, length) + ) + + return narrows diff --git a/mindnlp/core/distributed/checkpoint/staging.py b/mindnlp/core/distributed/checkpoint/staging.py new file mode 100644 index 000000000..810e8c285 --- /dev/null +++ b/mindnlp/core/distributed/checkpoint/staging.py @@ -0,0 +1,117 @@ +from typing import Optional, runtime_checkable +from typing_extensions import Protocol + +from core.distributed._state_dict_utils import ( + _copy_state_dict, + _create_cpu_state_dict, + _offload_state_dict_to_cpu, +) +from core.distributed.checkpoint.metadata import STATE_DICT_TYPE + + +__all__ = ["AsyncStager", "BlockingAsyncStager"] + + +@runtime_checkable +class AsyncStager(Protocol): + """ + This protocol is meant to provide customization and extensibility for dcp.async_save, allowing users + to customize how data is staged previous to executing the usual dcp.save path in parallel. + The expected order of operations (concretely defined in `core.distributed.state_dict_saver.async_save`) + is the following: + + 1. AsyncStager.stage_data(state_dict): + This call gives the AsyncStager the opportunity to 'stage' + the state_dict. The expectation and purpose of staging in this context is to create a "training-safe" + representation of the state dict, meaning that any updates to module data after staging is complete + should not be reflected in the state dict returned from this method. For example, in the default + case a copy of the entire state dict is created on CPU RAM and returned here, allowing users + to continue training without risking changes to data which is being serialized. + + 2. dcp.save is called on the state_dict returned from stage in parallel. This call is responsible + for serializing the state_dict and writing it to storage. + + 3. If AsyncStager.should_synchronize_after_execute is True, this method will be called immediately after + the serialization thread starts and before returning from dcp.async_save. If this is set to False, + the assumption is the user has defined a custom synchronization point for the the purpose of further + optimizing save latency in the training loop (for example, by overlapping staging with the + forward/backward pass), and it is the respondsibility of the user to call `AsyncStager.synchronize_staging` + at the appropriate time. + + """ + + # default to True since the common case is to stage synchronously + _synchronize_after_execute: bool = True + + @property + def should_synchronize_after_execute(self) -> bool: + """ + Whether to synchronize after executing the stage. + """ + + return self._synchronize_after_execute + + def stage(self, state_dict: STATE_DICT_TYPE) -> STATE_DICT_TYPE: + """ + Returns a "staged" copy of `state_dict`. The expectation of the staged copy is that it is + innoculated from any updates incurred after the stage call is complete. + """ + raise NotImplementedError( + f"{self.__class__.__name__} must implement stage method" + ) + + def synchronize_staging(self) -> None: + """ + In the case `stage` is async in some way, this method should be called to ensure staging + is complete and it is safe to begin modifying the original `state_dict` + """ + + +class BlockingAsyncStager(AsyncStager): + """ + An implementation of AsyncStager which stages the state_dict on CPU RAM and blocks until the copy is complete. + This implementation also provides an option to optimize stage latency using pinned memory. + + N.B. synchronize_staging is a no-op in this case. + + + """ + + # default to True since the common case is to stage synchronously + _synchronize_after_execute: bool = False + + def __init__( + self, + cache_staged_state_dict: bool = False, + type_check: bool = False, + ): + """ + Initializes the BlockingAsyncStager. + + Args: + cache_staged_state_dict: Whether to cache the staged state_dict. This option decreases staging latency + at the cost of increases memory usage. Additionally, if this parameter is set to True, it's the expectation + that the stager is maintained and re-used for multiple dcp.async_save calls. Default to False. + type_check: Whether to perform a type check during cpu_offload. Defaults to False. + + """ + self.cache_staged_state_dict = cache_staged_state_dict + self.type_check = type_check + self.state_dict_cache: Optional[STATE_DICT_TYPE] = None + + def stage(self, state_dict: STATE_DICT_TYPE) -> STATE_DICT_TYPE: + """ + Returns a copy of `state_dict` on the CPU. + """ + + if not self.cache_staged_state_dict: + return _offload_state_dict_to_cpu(state_dict, type_check=self.type_check) + + if self.state_dict_cache is None: + self.state_dict_cache = _create_cpu_state_dict(state_dict, pin_memory=True) + return _copy_state_dict(state_dict, self.state_dict_cache) + + def synchronize_staging(self) -> None: + """ + No-op function, since staging is blocking. + """ diff --git a/mindnlp/core/distributed/checkpoint/state_dict.py b/mindnlp/core/distributed/checkpoint/state_dict.py new file mode 100644 index 000000000..0ca8eebab --- /dev/null +++ b/mindnlp/core/distributed/checkpoint/state_dict.py @@ -0,0 +1,1422 @@ +# mypy: allow-untyped-defs +import contextlib +import functools +import gc +import warnings +from dataclasses import asdict, dataclass, field +from itertools import chain +from typing import ( + Any, + Callable, + cast, + Dict, + Generator, + Iterable, + List, + no_type_check, + Optional, + Set, + Tuple, + Union, +) + +from mindnlp import core +from mindnlp import core.distributed as dist +from mindnlp import core.nn as nn +from core.distributed._shard.sharded_tensor import ShardedTensor +from core.distributed._state_dict_utils import ( + _broadcast_state_dict, + _distribute_state_dict, + _flatten_state_dict, + _gather_state_dict, + _offload_state_dict_to_cpu, + _unflatten_state_dict, +) +from core.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + _CHECKPOINT_PREFIX, +) +from core.distributed.fsdp import ( + FullOptimStateDictConfig, + FullStateDictConfig, + FullyShardedDataParallel as FSDP, + OptimStateDictConfig, + ShardedOptimStateDictConfig, + ShardedStateDictConfig, + StateDictConfig, + StateDictType, +) +from core.distributed.fsdp._common_utils import ( + _get_module_fsdp_state_if_fully_sharded_module, + FSDP_WRAPPED_MODULE, +) +from core.distributed.tensor import DTensor +from core.nn.modules.module import _IncompatibleKeys +from core.nn.parallel import DistributedDataParallel as DDP +from core.utils._pytree import tree_map_only + + +__all__ = [ + "FQNS_T", + "PrimitiveType", + "ValueType", + "DictValueType", + "ListDictValueType", + "OptimizerStateType", + "StateDictOptions", + "get_model_state_dict", + "get_optimizer_state_dict", + "get_state_dict", + "set_model_state_dict", + "set_optimizer_state_dict", + "set_state_dict", +] + + +_FLAT_PARAM = "_flat_param" +_PG = "param_groups" +_PARAMS = "params" +_STATE = "state" + +FQNS_T = Set[str] +PrimitiveType = Union[DTensor, ShardedTensor, core.Tensor, int, float, str] +ValueType = Union[ + PrimitiveType, List[PrimitiveType], Tuple[PrimitiveType], Dict[str, "ValueType"] +] +DictValueType = Dict[str, ValueType] +ListDictValueType = List[DictValueType] +OptimizerStateType = Dict[str, Union[DictValueType, ListDictValueType]] + + +_patched_state_dict: Set[Callable] = set() + + +@contextlib.contextmanager +def _gc_context(): + is_enabled = gc.isenabled() + gc.disable() + try: + yield + finally: + if is_enabled: + gc.enable() + + +@dataclass +class StateDictOptions: + """ + This dataclass specifies how get_state_dict/set_state_dict will work. + + - ``full_state_dict``: if this is set to True, all the tensors in the + returned state_dict will be gathered. No ShardedTensor and DTensor + will be in the returned state_dict. + + - ``cpu_offload``: offload all the tensors to cpu. To prevent CPU OOM, if + ``full_state_dict`` is also true, then only the rank0 will get the + state_dict and all other ranks will get empty state_dict. + + - ``ignore_frozen_params``: if the value is True, the returned state_dict + won't contain any frozen parameters -- the ``requires_grad`` is False. + The default value is False. + + - ``keep_submodule_prefixes`` (deprecated): when ``submodules`` is not None, this option + indicates whether to keep the submodule prefixes from the state_dict keys. + or example, if the submodule is ``module.pretrain`` and the full FQN of + the parameter is ``pretrain.layer1.weight`` of the param. When this option + is True, the parameter's key in the returned state_dict will be + ``pretrain.layer1.weight``. If the options is False, the key will be + ``layer1.weight``. + Note that if ``keep_submodule_prefixes`` is False, there may be conflicted + FQNs, hence there should be only one submodule in ``submodules``. + + - ``strict``: the ``strict`` option when ``set_state_dict`` calls + model.load_state_dict(). + + - ``broadcast_from_rank0``: when the option is True, rank0 should receive a + full state_dict and will broadcast the tensors in the state_dict/ + optim_state_dict one by one to other ranks. Other ranks will receive + the tensors and shard according to the local shards in the model and + optimizer. ``full_state_dict`` must be set to True when using this option. + This option currently only supports DTensor, not the legacy ShardedTensor. + """ + + full_state_dict: bool = False + cpu_offload: bool = False + ignore_frozen_params: bool = False + keep_submodule_prefixes: bool = True + strict: bool = True + broadcast_from_rank0: bool = False + flatten_optimizer_state_dict: bool = False + + +@dataclass +class _StateDictInfo(StateDictOptions): + fqn_param_mapping: Dict[ + Union[str, core.Tensor], Union[FQNS_T, core.Tensor] + ] = field(default_factory=dict) + shared_params_mapping: Dict[ + Union[str, core.Tensor], Union[FQNS_T, core.Tensor] + ] = field(default_factory=dict) + submodule_prefixes: Set[str] = field(default_factory=set) + handle_model: bool = True + handle_optim: bool = True + fsdp_context: Callable = contextlib.nullcontext + fsdp_modules: List[nn.Module] = field(default_factory=list) + + +@functools.lru_cache(maxsize=None) +def _get_fqns( + model: nn.Module, + name: str, + skip_ddp_prefix: bool = True, + skip_compiler_prefix: bool = True, +) -> FQNS_T: + """ + This API is used to convert the name of a parameter to the FQNs. For FSDP + without `use_orig_params`, the name of FlatParameter can be mapped to + multiple original parameters. As a result, the return type of this function + is `Set[str]`. + + Args: + module (nn.Module): the root model. + name (str): the name + skip_ddp_prefix (bool): whether to skip DDP's `module` prefix + + Returns: + The canonical FQNs based on the model traversal. + """ + + # Remove the checkpoint prefix, if it exists. + name = name.replace(_CHECKPOINT_PREFIX, "") + if "." not in name: + return {name} + + obj_names = name.split(".") + fqn_obj_names = [] + curr_obj = model + for i, curr_obj_name in enumerate(obj_names): + if isinstance(curr_obj, DDP): + assert curr_obj_name == "module" + curr_obj = curr_obj.module + if not skip_ddp_prefix: + fqn_obj_names.append(curr_obj_name) + elif isinstance(curr_obj, FSDP): + if i < len(obj_names) - 1 and obj_names[i + 1] == _FLAT_PARAM: + prefix = ".".join(fqn_obj_names) + flat_param = getattr(curr_obj, _FLAT_PARAM) + if prefix: + prefix = f"{prefix}." + return {f"{prefix}{fqn}" for fqn in flat_param._fqns} + curr_obj = getattr(curr_obj, FSDP_WRAPPED_MODULE) + if curr_obj_name != FSDP_WRAPPED_MODULE: + fqn_obj_names.append(curr_obj_name) + curr_obj = getattr(curr_obj, curr_obj_name) + elif isinstance(curr_obj, core._dynamo.eval_frame.OptimizedModule): + assert curr_obj_name == "_orig_mod" + curr_obj = curr_obj._orig_mod + if not skip_compiler_prefix: + fqn_obj_names.append(curr_obj_name) + else: + fqn_obj_names.append(curr_obj_name) + if curr_obj_name == nn.modules.module._EXTRA_STATE_KEY_SUFFIX: + if i != len(obj_names) - 1: + raise RuntimeError("Expect `_extra_state` to be the last obj name") + else: + curr_obj = getattr(curr_obj, curr_obj_name) + + return {".".join(fqn_obj_names).replace(_CHECKPOINT_PREFIX, "")} + + +class _EXTRA_STATE: + pass + + +def _iterate_valid_model_state(model): + visited_modules: Set[nn.Module] = set() + + def recurse(module: nn.Module, curr_fqn: str) -> Generator: + visited_modules.add(module) + + curr_fqn = f"{curr_fqn}." if curr_fqn else "" + for name, submodule in module.named_children(): + if submodule in visited_modules: + continue + new_fqn = f"{curr_fqn}{name}" + yield from recurse(submodule, new_fqn) + + for name, obj in chain( + module.named_buffers(recurse=False), module.named_parameters(recurse=False) + ): + if name in module._non_persistent_buffers_set: + continue + new_fqn = f"{curr_fqn}{name}" + yield new_fqn, obj + + if ( + getattr(module.__class__, "get_extra_state", nn.Module.get_extra_state) + != nn.Module.get_extra_state + ): + new_fqn = f"{curr_fqn}{nn.modules.module._EXTRA_STATE_KEY_SUFFIX}" + yield new_fqn, _EXTRA_STATE() + + yield from recurse(model, "") + + +def _verify_options( + model: nn.Module, + optims: Tuple[core.optim.Optimizer, ...], + optim_only: bool, + *, + submodules: Optional[Set[nn.Module]] = None, + options: Optional[StateDictOptions] = None, +) -> _StateDictInfo: + """ + Verify the model and options passed by the user and generates _StateDictInfo. + """ + if submodules: + warnings.warn( + "Getting submodules only model/optim state_dict is deprecated and " + "will be removed in 2.5. This feature can be achieved by manually " + "filtering out the state_dict returned from get_state_dict.", + FutureWarning, + ) + if optim_only and not optims: + raise RuntimeError( + "Optimizers are not passed in but optim_only is set to True." + ) + + options = options or StateDictOptions() + + fqn_param_mapping: Dict[ + Union[str, core.Tensor], Union[Set[str], core.Tensor] + ] = {} + shared_params_mapping: Dict[ + Union[str, core.Tensor], Union[Set[str], core.Tensor] + ] = {} + for name, param in _iterate_valid_model_state(model): + if isinstance(param, _EXTRA_STATE): + continue + + fqns = _get_fqns(model, name) + fqn = fqn_param_mapping.get(param, None) + if fqn is not None: + cast(Set[str], fqn_param_mapping[param]).update(fqns) + shared_params_mapping[param] = fqn_param_mapping[param] + else: + # We need to do copy as _get_fqns is lru_cached + fqn_param_mapping[param] = fqns.copy() + for fqn in fqns: + if not isinstance(param, _EXTRA_STATE): + fqn_param_mapping[fqn] = param + + for param_, fqns_ in list(shared_params_mapping.items()): + for fqn in fqns_: + shared_params_mapping[fqn] = cast(core.Tensor, param_) + + submodule_prefixes: Set[str] = set() + if submodules: + submodules = set(submodules) + for name, module in model.named_modules(): + if module not in submodules: + continue + fqns = _get_fqns(model, name) + assert len(fqns) == 1, "Submodule FQN should only have 1 instance" + submodule_prefixes.update(f"{fqn}." for fqn in fqns) + + if options.broadcast_from_rank0 and not options.full_state_dict: + raise ValueError( + "full_state_dict must be True when broadcast_from_rank0 is True." + ) + fsdp_modules = FSDP.fsdp_modules(model) + state_dict_config: StateDictConfig + optim_state_dict_config: OptimStateDictConfig + fsdp_context: Callable + if fsdp_modules: + # FSDP API only work if at least one FSDP instance exists. + if options.full_state_dict: + state_dict_config = FullStateDictConfig( + offload_to_cpu=options.cpu_offload, rank0_only=options.cpu_offload + ) + optim_state_dict_config = FullOptimStateDictConfig( + offload_to_cpu=options.cpu_offload, + rank0_only=(options.cpu_offload or options.broadcast_from_rank0), + ) + state_dict_type = StateDictType.FULL_STATE_DICT + else: + state_dict_config = ShardedStateDictConfig( + offload_to_cpu=options.cpu_offload, + ) + optim_state_dict_config = ShardedOptimStateDictConfig( + offload_to_cpu=options.cpu_offload, + ) + state_dict_type = StateDictType.SHARDED_STATE_DICT + + @contextlib.contextmanager + def fsdp_state_dict_type_without_warning( + module, + state_dict_type, + state_dict_config, + optim_state_dict_config, + ): + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", message="FSDP.state_dict_type", category=FutureWarning + ) + with FSDP.state_dict_type( + module=module, + state_dict_type=state_dict_type, + state_dict_config=state_dict_config, + optim_state_dict_config=optim_state_dict_config, + ): + yield + + fsdp_context = functools.partial( + fsdp_state_dict_type_without_warning, + module=model, + state_dict_type=state_dict_type, + state_dict_config=state_dict_config, + optim_state_dict_config=optim_state_dict_config, + ) + else: + fsdp_context = contextlib.nullcontext + + return _StateDictInfo( + **asdict(options), + fqn_param_mapping=fqn_param_mapping, + shared_params_mapping=shared_params_mapping, + submodule_prefixes=submodule_prefixes, + fsdp_context=fsdp_context, + fsdp_modules=cast(List[nn.Module], fsdp_modules), + handle_model=not optim_only, + handle_optim=(len(optims) > 0), + ) + + +def _verify_state_dict( + model_state_dict: Dict[str, ValueType], + optim_state_dict: OptimizerStateType, + info: _StateDictInfo, +) -> None: + for module in info.fsdp_modules: + fsdp_state = _get_module_fsdp_state_if_fully_sharded_module(module) + assert fsdp_state is not None, "Expected a fsdp_state with a fsdp module." + + # Verify if the model_state_dict and optim_state_dict are valid. This API + # should give the users an explicit error message to debug or report. + if ( + info.handle_model + and not model_state_dict + and not info.submodule_prefixes + and not info.ignore_frozen_params + and not (info.cpu_offload and info.full_state_dict) + and info.strict + and not info.broadcast_from_rank0 + ): + raise RuntimeError( + "The option indicates that model state_dict is required to save " + "or load, but model state_dict is empty." + f"rank = {dist.get_rank()=}." + ) + + if info.handle_optim: + if ( + not optim_state_dict + and not (info.cpu_offload and info.full_state_dict) + and (not info.broadcast_from_rank0) + ): + raise RuntimeError( + "The option indicates that model state_dict is required to save, " + f"or load but optim state_dict is empty. {optim_state_dict}" + ) + + for key in model_state_dict.keys(): + if _FLAT_PARAM in key: + raise RuntimeError( + f"{key} contains {_FLAT_PARAM}. This can happen if the model " + "is not the root module." + ) + + +def _state_dict_fn(obj: Union[nn.Module, core.optim.Optimizer], api: str) -> Callable: + call = getattr(obj, api) + if call in _patched_state_dict: + call = functools.partial(getattr(obj.__class__, api), self=obj) + return call + + +def _maybe_full_or_cpu_state_dict( + state_dict: Dict[str, Any], info: _StateDictInfo +) -> Dict[str, Any]: + if info.full_state_dict: + ranks_only = ( + () + if (not info.cpu_offload or not core.distributed.is_initialized()) + else (0,) + ) + return _gather_state_dict( + state_dict, cpu_offload=info.cpu_offload, ranks_only=ranks_only + ) + elif info.cpu_offload: + return _offload_state_dict_to_cpu(state_dict) + else: + return state_dict + + +@core.no_grad() +def _get_model_state_dict( + model: nn.Module, info: _StateDictInfo +) -> Dict[str, ValueType]: + if not info.handle_model: + return {} + + with info.fsdp_context(): + state_dict = _state_dict_fn(model, "state_dict")() + + for key in list(state_dict.keys()): + fqns = _get_fqns(model, key) + assert len(fqns) == 1, (key, fqns) + fqn = next(iter(fqns)) + if fqn != key: + # As we only support FSDP, DDP, and TP, the only cases are + # wrapper-based DDP and compiler. Verify if the assumption + # is correct. + def verify(key, fqn) -> bool: + if len(fqn) >= len(key): + return False + fqn_split = fqn.split(".") + key_split = key.split(".") + fqn_idx = 0 + for key_idx, key_name in enumerate(key_split): + if key_name == fqn_split[fqn_idx]: + fqn_idx += 1 + if fqn_idx == len(fqn_split): + return key_idx == len(key_split) - 1 + elif key_name in ("module", "_orig_mod"): + continue + else: + return False + return True + + if not verify(key, fqn): + raise RuntimeError(f"An unexpected key, {key}, exists. FQN is {fqn}") + state_dict[fqn] = state_dict.pop(key) + + if info.submodule_prefixes: + new_state_dict: Dict[str, ValueType] = {} + # TODO: make this faster. + for fqn in state_dict.keys(): + for prefix in info.submodule_prefixes: + if not fqn.startswith(prefix): + continue + if info.keep_submodule_prefixes: + new_state_dict[fqn] = state_dict[fqn] + else: + new_fqn = fqn[len(prefix) :] + new_state_dict[new_fqn] = state_dict[fqn] + state_dict = new_state_dict + + if info.ignore_frozen_params: + for key, param in model.named_parameters(): + if param.requires_grad: + continue + fqns = _get_fqns(model, key) + for fqn in fqns: + state_dict.pop(fqn) + + for key, p in list(state_dict.items()): + if core.is_tensor(p) and p.is_meta: + state_dict.pop(key) + + return _maybe_full_or_cpu_state_dict(state_dict, info) + + +@core.no_grad() +def _load_model_state_dict( + model: nn.Module, + state_dict: Dict[str, ValueType], + info: _StateDictInfo, +) -> _IncompatibleKeys: + if not info.handle_model or (not state_dict and not info.broadcast_from_rank0): + return _IncompatibleKeys({}, {}) + + local_state_dict = {} + for key, value in _iterate_valid_model_state(model): + fqns = _get_fqns(model, key) + fqns_with_prefix = _get_fqns( + model, key, skip_ddp_prefix=False, skip_compiler_prefix=False + ) + + for fqn, fqn_with_prefix in zip(fqns, fqns_with_prefix): + if ( + not info.broadcast_from_rank0 or dist.get_rank() == 0 + ) and fqn != fqn_with_prefix: + state_dict[fqn_with_prefix] = state_dict.pop(fqn) + local_state_dict[fqn_with_prefix] = value + + assign = False + if info.broadcast_from_rank0 or info.full_state_dict: + device = None + for key, value in local_state_dict.items(): + if core.is_tensor(value) and value.dim() > 0: + if device is None: + device = value.device + else: + assert device == value.device + assert device is not None + if device == core.device("meta"): + device = dist.distributed_c10d._get_pg_default_device() + assign = True + if info.broadcast_from_rank0: + _broadcast_state_dict( + state_dict, local_state_dict, device=device, strict=info.strict + ) + elif info.full_state_dict: + _distribute_state_dict(state_dict, local_state_dict, device=device) + for fqn, local_state in local_state_dict.items(): + state_dict[fqn] = local_state + + with info.fsdp_context(): + return cast( + _IncompatibleKeys, + _state_dict_fn(model, "load_state_dict")( + state_dict=state_dict, strict=info.strict, assign=assign + ), + ) + + +def _init_optim_state(optim: core.optim.Optimizer) -> None: + """ + Initialize optim states by calling the step() with zero grads. + """ + if optim.state: + # The optimizer state is initialized. + return + + # There are some stateless optimizers like SGD. These optimizer will + # not return in the above condition. So if gradients exist, we should also + # return. If gradients do not exist, the following initialization should + # not disturb SGD because the gradients and lr are both zero. + for param_group in optim.param_groups: + for param in param_group[_PARAMS]: + if param.grad is not None: + return + + for param_group in optim.param_groups: + for param in param_group[_PARAMS]: + if param.requires_grad: + param.grad = core.zeros_like(param) + + # Some optimizers will update parameters regardless of grads due to lr, so + # make lr to zero when calling `step()`. + lrs = [] + for param_group in optim.param_groups: + if "lr" in param_group: + lrs.append(param_group["lr"]) + param_group["lr"] = 0.0 + optim.step(closure=None) + # Whether to recover the "lr" should not matter too much as we will + # restore checkpointing later. + for param_group in optim.param_groups: + if "lr" in param_group: + param_group["lr"] = lrs.pop(0) + optim.zero_grad(set_to_none=True) + + +def _flatten_optim_state_dict(state_dict: OptimizerStateType) -> Dict[str, ValueType]: + """ + This API flattens the optimizer state_dict to support optimizer resharding for + MPMD, e.g., pipeline parallelism. + + Without the API, the original optimizer state_dict looks like: + { + "state": { + "layer1.weight": { + "step": 10, "exp_avg": SomeTensor, "exp_avg_sq": SomeTensor + }, + "layer2.weight": { + "step": 10, "exp_avg": SomeTensor, "exp_avg_sq": SomeTensor + }, + }, + "param_group": [ + { + "lr": 0.0, + "betas": (0.9, 0.95), ..., + "params": ["layer1.weight", "layer2.weight"] + } + ] + } + + With this API, the optimizer state_dict looks like: + { + "state.layer1.weight.step": 10, + "state.layer2.weight.step": 10, + "state.layer1.weight.exp_avg": SomeTensor, + "state.layer2.weight.exp_avg": SomeTensor, + "state.layer1.weight.exp_avg_sq": SomeTensor, + "state.layer2.weight.exp_avg_sq": SomeTensor, + "param_group.layer1.weight.lr" : 0.1, + "param_group.layer2.weight.lr" : 0.1, + "param_group.layer1.weight.betas" : (0.9, 0.95), + "param_group.layer2.weight.betas" : (0.9, 0.95), + } + + Note that if any of the value is a container, like the betas in the example, + this API won't flattent it. + """ + + def _raise_if_type_not_supported(v): + if not isinstance(v, (core.Tensor, int, float)): + raise NotImplementedError( + "Flattening optimizer state_dict only supports " + "tensor, int, float states now. " + f"Type is {type(v)}." + ) + + ret: Dict[str, ValueType] = {} + for fqn, state in cast(DictValueType, state_dict[_STATE]).items(): + for k, v in cast(DictValueType, state).items(): + _raise_if_type_not_supported(v) + ret[f"{_STATE}.{fqn}.{k}"] = v + + for param_group in cast(ListDictValueType, state_dict[_PG]): + fqns = param_group.pop(_PARAMS) + for fqn in cast(List[str], fqns): + for k, v in param_group.items(): + ret[f"{_PG}.{fqn}.{k}"] = v + return ret + + +def _unflatten_optim_state_dict( + optim: core.optim.Optimizer, + state_dict: Dict[str, ValueType], + info: _StateDictInfo, +) -> OptimizerStateType: + """ + This API unflattens the state_dict generated by _flatten_optim_state_dict(). + See the docstring of _flatten_optim_state_dict() for more detail. + """ + state: DictValueType = {} + pg_state: ListDictValueType = [] + return_osd: OptimizerStateType = {_STATE: state, _PG: pg_state} + + for param_group in optim.param_groups: + pg_state.append({_PARAMS: []}) + for param in param_group[_PARAMS]: + for fqn in info.fqn_param_mapping[param]: + params = pg_state[-1][_PARAMS] + assert isinstance(params, list) # typing + params.append(fqn) + if not param.requires_grad: + continue + state[fqn] = {} + for state_name in optim.state[param].keys(): + cast(DictValueType, state[fqn])[state_name] = state_dict[ + f"{_STATE}.{fqn}.{state_name}" + ] + + first_param_fqn = cast(List[str], pg_state[-1][_PARAMS])[0] + for k in param_group.keys(): + if k == _PARAMS: + continue + value = state_dict[f"{_PG}.{first_param_fqn}.{k}"] + if k not in pg_state[-1]: + pg_state[-1][k] = value + elif pg_state[-1][k] != value: + raise RuntimeError( + "All the parameters in the same parameter group should have " + f"the same saved param_group value. But {first_param_fqn}.{k} " + f"is {value} while other(s) is {pg_state[-1][k]}." + ) + + return return_osd + + +@core.no_grad() +def _get_optim_state_dict( + model: nn.Module, + optimizers: Tuple[core.optim.Optimizer, ...], + info: _StateDictInfo, +) -> OptimizerStateType: + if not info.handle_optim: + return {} + + optim_state_dict: OptimizerStateType = {_STATE: {}, _PG: []} + for optim in optimizers: + _init_optim_state(optim) + osd = _state_dict_fn(optim, "state_dict")() + if info.fsdp_modules: + with info.fsdp_context(): + osd = FSDP.optim_state_dict(model, optim, osd) + + # We need to specially handle FlatParameter FSDP as + # FlatParameter FSDP converts the FQNs. + # There are no easy ways to do this conversion systematically. + # We can only use a string replacment without correctness check. + if not osd: + continue + for k in list(osd[_STATE].keys()): + if "_orig_mod" in k: + osd[_STATE][k.replace("_orig_mod.", "")] = osd[_STATE].pop(k) + for g in osd[_PG]: + params = [k.replace("_orig_mod.", "") for k in g[_PARAMS]] + g[_PARAMS] = params + else: + params = list(chain.from_iterable(g[_PARAMS] for g in optim.param_groups)) + param_pid_mapping = dict(zip(params, range(len(params)))) + fqn_pid_mapping = {} + for key, param in model.named_parameters(): + fqns = _get_fqns(model, key) + assert len(fqns) == 1 + fqn = next(iter(fqns)) + if param not in param_pid_mapping: + continue + pid = param_pid_mapping[param] + fqn_pid_mapping[fqn] = pid + fqn_pid_mapping[pid] = fqn + + for key in list(osd[_STATE].keys()): + fqn = fqn_pid_mapping[key] + osd[_STATE][fqn] = osd[_STATE].pop(key) + + for group in osd[_PG]: + group[_PARAMS] = [fqn_pid_mapping[pid] for pid in group[_PARAMS]] + + if not osd: + continue + + cast(DictValueType, optim_state_dict[_STATE]).update(osd[_STATE]) + cast(ListDictValueType, optim_state_dict[_PG]).extend(osd[_PG]) + + if info.flatten_optimizer_state_dict: + optim_state_dict = cast( + OptimizerStateType, _flatten_optim_state_dict(optim_state_dict) + ) + + return _maybe_full_or_cpu_state_dict(optim_state_dict, info) + + +def _split_optim_state_dict( + model: nn.Module, + optim: core.optim.Optimizer, + optim_state_dict: OptimizerStateType, + info: _StateDictInfo, +) -> OptimizerStateType: + """ + Extract the corresponding optim state_dict from ``optim_state_dict`` for + ``optim`` and return the result optim state_dict. + + Args: + model (nn.Module): the root model. + optim (core.optim.Optimizer): the optimizer. + optim_state_dict (Dict[str, ValueType]): the superset optim state_dict that + contains the optim state_dict of ``optim``. + info (_StateDictInfo): state dict information. + + Returns: + The optim state_dict of ``optim``. + """ + + state: DictValueType = {} + pg_state: ListDictValueType = [] + return_osd: OptimizerStateType = {_STATE: state, _PG: pg_state} + pg_mapping: Dict[int, int] = {} + + if all( + isinstance(k, int) for k in cast(DictValueType, optim_state_dict[_STATE]).keys() + ): + return optim_state_dict + + for param_group in optim.param_groups: + pg_state.append({_PARAMS: []}) + for param in param_group[_PARAMS]: + for fqn in info.fqn_param_mapping[param]: + if fqn in info.shared_params_mapping: + in_params = False + for loaded_param_group in cast( + ListDictValueType, optim_state_dict[_PG] + ): + if fqn in cast(List[str], loaded_param_group[_PARAMS]): + in_params = True + break + else: + in_params = True + if not in_params: + continue + + params = pg_state[-1][_PARAMS] + assert isinstance(params, list) + params.append(fqn) + if param.requires_grad: + state[fqn] = cast(DictValueType, optim_state_dict[_STATE])[fqn] + for loaded_param_group in cast( + ListDictValueType, optim_state_dict[_PG] + ): + if fqn in cast(List[str], loaded_param_group[_PARAMS]): + pg_mapping[id(loaded_param_group)] = len(return_osd[_PG]) - 1 + + for param_group in cast(ListDictValueType, optim_state_dict[_PG]): + idx = pg_mapping.get(id(param_group), -1) + if idx == -1: + continue + for key, value in param_group.items(): + if key == _PARAMS: + continue + # TODO: check if value is the same if exists. + pg_state[idx][key] = value + + return return_osd + + +@core.no_grad() +def _load_optim_state_dict( + model: nn.Module, + optimizers: Tuple[core.optim.Optimizer, ...], + state_dict: OptimizerStateType, + info: _StateDictInfo, +) -> None: + if not info.handle_optim: + return + + for optim in optimizers: + _init_optim_state(optim) + if state_dict: + if _STATE in state_dict: + optim_state_dict = _split_optim_state_dict( + model, optim, state_dict, info + ) + else: + optim_state_dict = _unflatten_optim_state_dict( + optim, cast(Dict[str, ValueType], state_dict), info + ) + else: + optim_state_dict = {} + if info.fsdp_modules: + # We need to specially handle FlatParameter FSDP as + # FlatParameter FSDP converts the FQNs. + for original_fqn, _ in model.named_parameters(): + fqns = _get_fqns(model, original_fqn) + fqns_with_compiler = _get_fqns( + model, original_fqn, skip_compiler_prefix=False + ) + if fqns == fqns_with_compiler: + continue + + assert len(fqns) == 1 + fqn = fqns.pop() + fqn_with_compiler = fqns_with_compiler.pop() + for g in optim_state_dict[_PG]: + val = cast(Dict[str, Any], g) + params = [ + key.replace(fqn, fqn_with_compiler) for key in val[_PARAMS] + ] + val[_PARAMS] = params + osd_state = cast(DictValueType, optim_state_dict[_STATE]) + for k in list(osd_state.keys()): + if fqn in k: + osd_state[k.replace(fqn, fqn_with_compiler)] = osd_state.pop(k) + + with info.fsdp_context(): + optim_state_dict = FSDP.optim_state_dict_to_load( + model, optim, optim_state_dict + ) + elif info.full_state_dict: + info.full_state_dict = False + local_state_dict = _get_optim_state_dict(model, (optim,), info) + info.full_state_dict = True + device = None + + def _device(t): + if t.dim() > 0: + nonlocal device + if device is None: + device = t.device + elif device != t.device: + raise ValueError("Device mismatch") + return t + + _ = tree_map_only(core.Tensor, _device, local_state_dict) + assert device is not None + flatten_osd, osd_mapping = _flatten_state_dict(optim_state_dict) + flatten_local_osd, local_osd_mapping = _flatten_state_dict(local_state_dict) + if info.broadcast_from_rank0: + _broadcast_state_dict(flatten_osd, flatten_local_osd, device=device) + else: + _distribute_state_dict(flatten_osd, flatten_local_osd, device=device) + # The modifications listed seek to address the problem where optim might possess + # dissimilar parameters in comparison to optim_state_dict. This is achieved by + # incorporating differential parameters within local, which may result in optim + # having additional parameters ultimately. + for optim_key in flatten_osd.keys(): + if optim_key not in flatten_local_osd: + assert optim_key in osd_mapping + flatten_local_osd[optim_key] = flatten_osd[optim_key] + local_osd_mapping[optim_key] = osd_mapping[optim_key] + optim_state_dict = _unflatten_state_dict( + flatten_local_osd, local_osd_mapping + ) + + # Note that we do not have to convert the FQN back to param id here if + # order in optim.param_groups[idx][_PARAMS] is the same as the one in + # optim_state_dict[_PG][idx][_PARAMS]. + _state_dict_fn(optim, "load_state_dict")(state_dict=optim_state_dict) + + +def get_model_state_dict( + model: nn.Module, + *, + submodules: Optional[Set[nn.Module]] = None, + options: Optional[StateDictOptions] = None, +) -> Dict[str, ValueType]: + """ + Return the model state_dict of ``model``. + + See ``get_state_dict`` for the detail usage. + + Args: + model (nn.Module): the nn.Module to the model. + submodules (deprecated): Optional[Set[nn.Module]]: only return the model parameters + that belong to the submodules. + options (StateDictOptions): the options to control how + model state_dict and optimizer state_dict should be returned. See + `StateDictOptions` for the details. + + Returns: + The state_dict for ``model``. + + :rtype: typing.Dict[str, ValueType] + """ + with _gc_context(): + info = _verify_options( + model, + (), + optim_only=False, + submodules=submodules, + options=options, + ) + model_state_dict = _get_model_state_dict(model, info) + _verify_state_dict(model_state_dict, {}, info) + return model_state_dict + + +def get_optimizer_state_dict( + model: nn.Module, + optimizers: Union[core.optim.Optimizer, Iterable[core.optim.Optimizer]], + *, + submodules: Optional[Set[nn.Module]] = None, + options: Optional[StateDictOptions] = None, +) -> OptimizerStateType: + """ + Return the combined state_dict for optimizers. + + See ``get_state_dict`` for the detail usage. + + Args: + model (nn.Module): the nn.Module to the model. + optimizers (Union[None, Optimizer, Iterable[Optimizer]]): + The optimizers that are used to optimize ``model``. + submodules (deprecated): Optional[Set[nn.Module]]: only return the model parameters + that belong to the submodules. + options (StateDictOptions): the options to control how + model state_dict and optimizer state_dict should be returned. See + `StateDictOptions` for the details. + + Returns: + The state_dict for ``optimizers``. + + :rtype: OptimizerStateType + """ + with _gc_context(): + optimizers = ( + (optimizers,) + if isinstance(optimizers, core.optim.Optimizer) + else tuple(optimizers) + ) + info = _verify_options( + model, + optimizers, + optim_only=True, + submodules=submodules, + options=options, + ) + optim_state_dict = _get_optim_state_dict(model, optimizers, info) + _verify_state_dict({}, optim_state_dict, info) + return optim_state_dict + + +def get_state_dict( + model: nn.Module, + optimizers: Union[core.optim.Optimizer, Iterable[core.optim.Optimizer]], + *, + submodules: Optional[Set[nn.Module]] = None, + options: Optional[StateDictOptions] = None, +) -> Tuple[Dict[str, ValueType], OptimizerStateType]: + """ + Return the model state_dict and optimizers state_dict. + + ``get_state_dict`` can process any module that is parallelized by PyTorch + FSDP/fully_shard, DDP/replicate, tensor_parallel/parallelize_module, and any + combination of these parallelisms. The main functions of ``get_state_dict`` + are: 1.) returning a model and optimizer state_dict that can be resharded + with a different number of trainers and/or different parallelisms. + 2.) hiding the parallelism-specific state_dict APIs. Users don't have to call + these APIs. + 3.) sanity checking the result state_dict. + + The keys of the result state dictionary are the canonical FQNs (Fully + Qualified Names). A canonical FQN refers to the FQN based on a parameter's + position in an nn.Module hierarchy. More specifically, a canonical FQN to a + parameter is the FQN returned by ``module.named_parameters()`` or + ``module.named_buffers()`` when the module is not distributed by any + parallelisms. Since the optimizer internally uses parameter IDs to represent + a parameter, there will be a conversion from the parameter IDs to the + canonical FQNs when calling this API. + + ``get_state_dict`` can also process a module that is not parallelized. In + such a case, ``get_state_dict`` only performs one function -- converting the + optimizer parameter IDs to the canonical FQNs. + + Example: + >>> # xdoctest: +SKIP + >>> from mindnlp import core + >>> from core.distributed.fsdp import FullyShardedDataParallel as FSDP + >>> from core.nn.parallel import DistributedDataParallel as DDP + >>> from core.distributed.checkpoint.state_dict import get_state_dict + + >>> fsdp_model = FSDP(copy.deepcopy(model)) + >>> fsdp_optim = core.optim.Adam(model.parameters(), lr=1e-3) + >>> ddp_model = DDP(copy.deepcopy(model)) + >>> ddp_optim = core.optim.Adam(model.parameters(), lr=1e-3) + + + >>> ddp_state_dict, ddp_optim_state_dict = get_state_dict(ddp_model, ddp_optim) + >>> fsdp_state_dict, fsdp_optim_state_dict = get_state_dict(fsdp_model, fsdp_optim) + + >>> # if we simply call ddp_model.state_dict() and fsdp_model.state_dict(), + >>> # the asserts will fail. + >>> assert ddp_state_dict == fsdp_state_dict + >>> assert ddp_optim_state == fsdp_optim_state_dict + + + Args: + model (nn.Module): the nn.Module to the model. + optimizers (Union[None, Optimizer, Iterable[Optimizer]]): + The optimizers that are used to optimize ``model``. + submodules (deprecated): Optional[Set[nn.Module]]: only return the model parameters + that belong to the submodules. + options (StateDictOptions): the options to control how + model state_dict and optimizer state_dict should be returned. See + `StateDictOptions` for the details. + + Returns: + ``Tuple`` that contain model state_dict and optimizer state_dict. + + :rtype: typing.Tuple[typing.Dict[str, ValueType], OptimizerStateType] + """ + + with _gc_context(): + optimizers = ( + (optimizers,) + if isinstance(optimizers, core.optim.Optimizer) + else tuple(optimizers) + ) + info = _verify_options( + model, + optimizers, + optim_only=False, + submodules=submodules, + options=options, + ) + model_state_dict = _get_model_state_dict(model, info) + optim_state_dict = _get_optim_state_dict(model, optimizers, info) + _verify_state_dict(model_state_dict, optim_state_dict, info) + return model_state_dict, optim_state_dict + + +def _unflatten_model_state_dict( + model: nn.Module, + state_dict: Union[Dict[nn.Module, Dict[str, ValueType]], Dict[str, ValueType]], +) -> Dict[str, ValueType]: + if not state_dict: + return {} + + if isinstance(next(iter(state_dict.keys())), nn.Module): + warnings.warn( + "Passing model_state_dict as a ``Dict[nn.Module, Dict[str, Any]]``" + "is deprecated and will be removed in 2.5. If you need this " + "feature, please preprocessing the model_state_dict to achieve the " + "same functionality.", + FutureWarning, + ) + cast_state_dict = cast(Dict[nn.Module, Dict[str, ValueType]], state_dict) + new_state_dict: Dict[str, ValueType] = {} + for submodule, sub_state_dict in cast_state_dict.items(): + for name, m in model.named_modules(): + if m != submodule: + continue + + fqns = _get_fqns(model, name) + assert len(fqns) == 1, "FQNs for a submodule should only have 1 element" + prefix = f"{next(iter(fqns))}." + new_state_dict.update( + {prefix + subfqn: value for subfqn, value in sub_state_dict.items()} + ) + return new_state_dict + else: + return cast(Dict[str, ValueType], state_dict) + + +def set_model_state_dict( + model: nn.Module, + model_state_dict: Dict[str, ValueType], + *, + options: Optional[StateDictOptions] = None, +) -> _IncompatibleKeys: + """Load the model state_dict. + + The counterpart of ``get_model_state_dict`` to set the state_dict to the + model. See ``set_state_dict`` for the detail usage. + + Args: + model (nn.Module): the nn.Module to the model. + model_state_dict: (Dict[str, ValueType]): + the model state_dict to load. If the key of the ``model_state_dict`` + is nn.Module, the key is a submodule of ``model`` and the value should + be the state_dict of the submodule. When loading the state_dict, + the prefix of the submodule will be append to the state_dict. + options (StateDictOptions): the options to control how + model state_dict and optimizer state_dict should be loaded. See + `StateDictOptions` for the details. + + Returns: + ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields: + * **missing_keys** is a list of str containing the missing keys + * **unexpected_keys** is a list of str containing the unexpected keys + + :type model_state_dict: typing.Dict[str, ValueType] + """ + model_state_dict: Dict[str, ValueType] = _unflatten_model_state_dict( + model, model_state_dict + ) + with _gc_context(): + info = _verify_options(model, (), optim_only=False, options=options) + + _verify_state_dict(model_state_dict, {}, info) + return _load_model_state_dict(model, model_state_dict, info) + + +def set_optimizer_state_dict( + model: nn.Module, + optimizers: Union[core.optim.Optimizer, Iterable[core.optim.Optimizer]], + optim_state_dict: OptimizerStateType, + *, + options: Optional[StateDictOptions] = None, +) -> None: + """Load the optimizers state_dict. + + The counterpart of ``get_optimizer_state_dict`` to set the state_dict to the + optimizers. See ``set_state_dict`` for the detail usage. + + Args: + model (nn.Module): the nn.Module to the model. + optimizers (Union[Optimizer, Iterable[Optimizer]]): + The optimizers that are used to optimize ``model``. + optim_state_dict: OptimizerStateType: + the optimizer state_dict to load. + options (StateDictOptions): the options to control how + model state_dict and optimizer state_dict should be loaded. See + `StateDictOptions` for the details. + + Returns: + None + + :type optim_state_dict: typing.OptimizerStateType + """ + with _gc_context(): + optimizers = ( + (optimizers,) + if isinstance(optimizers, core.optim.Optimizer) + else tuple(optimizers) + ) + info = _verify_options(model, optimizers, optim_only=True, options=options) + + _verify_state_dict({}, optim_state_dict, info) + _load_optim_state_dict(model, optimizers, optim_state_dict, info) + + +def set_state_dict( + model: nn.Module, + optimizers: Union[core.optim.Optimizer, Iterable[core.optim.Optimizer]], + *, + model_state_dict: Dict[str, ValueType], + optim_state_dict: OptimizerStateType, + options: Optional[StateDictOptions] = None, +) -> _IncompatibleKeys: + """Load the model state_dict and optimizers state_dict. + + The counterpart of ``get_state_dict`` to set the state_dict to the model and + optimizers. The given ``model_state_dict`` and ``optim_state_dict`` do not + have to be returned by ``get_state_dict`` but must meet the following + requirements: 1) all FQNs are canonical FQNs as defined in ``get_state_dict``, + 2) if a tensor is sharded, it must be either a ShardedTensor or DTensor, + 3) optimizer state_dict cannot contain the parameter IDs; the keys should be + the canonical FQNs. + + Args: + model (nn.Module): the nn.Module to the model. + optimizers (Union[Optimizer, Iterable[Optimizer]]): + The optimizers that are used to optimize ``model``. + model_state_dict: (Union[Dict[nn.Module, Dict[str, ValueType]], Dict[str, ValueType]]): + the model state_dict to load. If the key of the ``model_state_dict`` + is nn.Module, the key is a submodule of ``model`` and the value should + be the state_dict of the submodule. When loading the state_dict, + the prefix of the submodule will be append to the state_dict. + optim_state_dict: OptimizerStateType: + the optimizer state_dict to load. + options (StateDictOptions): the options to control how + model state_dict and optimizer state_dict should be loaded. See + `StateDictOptions` for the details. + + Returns: + ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields: + * **missing_keys** is a list of str containing the missing keys of the model state_dict. + * **unexpected_keys** is a list of str containing the unexpected keys of the model state_dict. + + :type model_state_dict: typing.Dict[str, ValueType] + :type optim_state_dict: typing.OptimizerStateType + """ + + model_state_dict: Dict[str, ValueType] = _unflatten_model_state_dict( + model, model_state_dict + ) + with _gc_context(): + optimizers = ( + (optimizers,) + if isinstance(optimizers, core.optim.Optimizer) + else tuple(optimizers) + ) + info = _verify_options( + model, optimizers, optim_only=not model_state_dict, options=options + ) + + _verify_state_dict(model_state_dict, optim_state_dict, info) + _load_optim_state_dict(model, optimizers, optim_state_dict, info) + return _load_model_state_dict(model, model_state_dict, info) + + +# TODO: correct the state_dict function signature. +# TODO: this API is not yet fully tested. Make it private +@no_type_check +def _patch_model_state_dict( + model: nn.Module, + *, + options: Optional[StateDictOptions] = None, +) -> None: + """Patch the ``state_dict`` and ``load_state_dict`` attributes of ``model``. + + Patch the ``state_dict`` and ``load_state_dict`` attributes of ``model`` to + be a partial function to call ``get_state_dict`` and ``set_state_dict``. + + Example: + from core.distributed.fsdp import FullyShardedDataParallel as FSDP + from core.distributed.checkpoint.state_dict import patch_model_state_dict + + model = fsdp(model) + patch_model_state_dict(model) + + Args: + model (nn.Module): the nn.Module to the model. + options (StateDictOptions): the options to control how + model state_dict and optimizer state_dict should be loaded. See + `StateDictOptions` for the details. + Returns: + None + """ + + _state_dict_call = functools.partial( + get_model_state_dict, + model=model, + options=options, + ) + + def state_dict_call(): + return _state_dict_call() + + model.state_dict = state_dict_call + + _load_state_dict_call = functools.partial( + set_model_state_dict, + model=model, + options=options, + ) + + def load_state_dict_call(state_dict: Dict[str, Any]): + _load_state_dict_call(model_state_dict=state_dict) + + model.load_state_dict = load_state_dict_call + + _patched_state_dict.add(state_dict_call) + _patched_state_dict.add(load_state_dict_call) + + +# TODO: correct the load_state_dict function signature. +# TODO: this API is not yet fully tested. Make it private +@no_type_check +def _patch_optimizer_state_dict( + model: nn.Module, + *, + optimizers: Tuple[core.optim.Optimizer, ...], + options: Optional[StateDictOptions] = None, +) -> None: + """Patch the ``state_dict`` and ``load_state_dict`` attributes of ``optimizers``. + + Patch the ``state_dict`` and ``load_state_dict`` attributes of ``optimizers`` to + be a partial function to call ``get_state_dict`` and ``set_state_dict``. + + Note that if there are multiple optimizers, all of the optimizers will be patched. + So users only need to call one of the state_dict() to get the full result. + + Example: + from core.distributed.fsdp import FullyShardedDataParallel as FSDP + from core.distributed.checkpoint.state_dict import patch_model_state_dict + + model = fsdp(model) + patch_model_state_dict(model) + + Args: + model (nn.Module): the nn.Module to the model. + options (StateDictOptions): the options to control how + model state_dict and optimizer state_dict should be loaded. See + `StateDictOptions` for the details. + Returns: + None + """ + + _state_dict_call = functools.partial( + get_optimizer_state_dict, + model=model, + optimizers=optimizers, + options=options, + ) + + def state_dict_call(): + return _state_dict_call() + + _load_state_dict_call = functools.partial( + set_optimizer_state_dict, + model=model, + optimizers=optimizers, + options=options, + ) + + def load_state_dict_call(state_dict: Dict[str, Any]): + _load_state_dict_call(optim_state_dict=state_dict) + + _patched_state_dict.add(state_dict_call) + _patched_state_dict.add(load_state_dict_call) + optimizers = ( + (optimizers,) + if isinstance(optimizers, core.optim.Optimizer) + else tuple(optimizers) + ) + for optim in optimizers: + optim.state_dict = state_dict_call + optim.load_state_dict = load_state_dict_call diff --git a/mindnlp/core/distributed/checkpoint/state_dict_loader.py b/mindnlp/core/distributed/checkpoint/state_dict_loader.py new file mode 100644 index 000000000..ce9372a93 --- /dev/null +++ b/mindnlp/core/distributed/checkpoint/state_dict_loader.py @@ -0,0 +1,323 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +import os +import warnings +from typing import Any, cast, Dict, Optional, Set, Union +from typing_extensions import deprecated + +from mindnlp import core +from mindnlp import core.distributed as dist +from core.distributed.checkpoint.default_planner import _EmptyStateDictLoadPlanner +from core.distributed.checkpoint.logger import _dcp_method_logger +from core.distributed.checkpoint.stateful import Stateful + +from ._storage_utils import _storage_setup +from .default_planner import DefaultLoadPlanner +from .planner import LoadPlan, LoadPlanner +from .storage import StorageReader +from .utils import _all_gather_keys, _api_bc_check, _DistWrapper, _profile + + +__all__ = ["load_state_dict", "load"] + + +@deprecated( + "`load_state_dict` is deprecated and will be removed in future versions. " + "Please use `load` instead.", + category=FutureWarning, +) +def load_state_dict( + state_dict: Dict[str, Any], + storage_reader: StorageReader, + process_group: Optional[dist.ProcessGroup] = None, + coordinator_rank: int = 0, + no_dist: bool = False, + planner: Optional[LoadPlanner] = None, +) -> None: + """This method is deprecated. Please switch to 'load'.""" + storage_reader.reset() + with _profile(): + # TODO: test returning `load` here instead. + return _load_state_dict( + state_dict, + storage_reader, + process_group, + coordinator_rank, + no_dist, + planner, + ) + + +@_dcp_method_logger(log_exceptions=True) +@_api_bc_check +def load( + state_dict: Dict[str, Any], + *, + checkpoint_id: Union[str, os.PathLike, None] = None, + storage_reader: Optional[StorageReader] = None, + planner: Optional[LoadPlanner] = None, + process_group: Optional[dist.ProcessGroup] = None, +) -> None: + """ + Load a distributed ``state_dict`` in SPMD style. + + Each rank will try to read the least amount of data necessary + to fullfill the requested `state_dict`. When loading :class:`ShardedTensor` + or :class:`DTensor` instances, each rank only reads data for their local shards. + + For each ``Stateful`` object (having both a ``state_dict`` and a ``load_state_dict``), + load will first call ``state_dict`` before attempting deserialization, followed by + ``load_state_dict`` once the deserialization is complete. + For each non-``Stateful`` object, load will deserailize the object, and then replace + it in the ``state_dict`` with the deserialized object. + + .. warning:: + All tensors in ``state_dict`` must be allocated on their + destination device *prior to* calling this function. + + All non-tensor data is loaded using `core.load()` and modified in place + on state_dict. + + .. warning:: + Users must call `load_state_dict` on the root module to ensure load + pos-processing and non-tensor data properly propagates. + + .. note: + If no process group is initialized, this function will assume the intent + is to load a checkpoint into the local process. This can be useful in the + case of local inference, and when using regular Tensors (as opposed to DTensor + or ShardedTensor) + + .. note: + Rank 0 is assumed to be the coordinator rank. + + Args: + state_dict (Dict[str, Any]): The state_dict to save. + checkpoint_id (Union[str, os.PathLike, None]): + The ID of this checkpoint instance. The meaning of the checkpoint_id + depends on the storage. It can be a path to a folder or to a file. + It can also be a key if the storage is a key-value store. + (Default: ``None``) + storage_reader (Optional[StorageReader]): + Instance of StorageWriter used to perform reads. If this is not + specified, DCP will automatically infer the reader based on the + checkpoint_id. If checkpoint_id is also None, an exception will + be raised. (Default: ``None``) + planner (Optional[LoadPlanner]): + Instance of LoadPlanner. If this is not specificed, the default + planner will be used. (Default: ``None``) + process_group (Optional[ProcessGroup]): + ProcessGroup to be used for cross-rank synchronization. + (Default: ``None``) + + Returns: + None. + + Examples + >>> # xdoctest: +SKIP + >>> my_model = MyModule() + >>> optimizer = Adagrad(my_model.parameters()) + >>> model_state_dict = my_model.state_dict() + >>> fs_storage_reader = core.distributed.checkpoint.FileSystemReader("/checkpoint/1") + + >>> core.distributed.checkpoint.load_state_dict( + >>> state_dict=model_state_dict, + >>> storage_reader=fs_storage_reader, + >>> ) + + >>> # module.load_state_dict() function might have customized steps + >>> # to flush the state_dict, must call it to + >>> # ensure correct behavior. + >>> my_model.load_state_dict(model_state_dict) + + .. note:: + load_state_dict uses collectives to coordinate reads across ranks. + For NCCL-based process groups, internal tensor representations of + objects must be moved to the GPU device before communication takes place. + In this case, the device used is given by ``core.cuda.current_device()`` + and it is the user's responsibility to ensure that this is set so that each + rank has an individual GPU, via ``core.cuda.set_device()``. + """ + + no_dist = not (dist.is_available() and dist.is_initialized()) + if no_dist: + warnings.warn( + "core.distributed is unavailable or uninitialized, assuming the intent is to load in a single process." + ) + + with _profile(): + storage_reader = cast( + StorageReader, _storage_setup(storage_reader, checkpoint_id, reader=True) + ) + + if no_dist: + keys = list(state_dict.keys()) + else: + keys = _all_gather_keys(state_dict, process_group) + if keys != sorted(state_dict.keys()): + warnings.warn( + "Detected mismatched keys in state dict after all gather!" + " This behavior is unsupported and may cause errors may cause errors." + ) + + statetful_sd = {} + for key in keys: + if key not in state_dict: + continue + elem = state_dict[key] + statetful_sd[key] = ( + elem.state_dict() if isinstance(elem, Stateful) else elem + ) + + _load_state_dict( + state_dict=statetful_sd, + storage_reader=storage_reader, + process_group=process_group, + no_dist=no_dist, + planner=planner, + ) + for key in keys: + if key not in state_dict: + continue + elem = state_dict[key] + if isinstance(elem, Stateful): + # If the state_dict is a Stateful object, + # DCP does an in-place load in the original state dict. + elem.load_state_dict(statetful_sd[key]) + else: + # Otherwise, replace the state_dict with the loaded state_dict. + state_dict[key] = statetful_sd[key] + + +def _load_state_dict( + state_dict: Dict[str, Any], + storage_reader: StorageReader, + process_group: Optional[dist.ProcessGroup] = None, + coordinator_rank: int = 0, + no_dist: bool = False, + planner: Optional[LoadPlanner] = None, +) -> None: + # core._C._log_api_usage_once("core.distributed.checkpoint.load_state_dict") + + distW = _DistWrapper(process_group, not no_dist, coordinator_rank) + if planner is None: + planner = DefaultLoadPlanner() + + ckpt_kwargs = {} + if (ckpt_id := getattr(storage_reader, "checkpoint_id", None)) is not None: + ckpt_kwargs["checkpoint_id"] = ckpt_id + ckpt_kwargs["process_group"] = distW.group + + @_dcp_method_logger(**ckpt_kwargs) + def local_step(): + assert planner is not None + metadata = storage_reader.read_metadata() + planner.set_up_planner(state_dict, metadata, distW.is_coordinator) + storage_reader.set_up_storage_reader(metadata, distW.is_coordinator) + + local_plan = planner.create_local_plan() + local_plan = storage_reader.prepare_local_plan(local_plan) + return local_plan + + @_dcp_method_logger(**ckpt_kwargs) + def global_step(all_local_plans): + assert planner is not None + all_local_plans = planner.create_global_plan(all_local_plans) + all_local_plans = storage_reader.prepare_global_plan(all_local_plans) + return all_local_plans + + central_plan: LoadPlan = distW.reduce_scatter("plan", local_step, global_step) + @_dcp_method_logger(**ckpt_kwargs) + def read_data(): + assert planner is not None + final_local_plan = planner.finish_plan(central_plan) + # all_reads = storage_reader.read_data(final_local_plan, planner) + storage_reader.read_data(final_local_plan, planner) + + # all_reads.wait() + return None + + _ = distW.all_gather("read", read_data) + + +def _load_state_dict_from_keys( + keys: Optional[Union[Set[str], str]] = None, + *, + checkpoint_id: Union[str, os.PathLike, None] = None, + storage_reader: Optional[StorageReader] = None, + process_group: Optional[dist.ProcessGroup] = None, +) -> Dict[str, Any]: + """ + Load only the specified keys from the checkpoint, if no keys are specified, the entire + checkpoint will be loaded. Note, this method completely loads the checkpoint into the + current process and is not distributed. + + .. warning:: + + + .. warning:: + + All non-tensor data is loaded using `core.load()` + + .. note: + As opposed to the usual pattern, this function does not take a state dict as input + and does not load inplace. Instead, a new state dict is directly initialized and read + from file. + + .. note: + If no process group is initialized, this function will assume the intent + is to load a checkpoint into the local process. This can be useful in the + case of local inference, and when using regular Tensors (as opposed to DTensor + or ShardedTensor) + + .. note: + Rank 0 is assumed to be the coordinator rank. + + Args: + keys (Optional[Union[Set[str], str]]): + Loads any key specified in this set. If no keys are specified, the entire checkpoint + is loaded. + checkpoint_id (Union[str, os.PathLike, None]): + The ID of this checkpoint instance. The meaning of the checkpoint_id + depends on the storage. It can be a path to a folder or to a file. + It can also be a key if the storage is a key-value store. + (Default: ``None``) + storage_reader (Optional[StorageReader]): + Instance of StorageWriter used to perform reads. If this is not + specified, DCP will automatically infer the reader based on the + checkpoint_id. If checkpoint_id is also None, an exception will + be raised. (Default: ``None``) + process_group (Optional[ProcessGroup]): + ProcessGroup to be used for cross-rank synchronization. + (Default: ``None``) + + Returns: + State dict from specified keys + """ + core._C._log_api_usage_once( + "core.distributed.checkpoint._load_state_dict_from_keys" + ) + + no_dist = not (dist.is_available() and dist.is_initialized()) + if no_dist: + warnings.warn( + "core.distributed is unavailable or uninitialized, assuming the intent is to load in a single process." + ) + + storage_reader = cast( + StorageReader, _storage_setup(storage_reader, checkpoint_id, reader=True) + ) + + if isinstance(keys, str): + keys = {keys} + + sd: Dict[str, Any] = {} + _load_state_dict( + state_dict=sd, + storage_reader=storage_reader, + process_group=process_group, + no_dist=no_dist, + planner=_EmptyStateDictLoadPlanner(keys=keys or set()), + ) + + return sd diff --git a/mindnlp/core/distributed/checkpoint/state_dict_saver.py b/mindnlp/core/distributed/checkpoint/state_dict_saver.py new file mode 100644 index 000000000..67dd9d251 --- /dev/null +++ b/mindnlp/core/distributed/checkpoint/state_dict_saver.py @@ -0,0 +1,334 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +import inspect +import os +import warnings +from concurrent.futures import Future, ThreadPoolExecutor +from typing import cast, Optional, Union +from typing_extensions import deprecated + +from mindnlp import core +from mindnlp import core.distributed as dist +from core.distributed._state_dict_utils import _offload_state_dict_to_cpu +from core.distributed.checkpoint._storage_utils import _storage_setup +from core.distributed.checkpoint.default_planner import DefaultSavePlanner +from core.distributed.checkpoint.logger import _dcp_method_logger +from core.distributed.checkpoint.metadata import Metadata, STATE_DICT_TYPE +from core.distributed.checkpoint.planner import SavePlan, SavePlanner +from core.distributed.checkpoint.staging import AsyncStager +from core.distributed.checkpoint.stateful import Stateful +from core.distributed.checkpoint.storage import StorageWriter +from core.distributed.distributed_c10d import _get_default_group + +from .utils import _api_bc_check, _DistWrapper, _profile + + +__all__ = ["save_state_dict", "save", "async_save"] + + +@deprecated( + "`save_state_dict` is deprecated and will be removed in future versions." + "Please use `save` instead.", + category=FutureWarning, +) +def save_state_dict( + state_dict: STATE_DICT_TYPE, + storage_writer: StorageWriter, + process_group: Optional[dist.ProcessGroup] = None, + coordinator_rank: int = 0, + no_dist: bool = False, + planner: Optional[SavePlanner] = None, +) -> Metadata: + """This method is deprecated. Please switch to 'save'.""" + storage_writer.reset() + + # TODO: test returning `save` here instead. + with _profile(): + return _save_state_dict( + state_dict, + storage_writer, + process_group, + coordinator_rank, + no_dist, + planner, + ) + + +@_dcp_method_logger(log_exceptions=True) # type: ignore[arg-type] +@_api_bc_check +def save( + state_dict: STATE_DICT_TYPE, + *, + checkpoint_id: Union[str, os.PathLike, None] = None, + storage_writer: Optional[StorageWriter] = None, + planner: Optional[SavePlanner] = None, + process_group: Optional[dist.ProcessGroup] = None, +) -> Metadata: + """ + Save a distributed model in SPMD style. + + This function is different from ``core.save()`` as it handles + ``ShardedTensor`` , and ``DTensor`` by having each rank only save their local shards. + + For each ``Stateful`` object (having both a ``state_dict`` and a ``load_state_dict``), + save will call ``state_dict`` before serialization. + + .. warning:: + There is no guarantees of Backwards Compatibility across PyTorch versions + for saved state_dicts. + + .. warning:: + If using the `process_group` argument, make sure that only its ranks + call `save_state_dict` and that all data in state_dict belong to it. + + .. note:: + When saving checkpoint for FSDP's `ShardingStrategy.HYBRID_SHARD`, only one of + the shard_group should be calling `save_state_dict` and the corresponding process + group needs to be passed in. + + .. note:: + If no process group is available, this function assumes the intention is to save the + state_dict in the local process. + + .. note: + Rank 0 is assumed to be the coordinator rank. + + + Args: + state_dict (Dict[str, Any]): The state_dict to save. + checkpoint_id (Union[str, os.PathLike, None]): + The ID of this checkpoint instance. The meaning of the checkpoint_id + depends on the storage. It can be a path to a folder or to a file. + It can also be a key if the storage is a key-value store. + (Default: ``None``) + storage_writer (Optional[StorageWriter]): + Instance of StorageWriter used to perform writes. If this is not + specified, DCP will automatically infer the writer based on the + checkpoint_id. If checkpoint_id is also None, an exception will + be raised. (Default: ``None``) + planner (Optional[SavePlanner]): + Instance of SavePlanner. If this is not specificed, the default + planner will be used. (Default: ``None``) + process_group (Optional[ProcessGroup]): + ProcessGroup to be used for cross-rank synchronization. + (Default: ``None``) + + Returns: + Metadata: Metadata object for the saved checkpoint. + + Example: + >>> # xdoctest: +SKIP + >>> my_model = MyModule() + + >>> state_dict = {"model": my_model} + + >>> fs_storage_writer = core.distributed.checkpoint.FileSystemWriter("/checkpoint/1") + >>> core.distributed.checkpoint.save( + >>> state_dict=state_dict, + >>> storage_writer=fs_storage_writer, + >>> ) + + .. note:: + save_state_dict uses collectives to coordinate writes across ranks. + For NCCL-based process groups, internal tensor representations of + objects must be moved to the GPU device before communication takes place. + In this case, the device used is given by ``core.cuda.current_device()`` + and it is the user's responsibility to ensure that this is set so that + each rank has an individual GPU, via ``core.cuda.set_device()``. + """ + core._C._log_api_usage_once("core.distributed.checkpoint.save") + + no_dist = not (dist.is_available() and dist.is_initialized()) + if no_dist: + warnings.warn( + "core.distributed is unavailable or uninitialized, assuming the intent is to save in a single process." + ) + + with _profile(): + storage_writer = cast( + StorageWriter, _storage_setup(storage_writer, checkpoint_id, reader=False) + ) + + return _save_state_dict( + state_dict=_stateful_to_state_dict(state_dict), + storage_writer=storage_writer, + process_group=process_group, + no_dist=no_dist, + planner=planner, + ) + + +@_dcp_method_logger(log_exceptions=True) +def async_save( + state_dict: STATE_DICT_TYPE, + *, + checkpoint_id: Union[str, os.PathLike, None] = None, + storage_writer: Optional[StorageWriter] = None, + planner: Optional[SavePlanner] = None, + process_group: Optional[dist.ProcessGroup] = None, +) -> Future: + """Asynchronous version of ``save``. This code first de-stages the state_dict on to the + staging storage (defaults to CPU memory), and then calls the `save` in a separate thread. + + .. warning:: + This feature is experimental and subject to change. + + Args: + state_dict (Dict[str, Any]): The state_dict to save. + checkpoint_id (Union[str, os.PathLike, None]): + The ID of this checkpoint instance. The meaning of the checkpoint_id + depends on the storage. It can be a path to a folder or to a file. + It can also be a key if the storage is a key-value store. + (Default: ``None``) + storage_writer (Optional[StorageWriter]): + Instance of StorageWriter used to perform 'stage' and 'save'. If + this is not specified, DCP will automatically infer the writer based on the + checkpoint_id. If checkpoint_id is also None, an exception will + be raised. (Default: ``None``) + planner (Optional[SavePlanner]): + Instance of SavePlanner. If this is not specificed, the default + planner will be used. (Default: ``None``) + process_group (Optional[ProcessGroup]): + ProcessGroup to be used for cross-rank synchronization. + (Default: ``None``) + + Returns: + Future: A future holding the resultant Metadata object from `save`. + + Example: + >>> # xdoctest: +SKIP + >>> my_model = MyModule() + + >>> state_dict = {"model": my_model} + + >>> fs_storage_writer = core.distributed.checkpoint.FileSystemWriter("/checkpoint/1") + >>> checkpoint_future = core.distributed.checkpoint.async_save( + >>> state_dict=state_dict, + >>> storage_writer=fs_storage_writer, + >>> ) + >>> + >>> # ... do some work ... + >>> + >>> checkpoint_future.result() + + """ + core._C._log_api_usage_once("core.distributed.checkpoint.async_save") + + if dist.is_available() and dist.is_initialized(): + pg = process_group or _get_default_group() + assert ( + core.device("cpu") in pg._device_types # type: ignore[attr-defined] + ), "A CPU backend must be enabled for async save; try initializing process group with 'cpu:gloo,cuda:nccl'" + + storage_writer = cast( + StorageWriter, _storage_setup(storage_writer, checkpoint_id, reader=False) + ) + + state_dict = _stateful_to_state_dict(state_dict) + if isinstance(storage_writer, AsyncStager): + staged_state_dict = storage_writer.stage(state_dict) + else: # provides bwc for storage_writers not implementing AsyncStager + staged_state_dict = _offload_state_dict_to_cpu(state_dict, type_check=False) + + executor = ThreadPoolExecutor(max_workers=1) + f: Future = executor.submit( + save, + staged_state_dict, + checkpoint_id=checkpoint_id, + storage_writer=storage_writer, + planner=planner, + process_group=process_group, + ) + f.add_done_callback(lambda f: executor.shutdown(wait=False)) + + if ( + isinstance(storage_writer, AsyncStager) + and storage_writer.should_synchronize_after_execute + ): + storage_writer.synchronize_staging() + + return f + + +def _stateful_to_state_dict(state_dict: STATE_DICT_TYPE) -> STATE_DICT_TYPE: + """Creates a shallow copy of `state_dict` where `state_dict` is called for each Stateful object.""" + stateful_state_dict = {} + for key, elem in state_dict.items(): + stateful_state_dict[key] = ( + elem.state_dict() if isinstance(elem, Stateful) else elem + ) + return stateful_state_dict + + +def _save_state_dict( + state_dict: STATE_DICT_TYPE, + storage_writer: StorageWriter, + process_group: Optional[dist.ProcessGroup] = None, + coordinator_rank: int = 0, + no_dist: bool = False, + planner: Optional[SavePlanner] = None, +) -> Metadata: + core._C._log_api_usage_once("core.distributed.checkpoint.save_state_dict") + + distW = _DistWrapper(process_group, not no_dist, coordinator_rank) + if planner is None: + planner = DefaultSavePlanner() + assert planner is not None + + global_metadata = None + + ckpt_kwargs = {} + if (ckpt_id := getattr(storage_writer, "checkpoint_id", None)) is not None: + ckpt_kwargs["checkpoint_id"] = ckpt_id + ckpt_kwargs["process_group"] = distW.group + + @_dcp_method_logger(**ckpt_kwargs) + def local_step(): + assert planner is not None + storage_meta = storage_writer.storage_meta() + if "storage_meta" not in inspect.signature(planner.set_up_planner).parameters: + warnings.warn( + "The function definition for SavePlanner.set_up_planner has been updated" + " to include the storage_meta argument. Please update your implementation" + " to include this parameter." + ) + planner.set_up_planner(state_dict, distW.is_coordinator) # type: ignore[call-arg, arg-type] + else: + planner.set_up_planner( + state_dict=state_dict, + storage_meta=storage_meta, + is_coordinator=distW.is_coordinator, + ) + storage_writer.set_up_storage_writer(distW.is_coordinator) + + local_plan = planner.create_local_plan() + local_plan = storage_writer.prepare_local_plan(local_plan) + return local_plan + + @_dcp_method_logger(**ckpt_kwargs) + def global_step(all_local_plans): + nonlocal global_metadata + + assert planner is not None + all_local_plans, global_metadata = planner.create_global_plan(all_local_plans) + all_local_plans = storage_writer.prepare_global_plan(all_local_plans) + return all_local_plans + + central_plan: SavePlan = distW.reduce_scatter("plan", local_step, global_step) + + @_dcp_method_logger(**ckpt_kwargs) + def write_data(): + assert planner is not None + final_local_plan = planner.finish_plan(central_plan) + all_writes = storage_writer.write_data(final_local_plan, planner) + + all_writes.wait() + return all_writes.value() + + @_dcp_method_logger(**ckpt_kwargs) + def finish_checkpoint(all_results): + assert global_metadata is not None + storage_writer.finish(metadata=global_metadata, results=all_results) + return global_metadata + + return distW.all_reduce("write", write_data, finish_checkpoint) diff --git a/mindnlp/core/distributed/checkpoint/stateful.py b/mindnlp/core/distributed/checkpoint/stateful.py new file mode 100644 index 000000000..d36c419ec --- /dev/null +++ b/mindnlp/core/distributed/checkpoint/stateful.py @@ -0,0 +1,42 @@ +from typing import Any, Dict, runtime_checkable, TypeVar +from typing_extensions import Protocol + + +__all__ = ["Stateful", "StatefulT"] + + +@runtime_checkable +class Stateful(Protocol): + """ + Stateful protocol for objects that can be checkpointed and restored. + """ + + def state_dict(self) -> Dict[str, Any]: + """ + Objects should return their state_dict representation as a dictionary. + The output of this function will be checkpointed, and later restored in + `load_state_dict()`. + + .. warning:: + Because of the inplace nature of restoring a checkpoint, this function + is also called during `core.distributed.checkpoint.load`. + + + Returns: + Dict: The objects state dict + """ + + ... + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + """ + Restore the object's state from the provided state_dict. + + Args: + state_dict: The state dict to restore from + """ + + ... + + +StatefulT = TypeVar("StatefulT", bound=Stateful) diff --git a/mindnlp/core/distributed/checkpoint/storage.py b/mindnlp/core/distributed/checkpoint/storage.py new file mode 100644 index 000000000..0b77dbdef --- /dev/null +++ b/mindnlp/core/distributed/checkpoint/storage.py @@ -0,0 +1,284 @@ +import abc +import os +from dataclasses import dataclass +from typing import Any, List, Optional, Union + +from core.distributed.checkpoint.metadata import Metadata, MetadataIndex, StorageMeta +from core.distributed.checkpoint.planner import ( + LoadPlan, + LoadPlanner, + SavePlan, + SavePlanner, +) +# from core.futures import Future + + +__all__ = ["WriteResult", "StorageWriter", "StorageReader"] + + +@dataclass(frozen=True) +class WriteResult: + index: MetadataIndex + + size_in_bytes: int + storage_data: Any + + +class StorageWriter(abc.ABC): + """ + Interface used by ``save_state_dict`` to write to storage. + + One StorageWriter instance acts as both the coordinator and the follower + in a distributed checkpoint. As part of initialization, each instance + is told its role. + + A subclass should expect the following sequence of calls. + + 0) (all ranks) set checkpoint_id if users pass a valid checkpoint_id. + 1) (all ranks) set_up_storage_writer() + 2) (all ranks) prepare_local_plan() + 3) (coordinator) prepare_global_plan() + 4) (all ranks) write_data() + 5) (coordinator) finish() + """ + + @abc.abstractmethod + def reset(self, checkpoint_id: Union[str, os.PathLike, None] = None) -> None: + """ + Calls to indicates a brand new checkpoint write is going to happen. + A checkpoint_id may be present if users set the checkpoint_id for + this checkpoint write. The meaning of the checkpiont_id is + storage-dependent. It can be a path to a folder/file or a key for + a key-value storage. + + Args: + checkpoint_id (Union[str, os.PathLike, None]): + The ID of this checkpoint instance. The meaning of the checkpoint_id + depends on the storage. It can be a path to a folder or to a file. + It can also be a key if the storage is a key-value store. + (Default: ``None``) + """ + ... + + @abc.abstractmethod + def set_up_storage_writer(self, is_coordinator: bool) -> None: + """ + Initialize this instance. + + Args: + is_coordinator (bool): Whether this instance is responsible for coordinating + the checkpoint. + """ + + @abc.abstractmethod + def prepare_local_plan(self, plan: SavePlan) -> SavePlan: + """ + Perform storage-specific local planning. + + While this method can produce a completely different plan, the recommended + way is to store storage specific data in SavePlan::storage_data. + + Args: + plan (SavePlan): The local plan from the ``SavePlanner`` in use. + + Returns: + A transformed ``SavePlan`` after storage local planning + """ + + @abc.abstractmethod + def prepare_global_plan(self, plans: List[SavePlan]) -> List[SavePlan]: + """ + Perform centralized planning of storage. + + This method is only called on the coordinator instance. + + While this method can produce a completely different plan, the preferred + way is to store storage specific data in SavePlan::storage_data. + + Args: + plans: A list of ``SavePlan`` instances, one for each rank. + + Returns: + A list of transformed ``SavePlan`` after storage global planning + """ + + @abc.abstractmethod + def write_data( + self, plan: SavePlan, planner: SavePlanner + ): + """ + Write all items from ``plan`` using ``planner`` to resolve the data. + + A subclass should call ``SavePlanner::resolve_data`` on each item + from the plan to get access to the underlying object to write. + + Subclasses should lazily call `resolve_data` as it can allocate memory. + In case of tensors, make following assumptions: + + - They might be on any device, including not matching the one on ``WriteItem::tensor_data`` + - They might be views or not contiguous. Only the projection needs to be saved. + + Args: + plan (SavePlan): The save plan to execute. + planner (SavePlanner): Planner object to be used to resolve items to data. + + Returns: + A future that completes to a list of WriteResult + """ + + @abc.abstractmethod + def finish(self, metadata: Metadata, results: List[List[WriteResult]]) -> None: + """ + Write the metadata and marks the current checkpoint as successful. + + The actual format/schema used for serializing `metadata` is an + implementation detail. The only requirement is that it's recoverable + in to the same object graph. + + Args: + metadata (Metadata): metadata for the new checkpoint + results: A list of WriteResults from all ranks. + + Returns: + None + """ + + @classmethod + @abc.abstractmethod + def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool: + """ + Check if the given checkpoint_id is supported by the stroage. This allow + us to enable automatic storage selection. + """ + ... + + def storage_meta(self) -> Optional[StorageMeta]: + """ + Return the storage-specific metadata. This is used to store additional information + in a checkpoint that can be useful for providing request-level observability. StorageMeta + is passed to the ``SavePlanner`` during save calls. Returns None by default. + + TODO: provide an example + """ + return None + + +class StorageReader(abc.ABC): + """ + Interface used by ``load_state_dict`` to read from storage. + + One StorageReader instance acts as both the coordinator and the follower + in a distributed checkpoint. As part of initialization, each instance + is told its role. + + A subclass should expected the following sequence of calls by ``load_state_dict``: + + 0) (all ranks) set checkpoint_id if users pass a valid checkpoint_id. + 1) (all ranks) read_metadata() + 2) (all ranks) set_up_storage_reader() + 3) (all ranks) prepare_local_plan() + 4) (coordinator) prepare_global_plan() + 5) (all ranks) read_data() + """ + + @abc.abstractmethod + def reset(self, checkpoint_id: Union[str, os.PathLike, None] = None) -> None: + """ + Calls to indicates a brand new checkpoint read is going to happen. + A checkpoint_id may be present if users set the checkpoint_id for + this checkpoint read. The meaning of the checkpiont_id is + storage-dependent. It can be a path to a folder/file or a key for + a key-value storage. + + Args: + checkpoint_id (Union[str, os.PathLike, None]): + The ID of this checkpoint instance. The meaning of the checkpoint_id + depends on the storage. It can be a path to a folder or to a file. + It can also be a key if the storage is more like a key-value store. + (Default: ``None``) + """ + ... + + @abc.abstractmethod + def read_metadata(self) -> Metadata: + """ + Read the checkpoint metadata. + + Returns: + The metadata object associated with the checkpoint being loaded. + + """ + + @abc.abstractmethod + def set_up_storage_reader(self, metadata: Metadata, is_coordinator: bool) -> None: + """ + Initialize this instance. + + Args: + metadata (Metadata): The metadata schema to use. + is_coordinator (bool): Whether this instance is responsible for coordinating + the checkpoint. + """ + + @abc.abstractmethod + def prepare_local_plan(self, plan: LoadPlan) -> LoadPlan: + """ + Perform storage-specific local planning. + + While this method can produce a completely different plan, the recommended + way is to store storage specific data in LoadPlan::storage_data. + + Args: + plan (LoadPlan): The local plan from the ``LoadPlan`` in use. + + Returns: + A transformed ``LoadPlan`` after storage local planning + """ + + @abc.abstractmethod + def prepare_global_plan(self, plans: List[LoadPlan]) -> List[LoadPlan]: + """ + Perform centralized planning of storage loading. + + This method is only called on the coordinator instance. + + While this method can produce a completely different plan, the preferred + way is to store storage specific data in LoadPlan::storage_data. + + Args: + plans: A list of ``LoadPlan`` instances, one for each rank. + + Returns: + A list of transformed ``LoadPlan`` after storage global planning + """ + + @abc.abstractmethod + def read_data(self, plan: LoadPlan, planner: LoadPlanner): + """ + Read all items from ``plan`` using ``planner`` to resolve the data. + + A subclass should call ``LoadPlanner::load_bytes`` to deserialize a BytesIO + object into the right place. + + A subclass should call ``LoadPlanner::resolve_tensor`` to get access to the + tensors that in should load data into. + + It's the StorageLayer responsibility to properly schedule any cross device copies + required. + + Args: + plan (LoadPlan): The local plan to execute on + planner (LoadPlanner): The planner object to use to resolve items. + + Returns: + A future that completes once all reads are finished. + """ + + @classmethod + @abc.abstractmethod + def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool: + """ + Check if the given checkpoint_id is supported by the stroage. This allow + us to enable automatic storage selection. + """ + ... diff --git a/mindnlp/core/distributed/checkpoint/utils.py b/mindnlp/core/distributed/checkpoint/utils.py new file mode 100644 index 000000000..b793e5dd2 --- /dev/null +++ b/mindnlp/core/distributed/checkpoint/utils.py @@ -0,0 +1,431 @@ +# mypy: allow-untyped-defs +import cProfile +import inspect +import io +import itertools +import os +import warnings +from contextlib import contextmanager +from functools import wraps +from pstats import Stats +from typing import Any, Callable, cast, Dict, List, Optional, Sequence, TypeVar, Union + +from mindnlp import core +from mindnlp import core.distributed as dist +from core.distributed._shard.sharded_tensor import ShardedTensor +from core.distributed._shard.sharded_tensor.shard import Shard + +from .api import ( + _is_wrapped_exception, + _wrap_exception, + CheckpointException, + WRAPPED_EXCEPTION, +) +from .metadata import MetadataIndex, STATE_DICT_TYPE + + +__all__ = ["find_tensor_shard", "find_state_dict_object"] + +T = TypeVar("T") +R = TypeVar("R") + + +def _get_failure_dict( + results: List[Union[T, WRAPPED_EXCEPTION]] +) -> Dict[int, WRAPPED_EXCEPTION]: + return cast( + Dict[int, WRAPPED_EXCEPTION], + {i: err for i, err in enumerate(results) if _is_wrapped_exception(err)}, + ) + + +def _all_gather_keys( + local_dict: Dict[Any, Any], group: Optional[dist.ProcessGroup] = None +) -> List[Any]: + """Gathers all keys, and returns them sorted.""" + keys = list(local_dict.keys()) + gathered_keys: List[List[Any]] = [None] * dist.get_world_size(group) # type: ignore[list-item] + + dist.all_gather_object(gathered_keys, keys, group=group) + return sorted(set(itertools.chain.from_iterable(gathered_keys))) + + +class _DistWrapper: + """ + This is a wrapper around PG that provides a series of features around object collectives. + + It works without distributed initialized, where most collectives turns into nops. + + All variants that take functions are exception robust, meaning that if one or more + ranks raise errors, all ranks will observe those. + """ + + def __init__( + self, + group: Optional[dist.ProcessGroup], + use_dist: bool, + coordinator_rank: int, + ): + self.group = group + self.use_dist = use_dist + self.coordinator_rank = coordinator_rank + if self.use_dist: + self.rank = dist.get_rank(group) + self.is_coordinator = self.rank == coordinator_rank + else: + self.rank = 0 + self.is_coordinator = True + + def get_rank(self) -> int: + return self.rank + + def get_world_size(self) -> int: + if self.use_dist: + return dist.get_world_size(self.group) + return 1 + + def broadcast_object(self, object: Optional[T]) -> T: + """Implement functionality similar to c10d::broadcast_object_list but without distributed enabled.""" + object_list = [object] + if self.use_dist: + dist.broadcast_object_list( + object_list=object_list, + group=self.group, + src=self.coordinator_rank, + ) + return cast(T, object_list[0]) + + def gather_object(self, object: T) -> Optional[List[T]]: + """Implement functionality similar to c10d::gather_object but without distributed enabled.""" + if self.use_dist: + gather_objs = ( + cast(List[T], [None] * dist.get_world_size(self.group)) + if self.is_coordinator + else None + ) + + dist.gather_object( + obj=object, + object_gather_list=gather_objs if self.is_coordinator else None, + dst=self.coordinator_rank, + group=self.group, + ) + result = gather_objs + else: + result = [object] + return result + + def all_gather_object(self, object: T) -> List[T]: + """Implement functionality similar to c10d::all_gather_object but without distributed enabled.""" + if self.use_dist: + gather_objs = cast(List[T], [None] * dist.get_world_size(self.group)) + + dist.all_gather_object( + object_list=gather_objs, obj=object, group=self.group + ) + else: + gather_objs = [object] + return gather_objs + + def scatter_object(self, object_list: Optional[List[T]]) -> T: + """Implement functionality similar to c10d::scatter_object but without distributed enabled.""" + if self.use_dist: + gather_result = cast(List[T], [None]) + dist.scatter_object_list( + scatter_object_output_list=gather_result, + scatter_object_input_list=object_list if self.is_coordinator else None, + src=self.coordinator_rank, + group=self.group, + ) + + local_reply = gather_result[0] + else: + assert object_list is not None + local_reply = object_list[0] + return local_reply + + def reduce_scatter( + self, + step: str, + map_fun: Callable[[], T], + reduce_fun: Callable[[List[T]], List[R]], + ) -> R: + """ + Compute a value on each rank, then do centralized reduce on a single rank, followed by a scatter. + + This method operates in the following way: + Run ``map_fun`` on all ranks + Gather results on rank 0 + Call ``reduce_fun`` on all those values + Scatter to each rank part of the result. + """ + local_data: Union[WRAPPED_EXCEPTION, T] + try: + local_data = map_fun() + except BaseException as e: + local_data = _wrap_exception(e) + + all_data = self.gather_object(local_data) + all_results: Optional[List[Union[R, CheckpointException]]] = None + if self.is_coordinator: + assert all_data is not None + node_failures = _get_failure_dict(all_data) + + if len(node_failures) == 0: + try: + # N.B. why can't mypy cast List[R] to List[Union[R, WRAPPED_EXCEPTION]]? + all_results = cast( + List[Union[R, CheckpointException]], + reduce_fun(cast(List[T], all_data)), + ) + except BaseException as e: + node_failures[self.rank] = _wrap_exception(e) + + if len(node_failures) > 0: + all_results = [ + CheckpointException(step, node_failures) + ] * self.get_world_size() + + result = self.scatter_object(all_results) + if isinstance(result, CheckpointException): + raise result + return result + + def all_reduce( + self, + step: str, + map_fun: Callable[[], T], + reduce_fun: Callable[[List[T]], R], + ) -> R: + """ + Compute a value on each rank, then do centralized reduce on a single rank, followed by a broadcast. + + This method operates in the following way: + Run ``map_fun`` on all ranks + Gather results on rank 0 + Call ``reduce_fun`` on all those values + Broadcast the reduced value to all ranks. + """ + local_data: Union[T, WRAPPED_EXCEPTION] + try: + local_data = map_fun() + except BaseException as e: + local_data = _wrap_exception(e) + + all_data = self.gather_object(local_data) + result: Optional[Union[R, CheckpointException]] = None + if self.is_coordinator: + assert all_data is not None + node_failures = _get_failure_dict(all_data) + if len(node_failures) == 0: + try: + result = reduce_fun(cast(List[T], all_data)) + except BaseException as e: + node_failures[self.rank] = _wrap_exception(e) + + if len(node_failures) > 0: + result = CheckpointException(step, node_failures) + + final_result = self.broadcast_object(result) + if isinstance(final_result, CheckpointException): + raise final_result + return cast(R, final_result) + + def all_gather( + self, + step: str, + map_fun: Callable[[], T], + ) -> List[T]: + """ + Compute a value on each rank, then all_gather them. + + This method operates in the following way: + Run ``map_cp`` on all ranks + all_gather the values to all ranks + """ + result: Union[T, WRAPPED_EXCEPTION] + try: + result = map_fun() + except BaseException as e: + result = _wrap_exception(e) + + all_results = self.all_gather_object(result) + + node_failures = _get_failure_dict(all_results) + if len(node_failures) > 0: + raise CheckpointException(step, node_failures) + return cast(List[T], all_results) + + def broadcast( + self, + step: str, + map_fun: Callable[[], T], + ) -> T: + """ + Compute a value on rank 0 and broadcast it. + + This method operates in the following way: + Run ``map_cp`` on rank 0 + broadcast the value + """ + result: Optional[Union[T, CheckpointException]] = None + if self.is_coordinator: + try: + result = map_fun() + except BaseException as e: + result = CheckpointException(step, {self.rank: _wrap_exception(e)}) + final_result = self.broadcast_object(result) + if isinstance(final_result, CheckpointException): + raise final_result + return cast(T, final_result) + + +def _find_shard(tensor: ShardedTensor, index: MetadataIndex) -> Shard: + if index.offset is None: + raise ValueError( + f"Cannot lookup {index.fqn} since its a ShardedTensor and no offset was provided" + ) + + shards = tensor.local_shards() + # index fast path + if index.index is not None: + if ( + len(shards) > index.index + and core.Size(shards[index.index].metadata.shard_offsets) == index.offset + ): + return shards[index.index] + + for shard in shards: + if core.Size(shard.metadata.shard_offsets) == index.offset: + return shard + raise ValueError(f"Could not find shard at '{index.offset}' for FQN: '{index.fqn}'") + + +def find_tensor_shard(tensor: core.Tensor, index: MetadataIndex) -> core.Tensor: + if hasattr(tensor, "__get_tensor_shard__"): + # DTensor implements _Checkpointable + return tensor.__get_tensor_shard__(index) # type: ignore[attr-defined] + if isinstance(tensor, ShardedTensor): + return _find_shard(tensor, index).tensor + if index.offset is not None: + # special case looking up a tensor by origin + if index.offset == core.Size([0] * len(tensor.size())): + return tensor + raise ValueError( + f"FQN: '{index.fqn}' is not a ShardedTensor, can't find by offset: '{index.offset}'" + ) + return tensor + + +def find_state_dict_object(state_dict: STATE_DICT_TYPE, index: MetadataIndex) -> Any: + if index.fqn not in state_dict: + raise ValueError(f"Could not find FQN: '{index.fqn}'") + obj = state_dict[index.fqn] + + if isinstance(obj, core.Tensor): + return find_tensor_shard(obj, index) + elif index.offset is not None: + raise ValueError( + f"FQN: '{index.fqn}' is not a ShardedTensor, can't find by offset: '{index.offset}'" + ) + return obj + + +def _element_wise_add(a: Sequence[int], b: Sequence[int]) -> List[int]: + return [i_a + i_b for i_a, i_b in zip(a, b)] + + +def _element_wise_sub(a: Sequence[int], b: Sequence[int]) -> List[int]: + return [i_a - i_b for i_a, i_b in zip(a, b)] + + +class _ReaderView(io.IOBase): + def __init__(self, base_stream: io.IOBase, offset: int, len: int): + super().__init__() + self.offset = offset + self.len = len + self.base_stream = base_stream + self.seek(0) + + def seek(self, __offset: int, __whence: int = os.SEEK_SET) -> int: + if __whence == os.SEEK_SET: + __offset = self.offset + __offset + elif __whence == os.SEEK_END: + __whence = os.SEEK_SET + __offset = (self.offset + self.len) - __offset + return self.base_stream.seek(__offset, __whence) + + def tell(self) -> int: + return self.base_stream.tell() - self.offset + + def readable(self) -> bool: + return self.base_stream.readable() + + def seekable(self) -> bool: + return self.base_stream.seekable() + + def readinto(self, b): + return self.base_stream.readinto(b) # type: ignore[attr-defined] + + def read(self, size=-1): + return self.base_stream.read(size) + + +def _create_file_view(file: io.IOBase, offset: int, length: int) -> io.IOBase: + # FIXME (kumpera) core.load fails if we wrap with io.BufferedReader + return _ReaderView(file, offset, length) + + +def _normalize_device_info(device_type: str, device_id: int) -> str: + """Device info normalization.""" + if device_type == "cpu": + return "cpu" + return f"{device_type}:{device_id}" + + +# TODO: integrate with distributed logging flag +ENABLE_PROFILE = False + + +@contextmanager +def _profile(): + # Only log the profiling when it is enable and is on rank0 or dist is not + # avaiable. + if ENABLE_PROFILE and (not dist.is_available() or dist.get_rank() == 0): + profiler = cProfile.Profile() + profiler.enable() + try: + yield + finally: + profiler.disable() + stats = Stats(profiler) + stats.sort_stats("time").print_stats(10) + else: + yield + + +def _api_bc_check(func): + @wraps(func) + def inner_func(*args, **kwargs) -> Any: + if len(args) == 2: + warnings.warn( + f"The argument order of {func.__name__} has been changed. " + "Please check the document to avoid future breakages." + ) + sig = inspect.signature(func) + kwonlyargs = [ + p.name for p in sig.parameters.values() if p.kind == p.KEYWORD_ONLY + ] + if "storage_writer" in kwonlyargs: + assert "storage_writer" not in kwargs, (args, kwargs) + kwargs["storage_writer"] = args[1] + elif "storage_reader" in kwonlyargs: + assert "storage_reader" not in kwargs, (args, kwargs) + kwargs["storage_reader"] = args[1] + else: + raise RuntimeError(f"Unexpected kwonlyargs = {kwonlyargs}") + return func(args[0], **kwargs) + else: + return func(*args, **kwargs) + + return inner_func diff --git a/mindnlp/core/distributed/collective_utils.py b/mindnlp/core/distributed/collective_utils.py new file mode 100644 index 000000000..712e9061e --- /dev/null +++ b/mindnlp/core/distributed/collective_utils.py @@ -0,0 +1,217 @@ +#!/usr/bin/env python3 + + +""" +A set of primitive functions for performing collective ops. + +Each should also handle single rank scenario. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Callable, cast, Generic, List, Optional, Tuple, TypeVar, Union + +from mindnlp import core.distributed as dist + + +T = TypeVar("T") + + +@dataclass +class SyncPayload(Generic[T]): + stage_name: Optional[str] + success: bool + payload: T + exception: Optional[Exception] = None + + +def broadcast( + data_or_fn: Union[T, Callable[[], T]], + *, + success: bool = True, + stage_name: Optional[str] = None, + rank: int = 0, + pg: Optional[dist.ProcessGroup] = None, +) -> T: + """ + Broadcasts the data payload from rank 0 to all other ranks. + Or if a function is passed, execute it in rank 0 and broadcast result to all other ranks. + + Can be used to broadcast a failure signal to stop all ranks. + + If the function raises an exception, all ranks will raise. + + Args: + data_or_fn: the data to broadcast or function to execute and broadcast result. + success: False to stop all ranks. + stage_name: the name of the logical stage for synchronization and debugging + rank: rank to broadcast data or execute function and broadcast resutls. + pg: the process group for sync + Throws: + RuntimeError from original exception trace + Returns: + the value after synchronization + + Example usage: + >> id = broadcast(data_or_fn=allocate_id, rank=0, pg=ext_pg.my_pg) + """ + + if not success and data_or_fn is not None: + raise AssertionError( + "Data or Function is expected to be None if not successful" + ) + + payload: Optional[T] = None + exception: Optional[Exception] = None + # if no pg is passed then execute if rank is 0 + if (pg is None and rank == 0) or (pg is not None and pg.rank() == rank): + # determine if it is an executable function or data payload only + if callable(data_or_fn): + try: + payload = data_or_fn() + except Exception as e: + success = False + exception = e + else: + payload = data_or_fn + + # broadcast the exception type if any to all ranks for failure categorization + sync_obj = SyncPayload( + stage_name=stage_name, + success=success, + payload=payload, + exception=exception, + ) + + if pg is not None: + broadcast_list = [sync_obj] + dist.broadcast_object_list(broadcast_list, src=rank, group=pg) + assert len(broadcast_list) == 1 + sync_obj = broadcast_list[0] + + # failure in any rank will trigger a throw in every rank. + if not sync_obj.success: + error_msg = f"Rank {rank} failed" + if stage_name is not None: + error_msg += f": stage {sync_obj.stage_name}" + if sync_obj.exception is not None: + error_msg += f": exception {sync_obj.exception}" + raise RuntimeError(error_msg) from sync_obj.exception + + return cast(T, sync_obj.payload) + + +def all_gather( + data_or_fn: Union[T, Callable[[], T]], + stage_name: Optional[str] = None, + pg: Optional[dist.ProcessGroup] = None, +) -> List[T]: + """ + A simple all_gather primitive with basic synchronization guard logic, + by checking payload from all ranks has the same stage name. + + Args: + data_or_fn: the data to be all gathered across ranks or function to be executed + stage_name: the sync stage name for out-of-sync protection + pg: the process group for sync + Throws: + RuntimeError from original exception trace + Returns: + a list of synced data from all ranks + + Example usage: + >> all_ids = all_gather(data_or_fn=allocate_id, pg=ext_pg.my_pg) + """ + payload: Optional[T] = None + exception: Optional[Exception] = None + success = True + # determine if it is an executable function or data payload only + if callable(data_or_fn): + try: + payload = data_or_fn() + except Exception as e: + success = False + exception = e + else: + payload = data_or_fn + + sync_obj = SyncPayload( + stage_name=stage_name, + success=success, + payload=payload, + exception=exception, + ) + + if pg is not None: + # List of success/failure across all ranks. + total_list = [None] * dist.get_world_size(pg) + all_gather_object_enforce_type(pg, total_list, sync_obj) + # Each rank will throw RuntimeError in case of failure on any rank. + stage_name = cast(SyncPayload[T], total_list[0]).stage_name + exception_list: List[Tuple[int, Exception]] = [] + ret_list: List[T] = [] + error_msg: str = "" + + for i, sp in enumerate(cast(List[SyncPayload[T]], total_list)): + if sp.stage_name != stage_name: + error_msg += ( + f"Unexpected stage name received from rank {i}: {sp.stage_name} " + ) + continue + if not sp.success and sp.exception is not None: + exception_list.append((i, sp.exception)) + continue + ret_list.append(sp.payload) + + if len(exception_list) > 0: + raise RuntimeError( # type: ignore[misc] + error_msg, exception_list + ) from exception_list[0] + return ret_list + else: + if not sync_obj.success: + raise RuntimeError( + f"all_gather failed with exception {sync_obj.exception}", + ) from sync_obj.exception + return [sync_obj.payload] # type: ignore[list-item] + + +# Note: use Any for typing for now so users can pass in +# either a list of None or target type placeholders +# otherwise pyre would complain +def all_gather_object_enforce_type( + pg: dist.ProcessGroup, + # pyre-fixme[2]: Parameter must have a type that does not contain `Any` + object_list: List[Any], + # pyre-fixme[2]: Parameter must have a type other than `Any` + obj: Any, + # pyre-fixme[2]: Parameter must have a type that does not contain `Any` + type_checker: Callable[[Any, Any], bool] = lambda x, y: type(x) == type(y), +) -> None: + """ + Similar to plain all_gather_object but with additional type checking + AFTER gather is done to ensure basic consistency. + If check does not pass, all ranks will fail with exception. + + This is generally to prevent conditional logic leading to + unexpected messages being received. This is considered fatal code error, + but due to logic stacks this might happen implicitly in practice. + + The default check does not check sub type (considered different) + or covariance (considered same) but users can pass in custom checker + if more complicated check is needed. + """ + dist.all_gather_object(object_list, obj, group=pg) + + # conservative check + list_len = len(object_list) + if list_len == 0: + return + first_obj = object_list[0] + for i in range(1, list_len): + if not type_checker(first_obj, object_list[i]): + raise TypeError( + f"Object type at index {i} is {type(object_list[i])}, " + f"while first object type is {type(first_obj)}" + ) diff --git a/mindnlp/core/distributed/constants.py b/mindnlp/core/distributed/constants.py new file mode 100644 index 000000000..076c1a9e9 --- /dev/null +++ b/mindnlp/core/distributed/constants.py @@ -0,0 +1,19 @@ +from datetime import timedelta +from typing import Optional + +# from core._C._distributed_c10d import _DEFAULT_PG_TIMEOUT +_DEFAULT_PG_TIMEOUT = timedelta(seconds=5000) + +__all__ = ["default_pg_timeout", "default_pg_nccl_timeout"] + +# Default process group wide timeout, if applicable. +# This only applies to the non-nccl backends +# To make an attempt at backwards compatibility with THD, we use an +# extraordinarily high default timeout, given that THD did not have timeouts. +default_pg_timeout: timedelta = _DEFAULT_PG_TIMEOUT +# Separate timeout for PGNCCL mainly becuase it's always been that way in the C++ layer, but until recently +# there was one default that applied across all backends in the python layer. +# Later, we could consider merging them back together at the c++ layer if we can align on a same value. +# (only if TORCH_NCCL_BLOCKING_WAIT or TORCH_NCCL_ASYNC_ERROR_HANDLING is set to 1). + +default_pg_nccl_timeout: Optional[timedelta] = _DEFAULT_PG_TIMEOUT diff --git a/mindnlp/core/distributed/device_mesh.py b/mindnlp/core/distributed/device_mesh.py new file mode 100644 index 000000000..607c7182a --- /dev/null +++ b/mindnlp/core/distributed/device_mesh.py @@ -0,0 +1,999 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates +import logging +import math +import threading +from functools import reduce +from itertools import chain +from typing import Dict, List, Optional, Tuple, TYPE_CHECKING, Union + +from mindnlp import core +from core.distributed import is_available +from core.utils._typing_utils import not_none + + +__all__ = ["init_device_mesh", "DeviceMesh"] + + +if not is_available(): + import sys + + # We need to create the stubs when distributed is not available. + # Otherwise, we would fail the doc tests (```./.ci/pytorch/docs-test.sh```), + # since it would try to import ``core.distributed.device_mesh`` or + # ``core.distributed.init_device_mesh`` but cannot find them. + + class _DeviceMeshStub: + pass + + def _init_device_mesh_stub(): + pass + + sys.modules["core.distributed.device_mesh"].DeviceMesh = _DeviceMeshStub # type: ignore[attr-defined] + sys.modules[ + "core.distributed.device_mesh" + ].init_device_mesh = _init_device_mesh_stub # type: ignore[attr-defined] + + +else: + from .c10d import Backend as C10dBackend + from core.distributed.distributed_c10d import ( + _find_pg_by_ranks_and_tag, + _get_default_group, + _get_group_tag, + get_backend, + get_process_group_ranks, + get_rank, + get_world_size, + init_process_group, + is_initialized, + new_group, + ProcessGroup, + split_group, + ) + + logger = logging.getLogger(__name__) + + # only import numpy typing when type checking + if TYPE_CHECKING: + try: + from numpy.typing import ArrayLike + except ImportError: + logger.warning( + "DeviceMesh requires numpy >= 1.21 to be installed for type checking" + ) + + class _MeshEnv(threading.local): + def __init__(self) -> None: + self.mesh_stack: List[DeviceMesh] = [] + self.child_to_root_mapping: Dict[DeviceMesh, DeviceMesh] = {} + self.mesh_dim_group_options: Dict[ + int, Tuple[str, Optional[C10dBackend.Options]] + ] = {} + self.root_to_flatten_mapping: Dict[DeviceMesh, Dict[str, DeviceMesh]] = {} + # Record flatten mesh name to its mesh dim index in root mesh. + self.flatten_name_to_root_dims: Dict[ + DeviceMesh, Dict[str, Tuple[int, ...]] + ] = {} + + def get_current_mesh(self) -> "DeviceMesh": + if len(self.mesh_stack) == 0: + raise RuntimeError("No device mesh is currently active!") + return self.mesh_stack[-1] + + def create_sub_mesh( + self, + device_mesh: "DeviceMesh", + submesh_dim_names: Tuple[str, ...], + submesh_dims: List[Tuple[int, ...]], + ) -> "DeviceMesh": + # Get the submesh dim size from the submesh_dims. + # For example, if we have a 3D mesh with mesh_shape (2, 2, 2) mesh_dim_names ("dp", "cp", "tp") and we want + # to slice out mesh["dp_cp"], then submesh_dims = [(0, 1), (2,)] and submesh_dim_size = [2 * 2, 2] = [4, 2]. + # If we want to slice out mesh["dp", "cp"], then submesh_dims = [(0,), (1,)] and submesh_dim_size = [2, 2]. + slice_dim_size = [ + reduce( + lambda x, y: x * device_mesh.mesh.size(y), + mesh_dim, + 1, + ) + for mesh_dim in submesh_dims + ] + + mesh_tensor = device_mesh.mesh + # slice_dim_idx could be differnt from submesh_dims, as we may need to flatten out some dims. + slice_dim_idx = [] + slice_dim_group_info = [] + # keep track of the number of dims that have been flattened so we can get the correct slice_dim_idx in the + # flattened mesh tensor. + num_dims_flatten = 0 + for mesh_dim_indices, mesh_dim_name in zip(submesh_dims, submesh_dim_names): + # Currently, this only allows slicing out a contiguous flattened dim. + # TODO: we need to handle reconstructing a non-contiguous flattened dim. + if len(mesh_dim_indices) > 1: + # We need to move the start_dim and end_dim to the left if some dims are already flattened. + mesh_tensor = mesh_tensor.flatten( + start_dim=mesh_dim_indices[0] - num_dims_flatten, + end_dim=mesh_dim_indices[-1] - num_dims_flatten, + ) + # If some dims are already flattened, we need to adjust the slice_dim_idx accordingly. + # For example, if the submesh_dims = [(0, 1), (2,), (3, 4)] with 0-1 flattened and 3-4 flattened, + # then the final slice_dim_idx should be [0, 1, 2]. + slice_dim_idx.append(mesh_dim_indices[0] - num_dims_flatten) + num_dims_flatten += len(mesh_dim_indices) - 1 + slice_dim_group_info.append( + self.root_to_flatten_mapping[device_mesh][ + mesh_dim_name + ]._dim_group_infos[0] + ) + else: + slice_dim_idx.append(mesh_dim_indices[0] - num_dims_flatten) + slice_dim_group_info.append( + device_mesh._dim_group_infos[mesh_dim_indices[0]] + ) + + # mesh_tensor has already been flattened if needed. So mesh_tensor.ndim <= device_mesh.mesh.ndim now. + mesh_dims_remained_idx = list(range(mesh_tensor.ndim)) + for idx in slice_dim_idx: + mesh_dims_remained_idx.remove(idx) + + # pg_ranks_by_dim is the size of [number of local ranks of the outermost submesh dimension, *slice_dim_idx] + # This means on each local rank of the outermost slice mesh dim, we have a tensor of submesh size with + # the pg ranks of the submesh. From this, we can extract the submesh mesh tensor contains the current rank. + pg_ranks_by_dim = mesh_tensor.permute( + *mesh_dims_remained_idx, *slice_dim_idx + ).reshape(-1, *slice_dim_size) + + cur_rank = device_mesh.get_rank() + for mesh_nd in pg_ranks_by_dim: + submesh = DeviceMesh( + device_mesh.device_type, + mesh_nd, + mesh_dim_names=submesh_dim_names, + _init_backend=False, + ) + if cur_rank in mesh_nd: + res_submesh = submesh + + res_submesh._dim_group_infos = slice_dim_group_info # type: ignore[possibly-undefined] + self.child_to_root_mapping[res_submesh] = device_mesh + + return res_submesh + + def create_flatten_mesh( + self, device_mesh: "DeviceMesh", mesh_dim_name: Optional[str] = None + ) -> "DeviceMesh": + root_mesh = _mesh_resources.get_root_mesh(device_mesh) + + flatten_dims_in_root = [ + not_none(root_mesh.mesh_dim_names).index(flattened_mesh_dim_name) + for flattened_mesh_dim_name in not_none(device_mesh.mesh_dim_names) + ] + + if not mesh_dim_name: + mesh_dim_name = "_".join( + [ + not_none(root_mesh.mesh_dim_names)[dim] + for dim in flatten_dims_in_root + ] + ) + + # Check whether the mesh_dim_name for flattened mesh is valid. + self.flatten_name_to_root_dims.setdefault(root_mesh, {}) + invalid_dim_names = chain( + *list(not_none(root_mesh.mesh_dim_names)), + *self.flatten_name_to_root_dims[root_mesh].keys(), + ) + if mesh_dim_name in invalid_dim_names: + raise RuntimeError( + f"{mesh_dim_name} already exists for submesh of the {root_mesh}. ", + f"The mesh_dim_names of submesh and flattened mesh are {invalid_dim_names}. " + f"Please specify another valid mesh_dim_name.", + ) + + # Quick return if the flatten mesh has been created before. + # TODO: If we decide to restrict flatten initialization once, we should remove + # this check and throw an error if the flatten mesh is already created before. + if ( + root_mesh in self.root_to_flatten_mapping + and mesh_dim_name in self.root_to_flatten_mapping[root_mesh] + ): + return self.root_to_flatten_mapping[root_mesh][mesh_dim_name] + + flattened_mesh_dim_size = math.prod(device_mesh.mesh.size()) + + remained_dims_in_root = list(range(root_mesh.mesh.ndim)) + for flatten_dim_in_root in flatten_dims_in_root: + remained_dims_in_root.remove(flatten_dim_in_root) + + pg_ranks_by_dim = root_mesh.mesh.permute( + *remained_dims_in_root, *flatten_dims_in_root + ).reshape(-1, flattened_mesh_dim_size) + + cur_rank = root_mesh.get_rank() + for mesh_nd in pg_ranks_by_dim: + # need to init backend here since the flattened pg doesn't exist in root mesh. + flattened_mesh = DeviceMesh( + root_mesh.device_type, + mesh_nd, + mesh_dim_names=(mesh_dim_name,), + ) + if cur_rank in mesh_nd: + res_flattened_mesh = flattened_mesh + self.child_to_root_mapping[res_flattened_mesh] = root_mesh # type: ignore[possibly-undefined] + self.root_to_flatten_mapping.setdefault(root_mesh, {})[mesh_dim_name] = res_flattened_mesh # type: ignore[possibly-undefined] + self.flatten_name_to_root_dims[root_mesh][mesh_dim_name] = tuple(flatten_dims_in_root) # type: ignore[possibly-undefined] + + return res_flattened_mesh + + def get_root_mesh(self, device_mesh: "DeviceMesh") -> "DeviceMesh": + # If a mesh could not be found in the child_to_root_mapping, it is a root mesh itself. + # A root mesh is not created through slicing. + # We considers the root mesh of a root mesh is itself. + root_mesh = self.child_to_root_mapping.get(device_mesh, None) + return device_mesh if not root_mesh else root_mesh + + def get_root_mesh_dim(self, device_mesh: "DeviceMesh") -> Optional[int]: + """ + Returns the index of the mesh dim in the root mesh. + The device_mesh passed in needs to be sliced out from the root mesh + or submesh of the root mesh. + """ + root_mesh = self.get_root_mesh(device_mesh) + child_mesh_dim_names = device_mesh.mesh_dim_names + if root_mesh and child_mesh_dim_names: + assert ( + len(child_mesh_dim_names) == 1 + ), "The submesh can only be a 1D mesh." + child_mesh_dim_name = child_mesh_dim_names[0] + return self.get_mesh_dim_by_name(root_mesh, child_mesh_dim_name) + return None + + @staticmethod + def num_devices_per_host(device_type: str) -> int: + return _get_device_handle(device_type).device_count() + + @staticmethod + def num_hosts(device_type: str) -> int: + # ProcessGroup can't tell us this info so we have to infer it, assume + # homogeneous hardware for now + return get_world_size() // _MeshEnv.num_devices_per_host(device_type) + + def get_mesh_dim_by_name( + self, device_mesh: "DeviceMesh", mesh_dim_name: str + ) -> int: + if ( + device_mesh.mesh_dim_names is None + or len(device_mesh.mesh_dim_names) == 0 + ): + raise KeyError( + "No `mesh_dim_names` found.", + ) + if mesh_dim_name not in device_mesh.mesh_dim_names: + raise KeyError( + f"Mesh dimension '{mesh_dim_name}' does not exist.", + f"Available mesh dimensions are: mesh_dim_names={device_mesh.mesh_dim_names}", + ) + return not_none(device_mesh.mesh_dim_names.index(mesh_dim_name)) + + def _set_mesh_dim_group_options( + self, + dim: int, + backend: str, + pg_options: Optional[C10dBackend.Options] = None, + ) -> None: + self.mesh_dim_group_options[dim] = (backend, pg_options) + + def _get_slice_mesh_dims( + self, device_mesh, mesh_dim_names + ) -> List[Tuple[int, ...]]: + """ + Validate whether the mesh_dim_names is valid for slicing the given device_mesh. + If valid, return dim indexes of the slice mesh in the device mesh. + """ + if device_mesh != self.get_root_mesh(device_mesh): + raise RuntimeError("Cannot create a submesh from a submesh.") + + # The slice mesh_dim_names should consist either the device_mesh's mesh_dim_names + # or its flattened mesh's mesh_dim_names. + self.flatten_name_to_root_dims.setdefault(device_mesh, {}) + flatten_name_to_root_dims = self.flatten_name_to_root_dims[device_mesh] + valid_mesh_dim_names = [ + *device_mesh.mesh_dim_names, + *flatten_name_to_root_dims, + ] + + if not all( + mesh_dim_name in valid_mesh_dim_names + for mesh_dim_name in mesh_dim_names + ): + raise KeyError( + f"Invalid mesh_dim_names {mesh_dim_names} specified. " + f"Valid mesh_dim_names are {valid_mesh_dim_names}." + ) + + # Validate the order of the slice mesh dim indices. + # This needs to be in ascending order. + curr_idx = -1 + slice_mesh_dims = [] + for mesh_dim_name in mesh_dim_names: + if mesh_dim_name in flatten_name_to_root_dims: + mesh_indices = flatten_name_to_root_dims[mesh_dim_name] + # TODO: this doesn't allow non-contiguous slicing with flatten dim yet. next_idx + # should be mesh_indices[0] once we support non-contiguous slicing with flatten dim. + next_idx = mesh_indices[-1] + slice_mesh_dims.append(mesh_indices) + else: + next_idx = device_mesh.mesh_dim_names.index(mesh_dim_name) + slice_mesh_dims.append((next_idx,)) + if next_idx <= curr_idx: + raise KeyError( + f"Invalid mesh_dim_names {mesh_dim_names} specified. ", + f"Found mesh dim indices to slice: {slice_mesh_dims}. ", + "Mesh dim indices should be in ascending order.", + ) + curr_idx = next_idx + + return slice_mesh_dims + + def _get_all_submeshes( + self, device_mesh: "DeviceMesh", mesh_dim_name: str + ) -> List["DeviceMesh"]: + """ + Return all the submeshes of a given mesh dimension of the device mesh. + """ + mesh_dim = self.get_mesh_dim_by_name(device_mesh, mesh_dim_name) + pg_ranks_by_dim = device_mesh.mesh.swapdims(-1, mesh_dim).reshape( + -1, device_mesh.mesh.size(mesh_dim) + ) + + cur_rank = device_mesh.get_rank() + res_submeshes = [] + for mesh_1d in pg_ranks_by_dim: + submesh = DeviceMesh( + device_mesh.device_type, + mesh_1d, + mesh_dim_names=(mesh_dim_name,), + _init_backend=False, + ) + submesh._dim_group_infos = ( + [device_mesh._dim_group_infos[mesh_dim]] + if cur_rank in mesh_1d + else [] + ) + res_submeshes.append(submesh) + + return res_submeshes + + _mesh_resources: _MeshEnv = _MeshEnv() + + def _get_device_handle(device_type: str = "cuda"): + """ + Get the module corresponding to the device_type which is cuda or cuda-like device. + For example, when the device_type is cuda, the module `core.cuda` is returned. + Return None when there is no corresponding module for device_type, otherwise + return the corresponding module. + """ + return getattr(torch, device_type, None) + + class DeviceMesh: + """ + DeviceMesh represents a mesh of devices, where layout of devices could be + represented as a n-d dimension array, and each value of the n-d dimensional + array is the global id of the default process group ranks. + + DeviceMesh could be used to describe the layout of devices across the cluster, + and serves as a proxy for communication among the device lists within the cluster. + + DeviceMesh can be used as a context manager. + + .. note:: + DeviceMesh follows SPMD programming model, which means the same PyTorch Python program + is running on all processes/ranks in the cluster. Therefore, users need to make sure the + `mesh` array (which describes the layout of devices) should be identical across all ranks. + Inconsistent `mesh` will lead to silent hang. + + Args: + device_type (str): The device type of the mesh. Currently supports: "cpu", "cuda/cuda-like". + mesh (ndarray): A multi-dimensional array or an integer tensor describing the layout + of devices, where the IDs are global IDs of the default process group. + + Returns: + DeviceMesh: A :class:`DeviceMesh` object representing the device layout. + + The following program runs on each process/rank in an SPMD manner. In this example, we have 2 + hosts with 4 GPUs each. + A reduction over the first dimension of mesh will reduce across + columns (0, 4), .. and (3, 7), a reduction over the second dimension + of mesh reduces across rows (0, 1, 2, 3) and (4, 5, 6, 7). + + Example:: + >>> # xdoctest: +SKIP("no rank") + >>> from core.distributed.device_mesh import DeviceMesh + >>> + >>> # Initialize device mesh as (2, 4) to represent the topology + >>> # of cross-host(dim 0), and within-host (dim 1). + >>> mesh = DeviceMesh(device_type="cuda", mesh=[[0, 1, 2, 3],[4, 5, 6, 7]]) + """ + + device_type: str + mesh: core.Tensor + mesh_dim_names: Optional[Tuple[str, ...]] + + def __init__( + self, + device_type: str, + mesh: Union[core.Tensor, "ArrayLike"], + *, + mesh_dim_names: Optional[Tuple[str, ...]] = None, + _init_backend: bool = True, + ) -> None: + self.device_type = device_type + if isinstance(mesh, core.Tensor) and mesh.device.type != "cpu": + raise ValueError(f"`mesh` must be a CPU tensor, got {mesh}") + self.mesh = ( + mesh.detach().to(dtype=core.int) + if isinstance(mesh, core.Tensor) + else core.tensor(mesh, device="cpu", dtype=core.int) + ) + self.mesh_dim_names = tuple(mesh_dim_names) if mesh_dim_names else None + + # private field to pre-generate DeviceMesh's hash + self._flatten_mesh_list = tuple(self.mesh.flatten().tolist()) + self._thread_id = None + + # Skip process group initialization if xla device or init backend is False + # TODO(yeounoh) implement DeviceMesh backend and register XLA backend. + if device_type != "xla": + # always try to create default (world) pg, even if it is not initialized + # already. The world pg is used for device mesh identity (rank) on each + # process (we need to know if the current global rank is in the mesh or not). + if _init_backend: + self._get_or_create_default_group() + self._init_process_groups() + + if is_initialized() and get_backend() == "threaded": + self._thread_id = threading.get_ident() + + # calculate the coordinates of the current global rank on the mesh + rank_coords = (self.mesh == get_rank()).nonzero() + assert rank_coords.size(0) in (0, 1) + self._coordinate_on_dim: Optional[List[int]] = ( + rank_coords[0].tolist() if rank_coords.size(0) > 0 else None + ) + + def _get_or_create_default_group(self): + default_initialized = is_initialized() + if not default_initialized: + init_process_group() + + world_size = get_world_size() + if self.mesh.numel() > world_size: + raise RuntimeError( + f"Mesh should not be bigger than default world size {world_size}, but found {self.mesh.numel()} ranks!" + ) + + device_handle = _get_device_handle(self.device_type) + # TODO: if user want to pass pg_options, offer a way to do it + if not default_initialized and device_handle: + # automatically set the current cuda/cuda-like device base on num of gpu devices available in each host + # NOTE: This device selection would only work for homogeneous hardware. + num_devices_per_host = device_handle.device_count() + if ( + world_size > num_devices_per_host + and world_size % num_devices_per_host != 0 + ): + raise RuntimeError( + f"DeviceMesh only support homogeneous hardware, but found " + f"{world_size} ranks and {num_devices_per_host} {self.device_type} devices!" + ) + device_handle.set_device(get_rank() % num_devices_per_host) + + return _get_default_group() + + def _init_process_groups(self): + # tag/ranks/group_name associated with each mesh dimension, each + # mesh dimension should have one sub-group per rank + # + # TODO(yifu): remove tag and ranks once we fully migrate to native + # functional collectives. See details in: + # https://github.com/pytorch/pytorch/issues/93173#issuecomment-1907095208 + dim_group_infos: List[Tuple[str, List[int], str]] = [] + default_group = _get_default_group() + + if self.mesh.ndim == 1 and self.mesh.numel() == get_world_size(): + # Append the default pg to the first dim groups only if the default pg is compatible with `self.device_type`. + # Otherwise, create new pg. + ranks = list(range(get_world_size())) + dim_group = ( + new_group( + backend="cpu:gloo,cuda:nccl", + ranks=ranks, + group_desc="mesh_default", + ) + if core.cuda.is_available() + and get_backend(default_group) == "gloo" + else default_group + ) + dim_group_infos.append( + ( + _get_group_tag(dim_group), + ranks, + dim_group.group_name, + ) + ) + else: + # create sub pgs base on the mesh argument specified + for dim in range(self.mesh.ndim): + # swap the current dim to the last dim + # then reshape to flatten out other dims + pg_ranks_by_dim = self.mesh.swapdims(-1, dim).reshape( + -1, self.mesh.size(dim) + ) + + # Respect dim group options specified via _MeshEnv.set_dim_group_options(). + # Inherit from the parent group if no options are specified for the group. + if dim in _mesh_resources.mesh_dim_group_options: + ( + backend, + pg_options, + ) = _mesh_resources.mesh_dim_group_options[dim] + else: + backend, pg_options = None, None + + # If we have a 2D mesh with mesh_dim_names ("dp", "tp"), the group description + # of the subgroups would be `mesh_dim_dp` and `mesh_name_tp`. + # If the mesh doesn't not have a mesh_dim_names, then the group description of the + # subgroup would be `mesh_dim_0` and `mesh_dim_1`. + group_desc = ( + f"mesh_{self.mesh_dim_names[dim]}" + if self.mesh_dim_names + else f"mesh_dim_{dim}" + ) + + # If bound_device_id exists, it means the nccl communicator has been eagerly initialized + # so that we can use `split_group` to create subgroups through `ncclCommSplit`. + # In this case, we only need to make one API call (`split_group``) for the subgroup creation + # for each mesh dimension. In a 2 * 4 mesh, we only need to make 2 API calls per ranks to create + # all the subgroups. + # Otherwise, we need to make more than one API call (`new_group`) for subgroup creations. The + # numbers of API calls are equal to the number of subgroups for each mesh dimension. In a 2 * 4 + # mesh, we need to make 2 + 4 = 6 API calls per ranks to create all the subgroups. + dim_group = None + if ( + bound_device_id := getattr( + default_group, "bound_device_id", None + ) + ) is not None: + dim_group = split_group( + parent_pg=default_group, + pg_options=pg_options, + split_ranks=pg_ranks_by_dim.tolist(), + group_desc=group_desc, + ) + + # If the subgroup has been already created through `split_group`, we simply loop over `pg_ranks_by_dim` + # and append the `(group_tag, subgroup_ranks, and group_name)` tuple to the `dim_group_infos` list when + # the current rank is in the subgroup. + # Otherwise, we use `new_group` instead of `split_group` to create subgroups by looping over `pg_ranks_by_dim` + # along with appending information to the `dim_group_infos` list whenever necessary. + for dim_mesh in pg_ranks_by_dim: + subgroup_ranks = dim_mesh.tolist() + + # We temporarily revert the re-use subgroup, since it breaks two internal tests. + # Temporarily reverting to resolve test timeout while root-causing. + # TODO: Add two tests to cover internal tests scenarios and re-enable reuse subgroup if exists. + if bound_device_id is None: + dim_group = new_group( + ranks=subgroup_ranks, + backend=backend, + pg_options=pg_options, + group_desc=group_desc, + ) + + # only add to dim_groups if the current rank in the subgroup + if self.get_rank() in subgroup_ranks: + if len(dim_group_infos) > dim: + raise RuntimeError( + f"Each device mesh dimension should get only one process group, but got {self.get_rank()} " + f"in {subgroup_ranks}!" + ) + dim_group_infos.append( + ( + _get_group_tag(not_none(dim_group)), + subgroup_ranks, + dim_group.group_name, + ) + ) + self._dim_group_infos = dim_group_infos + + def __enter__(self) -> "DeviceMesh": + # set this mesh as the current mesh in mesh env + _mesh_resources.mesh_stack.append(self) + return self + + # pyre-fixme[2]: Parameter must be annotated. + def __exit__(self, exc_type, exc_value, exc_traceback) -> None: + # pop this mesh from mesh env + _mesh_resources.mesh_stack.pop() + + def __repr__(self) -> str: + device_mesh_repr = ( + f"DeviceMesh('{self.device_type}', {self.mesh.tolist()})" + if not self.mesh_dim_names + else f"DeviceMesh('{self.device_type}', {self.mesh.tolist()}, mesh_dim_names={self.mesh_dim_names})" + ) + return device_mesh_repr + + def __hash__(self): + # lazily compute hash + self._hash = getattr(self, "_hash", None) + if not self._hash: + self._hash = hash( + ( + self._flatten_mesh_list, + self.mesh.shape, + self.device_type, + self.mesh_dim_names, + self._thread_id, + ) + ) + return self._hash + + def __eq__(self, other: object) -> bool: + if not isinstance(other, DeviceMesh): + return False + if id(self) == id(other): + return True + else: + return ( + self._flatten_mesh_list == other._flatten_mesh_list + and self.mesh.shape == other.mesh.shape + and self.device_type == other.device_type + and self.mesh_dim_names == other.mesh_dim_names + and self._thread_id == other._thread_id + ) + + def __getitem__( + self, mesh_dim_names: Union[str, Tuple[str, ...]] + ) -> "DeviceMesh": + """ + Slice the current DeviceMesh based on the mesh_dim_names given to create a submesh. + The submesh created consists of the dimensions and the communicators indicated by + ``mesh_dim_names`` + + Args: + mesh_dim_names (Union[str, Tuple[str]]): the name or the tuple of names of the + mesh dimension of the DeviceMesh to create the submesh for. + Returns: + A :class:`DeviceMesh` object + + The following program runs on each process/rank in an SPMD manner in a world size of 8. + In the first example: + Calling mesh_2d["tp"] on rank 0, 1, 2, 3 returns a 1D submesh of DeviceMesh:([0, 1, 2, 3]). + Calling mesh_2d["tp"] on rank 4, 5, 6, 7 returns a 1D submesh of DeviceMesh:([4, 5, 6, 7]). + Calling mesh_2d["dp"] on rank 0, 4 returns a 1D submesh of DeviceMesh:([0, 4]). + Calling mesh_2d["dp"] on rank 1, 5 returns a 1D submesh of DeviceMesh:([1, 5]). + Calling mesh_2d["dp"] on rank 2, 6 returns a 1D submesh of DeviceMesh:([2, 6]). + Calling mesh_2d["dp"] on rank 3, 7 returns a 1D submesh of DeviceMesh:([3, 7]). + + In the second example: + Calling mesh_3d["dp", "cp"] on rank 0, 1, 4, 5 returns a 2D submesh of DeviceMesh:([[0, 1], [4, 5]]). + Calling mesh_3d["dp", "cp"] on rank 2, 3, 6, 7 returns a 2D submesh of DeviceMesh:([[2, 3], [6, 7]]). + Calling mesh_3d["cp", "dp"] on rank 0, 1, 4, 5 returns a 2D submesh of DeviceMesh:([[0, 4], [1, 5]]). + Calling mesh_3d["cp", "dp"] on rank 2, 3, 6, 7 returns a 2D submesh of DeviceMesh:([[2, 6], [3, 7]]). + + Example:: + >>> # xdoctest: +SKIP("no rank") + >>> from core.distributed.device_mesh import DeviceMesh + >>> + >>> # Initialize a 2D device mesh as (2, 4) to represent the topology + >>> # of cross-host(dim 0), and within-host (dim 1). + >>> mesh_2d = init_device_mesh(device_type="cuda", (2,4), mesh_dim_names=("dp", "tp")) + >>> tp_mesh = mesh_2d["tp"] + >>> dp_mesh = mesh_2d["dp"] + >>> + >>> # Initialize a 3D mesh. + >>> mesh_3d = init_device_mesh(device_type="cuda", (2,2,2), mesh_dim_names=("dp", "pp", "cp")) + >>> # The order of the mesh_dim_names provided deteremines the order of dimensions in the submesh. + >>> dp_cp_mesh = mesh_3d["dp", "cp"] + >>> cp_dp_mesh = mesh_3d["cp", "dp"] + """ + if not self.mesh_dim_names: + raise RuntimeError("Cannot slice a DeviceMesh without mesh_dim_names!") + + mesh_dim_names = ( + (mesh_dim_names,) if isinstance(mesh_dim_names, str) else mesh_dim_names + ) + + if mesh_dim_names == self.mesh_dim_names: + return self + else: + slice_mesh_dims = _mesh_resources._get_slice_mesh_dims( + self, mesh_dim_names + ) + submesh = _mesh_resources.create_sub_mesh( + self, mesh_dim_names, slice_mesh_dims + ) + return submesh + + def get_group(self, mesh_dim: Optional[Union[int, str]] = None) -> ProcessGroup: + """ + Returns the single ProcessGroup specified by mesh_dim, or, if mesh_dim is not specified and the + DeviceMesh is 1-dimensional, returns the only ProcessGroup in the mesh. + + Args: + mesh_dim (str/int, optional): it can be the name of the mesh dimension or the index + of the mesh dimension. Default is None. + + Returns: + A :class:`ProcessGroup` object. + """ + if not hasattr(self, "_dim_group_infos"): + raise RuntimeError("DeviceMesh process groups not initialized!") + + if self.mesh.ndim > 1 and mesh_dim is None: + raise RuntimeError( + f"Found the DeviceMesh have {self.mesh.ndim} dimensions", + "Optional kwarg `mesh_dim` needs to be specified when device_mesh.ndim > 1.", + "If you want to get the list of all the ProcessGroups in the DeviceMesh," + "please use `get_all_groups()` instead.", + ) + + # Quick return if the current device_mesh is a 1D mesh. + if self.mesh.ndim == 1 and mesh_dim is None: + return not_none( + _find_pg_by_ranks_and_tag(*self._dim_group_infos[0][:2]) # type: ignore[index] + ) + + root_mesh = _mesh_resources.get_root_mesh(self) + root_to_flatten_mapping = _mesh_resources.root_to_flatten_mapping.get( + root_mesh, None + ) + if root_to_flatten_mapping and mesh_dim in root_to_flatten_mapping.keys(): + dim_group_infos = root_to_flatten_mapping[mesh_dim]._dim_group_infos[0][:2] # type: ignore[index] + return not_none(_find_pg_by_ranks_and_tag(*dim_group_infos)) + else: + mesh_dim = ( + _mesh_resources.get_mesh_dim_by_name(self, mesh_dim) + if isinstance(mesh_dim, str) + else mesh_dim + ) + return not_none( + _find_pg_by_ranks_and_tag(*self._dim_group_infos[mesh_dim][:2]) # type: ignore[index] + ) + + def get_all_groups(self) -> List[ProcessGroup]: + """ + Returns a list of ProcessGroups for all mesh dimensions. + + Returns: + A list of :class:`ProcessGroup` object. + """ + return [self.get_group(i) for i in range(self.mesh.ndim)] + + @staticmethod + def from_group( + group: Union[ProcessGroup, List[ProcessGroup]], + device_type: str, + mesh: Optional[Union[core.Tensor, "ArrayLike"]] = None, + *, + mesh_dim_names: Optional[Tuple[str, ...]] = None, + ) -> "DeviceMesh": + """ + Constructs a :class:`DeviceMesh` with ``device_type`` from an + existing :class:`ProcessGroup`. + + The constructed device mesh has number of dimensions equal to the + number of groups passed. If more than one group is passed, then the + ``mesh`` argument is required. + """ + if isinstance(group, ProcessGroup): + group_ranks = get_process_group_ranks(group) + if ( + isinstance(mesh, core.Tensor) and mesh.tolist() != group_ranks + ) or ( + mesh is not None + and not isinstance(mesh, core.Tensor) + and mesh != group_ranks + ): + raise ValueError( + f"Invalid mesh {str(mesh)} for ProcessGroup with ranks {group_ranks}" + ) + mesh = core.tensor(group_ranks, device="cpu", dtype=core.int) + device_mesh = DeviceMesh( + device_type, + mesh, + mesh_dim_names=mesh_dim_names, + _init_backend=False, + ) + device_mesh._dim_group_infos = [ + (_get_group_tag(group), group_ranks, group.group_name) + ] + return device_mesh + groups = list(group) + if len(groups) == 0: + raise ValueError("Expects at least one ProcessGroup to be passed") + if mesh is None: + raise ValueError("Must pass mesh if passing multiple ProcessGroups") + mesh = ( + mesh.detach().to(dtype=core.int, device="cpu") + if isinstance(mesh, core.Tensor) + else core.tensor(mesh, device="cpu", dtype=core.int) + ) + if mesh.ndim != len(groups): + raise ValueError( + "Expects mesh with ndim equal to number of ProcessGroups but got " + f"mesh {mesh.tolist()} and {len(groups)} ProcessGroups" + ) + device_mesh = DeviceMesh( + device_type, mesh, mesh_dim_names=mesh_dim_names, _init_backend=False + ) + device_mesh._dim_group_infos = [ + ( + _get_group_tag(group), + get_process_group_ranks(group), + group.group_name, + ) + for group in groups + ] + return device_mesh + + def size(self, mesh_dim: Optional[int] = None) -> int: + return self.mesh.numel() if mesh_dim is None else self.mesh.size(mesh_dim) + + @property + def ndim(self) -> int: + return self.mesh.ndim + + @property + def shape(self) -> Tuple[int, ...]: + return tuple(self.mesh.shape) + + def get_rank(self) -> int: + """ + Returns the current global rank. + """ + return get_rank() + + def get_local_rank(self, mesh_dim: Optional[Union[int, str]] = None) -> int: + """ + Returns the local rank of the given mesh_dim of the DeviceMesh. + + Args: + mesh_dim (str/int, optional): it can be the name of the mesh dimension or the index + of the mesh dimension. Default is None. + + Returns: + An integer denotes the local rank. + + The following program runs on each process/rank in an SPMD manner. In this example, we have 2 + hosts with 4 GPUs each. + Calling mesh_2d.get_local_rank(mesh_dim=0) on rank 0, 1, 2, 3 would return 0. + Calling mesh_2d.get_local_rank(mesh_dim=0) on rank 4, 5, 6, 7 would return 1. + Calling mesh_2d.get_local_rank(mesh_dim=1) on rank 0, 4 would return 0. + Calling mesh_2d.get_local_rank(mesh_dim=1) on rank 1, 5 would return 1. + Calling mesh_2d.get_local_rank(mesh_dim=1) on rank 2, 6 would return 2. + Calling mesh_2d.get_local_rank(mesh_dim=1) on rank 3, 7 would return 3. + + Example:: + >>> # xdoctest: +SKIP("no rank") + >>> from core.distributed.device_mesh import DeviceMesh + >>> + >>> # Initialize device mesh as (2, 4) to represent the topology + >>> # of cross-host(dim 0), and within-host (dim 1). + >>> mesh = DeviceMesh(device_type="cuda", mesh=[[0, 1, 2, 3],[4, 5, 6, 7]]) + """ + if self.ndim > 1 and mesh_dim is None: + raise RuntimeError( + f"Found the DeviceMesh have {self.mesh.ndim} dimensions", + "Optional kwarg `mesh_dim` needs to be specified when device_mesh.ndim > 1.", + ) + elif mesh_dim is None: + mesh_dim = 0 + + mesh_dim_group = not_none(self.get_group(mesh_dim)) + assert isinstance( + mesh_dim_group, ProcessGroup + ), "We expect ProcessGroup before calling `get_rank`!" + return not_none(get_rank(mesh_dim_group)) + + def get_coordinate(self) -> Optional[List[int]]: + """ + Return the relative indices of this rank relative to all + dimensions of the mesh. If this rank is not part of the mesh, return None. + """ + return self._coordinate_on_dim if self._coordinate_on_dim else None + + def _flatten(self, mesh_dim_name: Optional[str] = None) -> "DeviceMesh": + """ + Returns a 1D DeviceMesh by flattening the current DeviceMesh. + + If no mesh_dim_name is provided, the default is a string concatentaing the mesh_dim_names of the + given submesh with each mesh_dim_name separated by "_". For example, if we have a 3D mesh + DeviceMesh([[[0, 1], [2, 3]], [[4, 5], [6, 7]]], mesh_dim_names=("dp", "cp", "tp")), calling + mesh_3d["dp", "cp"]._flatten() will create a 1D submesh DeviceMesh([0, 1, 2, 3], mesh_dim_names=("dp_cp",)) + on rank 0, 1, 2, 3 and a 1D submesh DeviceMesh([4, 5, 6, 7], mesh_dim_names=("dp_cp",)) on rank 4, 5, 6, 7. + + After the flattened dimension is created, to access the flattened dimesnion in mesh_3d, one can use the + existing slicing method to obtain the flattened mesh through calling mesh_3d["dp_cp"]. + """ + if not self.mesh_dim_names: + raise RuntimeError( + "Cannot flatten a DeviceMesh without mesh_dim_names!" + ) + + return _mesh_resources.create_flatten_mesh(self, mesh_dim_name) + + def init_device_mesh( + device_type: str, + mesh_shape: Tuple[int, ...], + *, + mesh_dim_names: Optional[Tuple[str, ...]] = None, + ) -> DeviceMesh: + """ + Initializes a `DeviceMesh` based on `device_type`, `mesh_shape`, and `mesh_dim_names` parameters. + + This creates a DeviceMesh with an n-dimensional array layout, where `n` is the length of `mesh_shape`. + If `mesh_dim_names` is provided, each dimension is labeled as `mesh_dim_names[i]`. + + .. note:: + `init_device_mesh` follows SPMD programming model, meaning the same PyTorch Python program + runs on all processes/ranks in the cluster. Ensure `mesh_shape` (the dimensions of the nD array + describing device layout) is identical across all ranks. Inconsistent `mesh_shape` may lead to hanging. + + .. note:: + If no process group is found, init_device_mesh will initialize distributed process group/groups + required for distributed communications behind the scene. + + Args: + device_type (str): The device type of the mesh. Currently supports: "cpu", "cuda/cuda-like". + Passing in a device type with a GPU index, such as "cuda:0", is not allowed. + mesh_shape (Tuple[int]): A tuple defining the dimensions of the multi-dimensional array + describing the layout of devices. + mesh_dim_names (Tuple[str], optional): A tuple of mesh dimension names to assign to each dimension + of the multi-dimensional array describing the layout of devices. Its length must match the length + of `mesh_shape`. Each string in `mesh_dim_names` must be unique. + + Returns: + DeviceMesh: A :class:`DeviceMesh` object representing the device layout. + + Example:: + >>> # xdoctest: +SKIP("no rank") + >>> from core.distributed.device_mesh import init_device_mesh + >>> + >>> mesh_1d = init_device_mesh("cuda", mesh_shape=(8,)) + >>> mesh_2d = init_device_mesh("cuda", mesh_shape=(2, 8), mesh_dim_names=("dp", "tp")) + + """ + if mesh_dim_names is not None: + if len(set(mesh_dim_names)) != len(mesh_dim_names): + raise RuntimeError( + "Each mesh_dim_name must be unique.", + f"Found repeated mesh_dim_name in mesh_dim_names {mesh_dim_names}", + ) + + if len(mesh_shape) != len(mesh_dim_names): + raise RuntimeError( + "mesh_shape and mesh_dim_names should have same length!", + f"Found len(mesh_dim_names): {len(mesh_dim_names)} and len(mesh_shape):{len(mesh_shape)}.", + ) + + # assume valid device types are all letters + if device_type and not device_type.isalpha(): + raise RuntimeError( + f"Device type with GPU index is not supported but got {device_type}. ", + "If you maintained a 'core.device' object, it's recommended to pass in 'device.type'.", + ) + + # Always initialize the mesh's tensor on CPU, regardless of what the + # external device type has been set to be (e.g. meta) + with core.device("cpu"): + mesh = core.arange(math.prod(mesh_shape), dtype=core.int).view(mesh_shape) + device_mesh = DeviceMesh( + device_type=device_type, + mesh=mesh, + mesh_dim_names=mesh_dim_names, + ) + + return device_mesh diff --git a/mindnlp/core/distributed/distributed_c10d.py b/mindnlp/core/distributed/distributed_c10d.py new file mode 100644 index 000000000..a505e81ca --- /dev/null +++ b/mindnlp/core/distributed/distributed_c10d.py @@ -0,0 +1,4978 @@ +# mypy: allow-untyped-defs +"""Distributed Collective Communication (c10d).""" + +import collections.abc +import contextlib +import ctypes +import hashlib +import io +import itertools +import logging +import os +import pickle +import sys +import time +import warnings +from collections import namedtuple +from datetime import timedelta +from typing import Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING, Union +import mindspore.communication._comm_helper +from typing_extensions import deprecated + +import numpy as np +import mindspore +from mindspore.communication import init, GlobalComm, get_group_size, get_process_group_ranks as _get_group_ranks, \ + create_group, get_rank as _get_rank + +from mindnlp import core +# from core._C import _DistStoreError as DistStoreError +from .c10d import ( + # _DistributedBackendOptions, + # _register_process_group, + # _resolve_process_group, + # _unregister_all_process_groups, + # _unregister_process_group, + AllgatherOptions, + AllreduceCoalescedOptions, + AllreduceOptions, + AllToAllOptions, + BarrierOptions, + BroadcastOptions, + # DebugLevel, + GatherOptions, + # get_debug_level, + PrefixStore, + ProcessGroup, + ReduceOp, + ReduceOptions, + ReduceScatterOptions, + ScatterOptions, + Store, + Work, +) +# from core._utils_internal import set_pytorch_distributed_envs_from_justknobs +# from core.monitor import _WaitCounter +from mindnlp.core.utils._typing_utils import not_none + +from .c10d_logger import _exception_logger, _time_logger +from .constants import default_pg_nccl_timeout, default_pg_timeout +# from .rendezvous import register_rendezvous_handler, rendezvous # noqa: F401 + + +__all__ = [ + "Backend", + "BackendConfig", + "GroupMember", + "P2POp", + "all_gather", + "all_gather_coalesced", + "all_gather_object", + "all_reduce", + "all_reduce_coalesced", + "all_to_all", + "all_to_all_single", + "barrier", + "batch_isend_irecv", + "broadcast", + "send_object_list", + "recv_object_list", + "broadcast_object_list", + "destroy_process_group", + "gather", + "gather_object", + "get_backend_config", + "get_backend", + "get_rank", + "get_world_size", + "get_pg_count", + "group", + "init_process_group", + "irecv", + "is_gloo_available", + "is_initialized", + "is_mpi_available", + "is_backend_available", + "is_nccl_available", + "is_torchelastic_launched", + "is_ucc_available", + "isend", + "monitored_barrier", + "new_group", + "new_subgroups", + "new_subgroups_by_enumeration", + "recv", + "reduce", + "reduce_scatter", + "scatter", + "scatter_object_list", + "send", + "supports_complex", + "AllreduceCoalescedOptions", + "AllreduceOptions", + "AllToAllOptions", + "BarrierOptions", + "BroadcastOptions", + "GatherOptions", + "PrefixStore", + "ProcessGroup", + "ReduceOp", + "ReduceOptions", + "ReduceScatterOptions", + "ScatterOptions", + "Store", + # "DebugLevel", + # "get_debug_level", + "Work", + "default_pg_timeout", + "get_group_rank", + "get_global_rank", + "get_process_group_ranks", + "all_gather_into_tensor", + "reduce_scatter_tensor", + "get_node_local_rank", + "split_group", +] + +_MPI_AVAILABLE = True +_NCCL_AVAILABLE = True +_GLOO_AVAILABLE = True +_UCC_AVAILABLE = True + +_pickler = pickle.Pickler +_unpickler = pickle.Unpickler + + +try: + from core._C._distributed_c10d import ProcessGroupMPI + + ProcessGroupMPI.__module__ = "core.distributed.distributed_c10d" + __all__ += ["ProcessGroupMPI"] +except ImportError: + _MPI_AVAILABLE = False + +try: + from core._C._distributed_c10d import ProcessGroupNCCL + + ProcessGroupNCCL.__module__ = "core.distributed.distributed_c10d" + __all__ += ["ProcessGroupNCCL"] +except ImportError: + _NCCL_AVAILABLE = False + +try: + from core._C._distributed_c10d import _ProcessGroupWrapper, ProcessGroupGloo + + ProcessGroupGloo.__module__ = "core.distributed.distributed_c10d" + __all__ += ["ProcessGroupGloo"] +except ImportError: + _GLOO_AVAILABLE = False + +try: + from core._C._distributed_c10d import ProcessGroupUCC + + ProcessGroupUCC.__module__ = "core.distributed.distributed_c10d" + __all__ += ["ProcessGroupUCC"] +except ImportError: + _UCC_AVAILABLE = False + +logger = logging.getLogger(__name__) + +PG_WRAPPER_STORE_PREFIX = "pg_wrapper" + + +# Some reduce ops are not supported by complex numbers and will result in an error. +# We currently provide complex support to the distributed API by viewing +# complex tensors as real (core.view_as_real), meaning that calling +# these unsupported ops will return garbage values rather than error out. +# (e.g. max(2+3i, 3+2i) = 3+3i) +# We'd like calls to unsupported ops to error out accordingly, +# rather than returning garbage values. +def supports_complex(reduceOp: ReduceOp) -> bool: + """Return true if reduce ops is supported. False otherwise.""" + denyList = [ + ReduceOp.MAX, + ReduceOp.MIN, + ReduceOp.PRODUCT, + ReduceOp.BAND, + ReduceOp.BOR, + ReduceOp.BXOR, + ] + return reduceOp not in denyList + + +class Backend(str): + """ + An enum-like class for backends. + + Available backends: GLOO, NCCL, UCC, MPI, and other registered backends. + + The values of this class are lowercase strings, e.g., ``"gloo"``. They can + be accessed as attributes, e.g., ``Backend.NCCL``. + + This class can be directly called to parse the string, e.g., + ``Backend(backend_str)`` will check if ``backend_str`` is valid, and + return the parsed lowercase string if so. It also accepts uppercase strings, + e.g., ``Backend("GLOO")`` returns ``"gloo"``. + + .. note:: The entry ``Backend.UNDEFINED`` is present but only used as + initial value of some fields. Users should neither use it directly + nor assume its existence. + """ + + UNDEFINED = "undefined" + GLOO = "gloo" + NCCL = "nccl" + UCC = "ucc" + MPI = "mpi" + + _BackendPlugin = namedtuple("_BackendPlugin", ["creator_fn", "extended_api"]) + + _plugins: Dict[str, _BackendPlugin] = {} + + backend_list = [UNDEFINED, GLOO, NCCL, UCC, MPI] + + # 3rd-party devices can register the default backend support here + default_device_backend_map: Dict[str, str] = { + "cpu": GLOO, + "cuda": NCCL, + } + + backend_capability: Dict[str, List[str]] = { + GLOO: ["cpu", "cuda"], + NCCL: ["cuda"], + UCC: ["cpu", "cuda"], + MPI: ["cpu", "cuda"], + } + + backend_type_map: Dict[str, ProcessGroup.BackendType] = { + UNDEFINED: ProcessGroup.BackendType.UNDEFINED, + GLOO: ProcessGroup.BackendType.GLOO, + NCCL: ProcessGroup.BackendType.NCCL, + UCC: ProcessGroup.BackendType.UCC, + MPI: ProcessGroup.BackendType.MPI, + } + + def __new__(cls, name: str): + """Create and return a new instance of the class.""" + if not isinstance(name, str): + raise ValueError("Backend constructor parameter must be string-ish") + value = getattr(Backend, name.upper(), Backend.UNDEFINED) + + if value == Backend.UNDEFINED: + value = name.lower() + return value + + @classmethod + def register_backend( + cls, + name, + func, + extended_api=False, + devices: Optional[Union[str, List[str]]] = None, + ) -> None: + """ + Register a new backend with the given name and instantiating function. + + This class method is used by 3rd party ``ProcessGroup`` extension to + register new backends. + + Args: + name (str): Backend name of the ``ProcessGroup`` extension. It + should match the one in ``init_process_group()``. + func (function): Function handler that instantiates the backend. + The function should be implemented in the backend + extension and takes four arguments, including + ``store``, ``rank``, ``world_size``, and ``timeout``. + extended_api (bool, optional): Whether the backend supports extended argument structure. + Default: ``False``. If set to ``True``, the backend + will get an instance of ``c10d::DistributedBackendOptions``, and + a process group options object as defined by the backend implementation. + device (str or list of str, optional): device type this backend + supports, e.g. "cpu", "cuda", etc. If `None`, + assuming both "cpu" and "cuda" + + .. note:: This support of 3rd party backend is experimental and subject to change. + + """ + # Allow UCC plugin if Pytorch is not built with native support. + # TODO: remove this exception once UCC plugin is fully deprecated. + if name != Backend.UCC or (name == Backend.UCC and is_ucc_available()): + assert not hasattr( + Backend, name.upper() + ), f"{name.upper()} c10d backend already exist" + assert ( + name.upper() not in Backend._plugins + ), f"{name.upper()} c10d backend creator function already exist" + + setattr(Backend, name.upper(), name.lower()) + Backend.backend_list.append(name.lower()) + if devices is not None: + for device in devices: + if device != "cpu" and device != "cuda": + Backend.default_device_backend_map[device] = name.lower() + Backend.backend_type_map[name.lower()] = ProcessGroup.BackendType.CUSTOM + + # Update device capability matrix in Backend class + if devices is None: + # This is more of a backward support for groups like `threaded`: + # assume default devices "cpu" and "cuda", but warn + warnings.warn( + f"Device capability of {name} unspecified, assuming `cpu` and " + "`cuda`. Please specify it via the `devices` argument of " + "`register_backend`." + ) + Backend.backend_capability[name.lower()] = ["cpu", "cuda"] + elif isinstance(devices, str): + # Single device string specified. Simply convert to list. + Backend.backend_capability[name.lower()] = [devices] + else: + Backend.backend_capability[name.lower()] = devices + + Backend._plugins[name.upper()] = Backend._BackendPlugin(func, extended_api) + + +class BackendConfig: + """Backend configuration class.""" + + def __init__(self, backend: Backend): + """Init.""" + self.device_backend_map: Dict[str, Backend] = {} + backend = str(backend) + + if backend == Backend.UNDEFINED: + # default config when backend is not specified + # supported since PyTorch 2.0 + for device, default_backend in Backend.default_device_backend_map.items(): + if is_backend_available(default_backend): + if ( + default_backend == Backend.NCCL + and not core.cuda.is_available() + ): + continue + self.device_backend_map[device] = Backend(default_backend) + elif backend.lower() in Backend.backend_list: + # Cases for when backend is a single string (without device types) + # e.g. "nccl", "gloo", "ucc", "mpi" + supported_devices = Backend.backend_capability[backend.lower()] + backend_val = Backend(backend) + self.device_backend_map = dict.fromkeys(supported_devices, backend_val) + elif ":" in backend.lower(): + # Backend specified in "device:backend" format + # make sure the backend string is in the correct format + # "{device_type1}:{backend1},{device_type2}:{backend2}" + # e.g. "cpu:gloo,cuda:nccl" + backend_str_error_message = f"""The custom backend string argument is invalid: {backend}. + Custom backend string is an experimental feature where the backend string must be in the format: + ":,:...". e.g. 'cpu:gloo,cuda:nccl'""" + + # parse the backend string and populate the device_backend_map + for device_backend_pair_str in backend.lower().split(","): + device_backend_pair = device_backend_pair_str.split(":") + if len(device_backend_pair) != 2: + raise ValueError( + f"Invalid device:backend pairing: \ + {device_backend_pair_str}. {backend_str_error_message}" + ) + device, backend = device_backend_pair + if device in self.device_backend_map: + raise ValueError( + f"Duplicate device type {device} \ + in backend string: {backend}. {backend_str_error_message}" + ) + self.device_backend_map[device] = Backend(backend) + else: + # User specified a single backend name whose device capability is + # unknown, assuming it can support the default devices of PyTorch + # (cpu and cuda) + warnings.warn( + f"Device capability of {backend} unknown, assuming `cpu` and " + "`cuda`. You can specify it in `device:backend` format in " + "`init_process_group` call." + ) + backend_val = Backend(backend) + self.device_backend_map = { + "cpu": backend_val, + "cuda": backend_val, + "xpu": backend_val, + } + + logger.info("Using backend config: %s", self.device_backend_map) + + def __repr__(self): + """Return all the device:backend pairs separated by commas.""" + return ",".join( + f"{device}:{backend}" for device, backend in self.device_backend_map.items() + ) + + def get_device_backend_map(self) -> Dict[str, Backend]: + """Return backend map of the device.""" + return self.device_backend_map + + +class P2POp: + """ + A class to build point-to-point operations for ``batch_isend_irecv``. + + This class builds the type of P2P operation, communication buffer, peer rank, + Process Group, and tag. Instances of this class will be passed to + ``batch_isend_irecv`` for point-to-point communications. + + Args: + op (Callable): A function to send data to or receive data from a peer process. + The type of ``op`` is either ``core.distributed.isend`` or + ``core.distributed.irecv``. + tensor (Tensor): Tensor to send or receive. + peer (int, optional): Destination or source rank. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + tag (int, optional): Tag to match send with recv. + group_peer (int, optional): Destination or source rank. + """ + + def __init__( + self, + op: Callable, + tensor: core.Tensor, + peer: Optional[int] = None, + group: Optional[ProcessGroup] = None, + tag: int = 0, + group_peer: Optional[int] = None, + ): + """Init.""" + self.op = op + self.tensor = tensor + self.group = _group_or_default_group(group) + self.peer = _canonicalize_group_rank( + self.group, peer, group_peer, return_global=True + ) + self.tag = tag + self.group_peer = _canonicalize_group_rank(self.group, peer, group_peer) + + def __new__( + cls, + op: Callable, + tensor: core.Tensor, + peer: Optional[int] = None, + group: Optional[ProcessGroup] = None, + tag: int = 0, + group_peer: Optional[int] = None, + ): + """Create and return a new instance of the class.""" + _check_op(op) + _check_single_tensor(tensor, "tensor") + + return object.__new__(cls) + + def __repr__(self): + my_group_rank = get_rank(self.group) + op_name = self.op.__name__ + group_name = self.group.group_name if self.group else "default_pg" + if "send" in op_name: + s = my_group_rank + d = self.group_peer + elif "recv" in op_name: + s = self.group_peer + d = my_group_rank + else: + return super().__repr__() + + return f"P2POp({op_name} pg={group_name}, group_src={s}, group_dst={d}, {self.tensor.shape}, {self.tensor.dtype})" + + +class _CollOp: + """ + A class to capture collective operations. + + Args: + op (Callable): A collective function, e.g. ``core.distributed.all_reduce``. + tensor (Tensor): Tensor to operate on. + dst_tensor (Tensor, optional): Provided when source and destinaton tensors are not the same. + redop (ReduceOp, optional): reduce operation. + root (int, optional): root of broadcast or reduce. + """ + + def __init__( + self, + op: Callable, + tensor: core.Tensor, + dst_tensor: Optional[core.Tensor] = None, + redop: Optional[ReduceOp] = None, + root: Optional[int] = None, + ): + self.op = op + self.tensor = tensor + self.dst_tensor = dst_tensor + self.redop = redop + self.root = root + + +# DO NOT USE THESE FIELDS DIRECTLY. +# Use them through the _world object to make sure the _world override mechanism +_pg_map: Dict[ProcessGroup, Tuple[str, Store]] = {} +_pg_names: Dict[ProcessGroup, str] = {} +_pg_group_ranks: Dict[ProcessGroup, Dict[int, int]] = {} +# For a pg, it is a map from ProcessGroup to BackendConfig +_pg_backend_config: Dict[ProcessGroup, str] = {} +_group_count = 0 +_tags_to_pg: Dict[str, List[ProcessGroup]] = {} +_pg_to_tag: Dict[ProcessGroup, str] = {} +_backend: Optional[str] = None + + +class _World: + """ + Container class for c10d process group state. + + This is used during registration and lookup of PG state. + + .. warning:: This is an experimental API intended to expose the inner workings + of c10d and is subject to change.. + """ + + def __init__(self) -> None: + self._default_pg = None + self._pg_coalesce_state: Dict[ProcessGroup, List[_CollOp]] = {} + + @property + def default_pg(self) -> Optional[ProcessGroup]: + """ + Process group that includes all ranks of the cluster. + + This default ProcessGroup is used by c10d APIs when a ProcessGroup is needed + but None is provided. + """ + return self._default_pg + + @default_pg.setter + def default_pg(self, value) -> None: + self._default_pg = value + + @property + def pg_map(self) -> Dict[ProcessGroup, Tuple[str, Store]]: + """ + Provide Mapping from ProcessGroup to backend name and store. + + For NCCL and GLOO pg, it is a map from ProcessGroup to (Backend, Store) + For MPI pg, it is a map from ProcessGroup to (Backend, None) + + TODO don't expose the map, expose fine grained ops + """ + global _pg_map + return _pg_map + + @property + def pg_names(self) -> Dict[ProcessGroup, str]: + """ + Process group's names, map from ProcessGroup to str. + + TODO don't expose the map, expose fine grained ops + """ + global _pg_names + return _pg_names + + @property + def pg_group_ranks(self) -> Dict[ProcessGroup, Dict[int, int]]: + """ + Process group's global rank to local rank mapping. + + TODO don't expose the map, expose fine grained ops + """ + global _pg_group_ranks + return _pg_group_ranks + + @property + def pg_backend_config(self) -> Dict[ProcessGroup, str]: + """ + Process group's backend config. + + TODO don't expose the map, expose fine grained ops + """ + global _pg_backend_config + return _pg_backend_config + + @property + def group_count(self) -> int: + """ + Process group count for default naming. + + TODO don't expose group_count, use something else instead + """ + global _group_count + return _group_count + + @group_count.setter + def group_count(self, value: int) -> None: + """Use to compute the name of ProcessGroups when using global synchronization.""" + global _group_count + _group_count = value + + @property + def tags_to_pg(self) -> Dict[str, List[ProcessGroup]]: + global _tags_to_pg + return _tags_to_pg + + @property + def pg_to_tag(self) -> Dict[ProcessGroup, str]: + global _pg_to_tag + return _pg_to_tag + + @property + def pg_coalesce_state(self) -> Dict[ProcessGroup, List[_CollOp]]: + return self._pg_coalesce_state + + @property + def pg_config_info(self) -> List[Dict[str, Any]]: + """ + Return a list of dict with process groups and backends. + + Along with their unique IDs and configurations (types and ranks). + """ + config_info: List[Dict[str, Any]] = [] + default_pg_size = _get_group_size(None) + for pg in self.pg_map.keys(): + ranks = self.pg_group_ranks[pg] + config_info.append( + { + "pg_name": self.pg_names[pg], + "pg_desc": pg.group_desc, + "backend_config": self.pg_backend_config[pg], + "ranks": ( + list(ranks.keys()) if len(ranks) != default_pg_size else [] + ), # 'ranks' is an empty list when all ranks are involved in a pg + "group_size": len(ranks), + "group_count": self.group_count, + } + ) + return config_info + + +_world = _World() +"""Holds the singleton instance of ``_World`` used by c10. Experimental extension point to override it""" + + +class _WorldMeta(type): + """ + Meta class of ``group`` and ``GroupMember``. + + Allows them to have the class property ``WORLD``. + """ + + # Points to the default PG once initialized. + @property + def WORLD(cls) -> Optional[ProcessGroup]: + return _world.default_pg + + @WORLD.setter + def WORLD(cls, pg: Optional[ProcessGroup]): + _world.default_pg = pg + + +class group(metaclass=_WorldMeta): + """Group class. Placeholder.""" + + +class GroupMember(metaclass=_WorldMeta): + """Group member class.""" + + NON_GROUP_MEMBER = -100 + + +def _get_default_timeout(backend: Backend) -> timedelta: + # see note on nccl vs other backend timeout (constants.py) + if backend == Backend.NCCL: + if not isinstance(default_pg_nccl_timeout, timedelta): + # TODO moco benchmark on CPU initializes pgnccl backend today, triggered this assert in CI before it was + # changed to be a warning. We should fix the moco model. + warnings.warn( + "Attempted to get default timeout for nccl backend, but NCCL support is not compiled" + ) + return default_pg_timeout + return default_pg_nccl_timeout + else: + return default_pg_timeout + + +def _check_valid_timeout(timeout: Any) -> None: + if not isinstance(timeout, timedelta): + raise TypeError( + f"Expected timeout argument to be of type datetime.timedelta, got {timeout}" + ) + + +# Default process group state +_default_pg_init_method: Optional[str] = None + +STORE_BASED_BARRIER_PREFIX = "store_based_barrier_key" + + +def _get_object_coll_device(group: Optional[ProcessGroup] = None) -> str: + """ + .. note:: This is an internal helper and does not have backward + compatibility, please use with caution. + + Return the device type to use with ``group`` for object collectives or + barrier. + + There are selection rules: + 1. If user specifies exactly one backend in ``init_process_group`` call: + use that backend + 2. Else if user specifies multiple "device:backend" pairs in init_process_group: + If "cpu" is among those pairs, use "cpu" (because the object is in cpu memory); + Otherwise, use the first backend (sort of a random pick). + + Args: + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + + Returns: + str: The device type to use for object collective with ``group``. + + """ + group = group or _get_default_group() + + if not isinstance(group, ProcessGroup): + warnings.warn( + f"You are using a Backend {type(group)} as a ProcessGroup. " + "This usage is deprecated since PyTorch 2.0. Please use a public API " + "of PyTorch Distributed instead.", + ) + # Provide backward compatibility to cases where `group` passed in is + # actually a Backend (like `ProcessGroupGloo`) rather than a + # `ProcessGroup` in PT 2.0 sense + if isinstance(group, ProcessGroupGloo): + # RPC uses Gloo for object collectives + return "cpu" + else: + raise ValueError(f"Expecting a ProcessGroup, but got a {type(group)}.") + + """ + ``group._device_types`` is a property pybind that returns the devices + ("cpu", "cuda", etc) supported by ``group``. Can be multiple if the + ``group`` supports multiple devices. + """ + devices = group._device_types + + if len(devices) == 1: + # User fixed exactly one backend in `init_process_group` + return devices[0].type + elif len(devices) == 0: + # No backend has been registered with this PG (maybe because no + # collective has been run?) We pick cpu as the default and hopefully + # this would lazily init Gloo or other available cpu backend. + return "cpu" + elif core.device("cpu") in devices: + # There are multiple backends in this PG and cpu is among them. + # cpu is preferred as the object is in cpu memory. No need for device + # copy. + return "cpu" + else: + # No cpu in the backend list. Randomly pick the first backend + return devices[0].type + + +def _device_capability(group: Optional[ProcessGroup] = None) -> List[str]: + """ + Return the device type(s) supported by ``group``. + + Args: + group (ProcessGroup, optional): The process group to query. If None, + the default process group will be used. + + Returns: + List[str]: A list of device types supported by ``group``. + """ + group = group or _get_default_group() + return [device.type for device in group._device_types] + + +def _store_based_barrier( + rank, + store, + group_name, + rendezvous_count, + timeout, + logging_interval=timedelta(seconds=10), +) -> None: + """ + Store based barrier for synchronizing processes. + + Barrier based on store which is used for synchronizing processes after + ``init_process_group`` or ``new_group``. Intended to be used only with + those two methods and is not a generic alternative to ``barrier()``. + """ + store_key = f"{STORE_BASED_BARRIER_PREFIX}:{group_name}" + store.add(store_key, 1) + logger.debug("Added key: %s to store for rank: %s", store_key, rank) + + # Now wait for all workers to check in with the store. + world_size = rendezvous_count + worker_count = store.add(store_key, 0) + + last_worker_key = f"{store_key}:last_worker" + if worker_count == world_size: + store.set(last_worker_key, "1") + + # adjust the timeout to be at least 10secs + 1sec per thousand ranks to reduce the odds of timeout + # this value was empirically found while scale testing. + logging_interval = max(logging_interval, timedelta(seconds=10 + world_size / 1000)) + + start = time.time() + while True: + try: + # This will throw an exception after the logging_interval in which we print out + # the status of the group or time out officially, throwing runtime error + store.wait([last_worker_key], logging_interval) + break + except RuntimeError as e: + worker_count = store.add(store_key, 0) + # Print status periodically to keep track. + logger.debug( + "Waiting in store based barrier to initialize process group for %s seconds" + "rank: %s, key: %s (world_size=%s, num_workers_joined=%s, timeout=%s error=%s)", + time.time() - start, + rank, + store_key, + world_size, + worker_count, + timeout, + e, + ) + + if timedelta(seconds=(time.time() - start)) > timeout: + raise DistStoreError( # noqa: B904 + "Timed out initializing process group in store based barrier on " + f"rank {rank}, for key: {store_key} (world_size={world_size}, " + f"num_workers_joined={worker_count}, timeout={timeout} error={e})" + ) + + logger.info( + "Rank %s: Completed store-based barrier for key:%s with %s nodes.", + rank, + store_key, + world_size, + ) + + +def _rank_not_in_group(group: Optional[ProcessGroup]) -> bool: + """Check if the current process's rank is not in a given group.""" + if group is None: + return False + return group == GroupMember.NON_GROUP_MEMBER + + +def _warn_not_in_group(op_name) -> None: + global_rank = -1 if GroupMember.WORLD is None else GroupMember.WORLD.rank() + warnings.warn( + f"Running {op_name} on global rank {global_rank} which does not " + "belong to the given group." + ) + + +def get_group_rank(group: ProcessGroup, global_rank: int) -> int: + """ + Translate a global rank into a group rank. + + ``global_rank`` must be part of ``group`` otherwise this raises RuntimeError. + + Args: + group (ProcessGroup): ProcessGroup to find the relative rank. + global_rank (int): Global rank to query. + + Returns: + Group rank of ``global_rank`` relative to ``group`` + + N.B. calling this function on the default process group returns identity + """ + if group is GroupMember.WORLD: + return global_rank + if group not in _world.pg_group_ranks: + raise ValueError( + f"Group {group} is not registered, please create group with core.distributed.new_group API" + ) + group_ranks = _world.pg_group_ranks[group] + if global_rank not in group_ranks: + raise ValueError(f"Global rank {global_rank} is not part of group {group}") + + return group_ranks[global_rank] + + +def get_global_rank(group: ProcessGroup, group_rank: int) -> int: + """ + Translate a group rank into a global rank. + + ``group_rank`` must be part of `group` otherwise this raises RuntimeError. + + Args: + group (ProcessGroup): ProcessGroup to find the global rank from. + group_rank (int): Group rank to query. + + Returns: + Global rank of ``group_rank`` relative to ``group`` + + N.B. calling this function on the default process group returns identity + """ + if group is GroupMember.WORLD: + return group_rank + if group not in _world.pg_group_ranks: + raise ValueError( + f"Group {group} is not registered, please create group with core.distributed.new_group API" + ) + for rank, grp_rank in _world.pg_group_ranks[group].items(): + if grp_rank == group_rank: + return rank + raise ValueError(f"Group rank {group_rank} is not part of group {group}") + + +# TODO: remove this once the ecosystem moves away from it. +@deprecated( + "`core.distributed.distributed_c10d._get_global_rank` is deprecated, " + "please use `core.distributed.distributed_c10d.get_global_rank` instead", + category=FutureWarning, +) +def _get_global_rank(group, rank) -> int: + """Use get_global_rank as this method is deprecated.""" + return get_global_rank(group, rank) + + +def get_process_group_ranks(group: ProcessGroup) -> List[int]: + """ + Get all ranks associated with ``group``. + + Args: + group (ProcessGroup): ProcessGroup to get all ranks from. + + Returns: + List of global ranks ordered by group rank. + """ + return list(_world.pg_group_ranks[group].keys()) + + +def _get_group_size(group) -> int: + """Get a given group's world size.""" + if group is GroupMember.WORLD or group is None: + default_pg = _get_default_group() + return default_pg.size() + return group.size() + + +def _get_group_size_by_name(group_name: str) -> int: + group = _resolve_process_group(group_name) + return group.size() + + +def _resolve_group_name_by_ranks_and_tag(ranks: List[int], tag: str) -> str: + # TODO(yifu): remove this function once ranks + tag is not a supported + # identifier for process group for functional collectives. + group = _find_pg_by_ranks_and_tag(tag, ranks) + if group is None: + raise ValueError("") + return group.group_name + + +def _check_single_tensor(param, param_name) -> None: + """Check that the parameter ``param_name`` is a single tensor.""" + if not isinstance(param, core.Tensor): + raise TypeError( + f"""Invalid function argument. Expected parameter `{param_name}` of type core.Tensor + but got {type(param)} instead.""" + ) + + +def _check_tensor_list(param, param_name) -> None: + """Check that the parameter ``param_name`` is a list of tensors.""" + if not isinstance(param, list): + raise TypeError( + f"""Invalid function argument. Expected parameter `{param_name}` of type List[core.Tensor] + but got {type(param)} instead.""" + ) + elif not all(isinstance(p, core.Tensor) for p in param): + raise TypeError( + f"""Invalid function argument. Expected parameter `{param_name}` of type List[core.Tensor] + but got {type(param)} with elements of type {[type(p) for p in param]}.""" + ) + + +def _group_or_default_group(group: Optional[ProcessGroup] = None) -> ProcessGroup: + if group is None or group is GroupMember.WORLD: + group = _get_default_group() + return group + + +def _canonicalize_group_rank( + group: ProcessGroup, + global_rank: Optional[int] = None, + group_rank: Optional[int] = None, + return_global: bool = False, +) -> int: + """ + Helper method to take _either_ a global rank or a group rank and produce a group rank. + + If 'return_global' is true, produce a global rank instead of a group rank. + """ + + if group_rank is not None: + if global_rank is not None: + raise ValueError("Can't specify both group_rank and global_rank") + global_rank = get_global_rank(group, group_rank) + else: + if global_rank is None: + raise ValueError("Must specify global_rank or group_rank") + group_rank = get_group_rank(group, global_rank) + return global_rank if return_global else group_rank + + +def _check_not_self_rank(group: ProcessGroup, rank: int, rank_type: str): + if group.rank() == rank: + raise ValueError( + f"Invalid {rank_type} rank: {rank_type} rank should not be the same as " + "the rank of the current process." + ) + + +def _as_iterable(obj) -> collections.abc.Iterable: + return obj if isinstance(obj, list) else (obj,) + + +def _ensure_all_tensors_same_dtype(*tensors) -> None: + last_dtype = None + for tensor in itertools.chain.from_iterable(map(_as_iterable, tensors)): + tensor_dtype = tensor.dtype + # Mixing complex and its element type is allowed + # if tensor_dtype.is_complex: + # tensor_dtype = ( + # core.float32 if tensor_dtype == core.complex64 else core.complex128 + # ) + + if last_dtype is None: + last_dtype = tensor_dtype + else: + if last_dtype != tensor_dtype: + raise ValueError( + "Invalid usage of tensors with different dtypes" + f"Found {last_dtype} and {tensor.dtype}" + ) + + +def _check_op(op) -> None: + """Check that the ``op`` is either isend or irecv.""" + if op not in [isend, irecv]: + raise ValueError( + "Invalid ``op``. Expected ``op`` " + "to be of type ``core.distributed.isend`` or " + "``core.distributed.irecv``." + ) + + +def _check_p2p_op_list(p2p_op_list) -> None: + """ + Check that the ``p2p_op_list`` is a list of P2POp instances. + + Also, check that all ops use the same group. + """ + if not isinstance(p2p_op_list, list) or not all( + isinstance(p2p_op, P2POp) for p2p_op in p2p_op_list + ): + raise ValueError( + "Invalid ``p2p_op_list``. Each op is expected to " + "to be of type ``core.distributed.P2POp``." + ) + + group = p2p_op_list[0].group + if not all(group == p2p_op.group for p2p_op in p2p_op_list): + raise ValueError("All ops need to use the same group.") + + +def is_mpi_available() -> bool: + """Check if the MPI backend is available.""" + return _MPI_AVAILABLE + + +def is_nccl_available() -> bool: + """Check if the NCCL backend is available.""" + return _NCCL_AVAILABLE + + +def is_gloo_available() -> bool: + """Check if the Gloo backend is available.""" + return _GLOO_AVAILABLE + + +def is_ucc_available() -> bool: + """Check if the UCC backend is available.""" + return _UCC_AVAILABLE + + +def is_backend_available(backend: str) -> bool: + """ + Check backend availability. + + Checks if the given backend is available and supports the built-in backends or + third-party backends through function ``Backend.register_backend``. + + Args: + backend (str): Backend name. + Returns: + bool: Returns true if the backend is available otherwise false. + """ + # If the backend has an ``is_backend_available`` function, return the result of that function directly + available_func = getattr(core.distributed, f"is_{backend.lower()}_available", None) + if available_func: + return available_func() + + return backend.lower() in Backend.backend_list + + +def is_initialized() -> bool: + """Check if the default process group has been initialized.""" + return GroupMember.WORLD is not None + + +def is_torchelastic_launched() -> bool: + """ + Check whether this process was launched with ``core.distributed.elastic`` (aka torchelastic). + + The existence of ``TORCHELASTIC_RUN_ID`` environment + variable is used as a proxy to determine whether the current process + was launched with torchelastic. This is a reasonable proxy since + ``TORCHELASTIC_RUN_ID`` maps to the rendezvous id which is always a + non-null value indicating the job id for peer discovery purposes.. + """ + return os.getenv("TORCHELASTIC_RUN_ID") is not None + + +def _is_barrier_after_init() -> int: + # Environment variable to control whether process group should perform a + # barrier after its init. Default value is 0, i.e. no barrier. If you + # experience issue with this setting, you may set + # `TORCH_DIST_INIT_BARRIER=1` to add the barrier. + return int(os.getenv("TORCH_DIST_INIT_BARRIER", "0")) + + +def _get_default_group() -> ProcessGroup: + """Get the default process group created by init_process_group.""" + if not is_initialized(): + raise ValueError( + "Default process group has not been initialized, " + "please make sure to call init_process_group." + ) + if TYPE_CHECKING: + return not_none(GroupMember.WORLD) + else: + return GroupMember.WORLD + + +def _get_default_store() -> Store: + """Get the default store created by init_process_group.""" + if not is_initialized(): + raise ValueError( + "Default process group has not been initialized, " + "please make sure to call init_process_group." + ) + default_pg = _get_default_group() + _, default_store = _world.pg_map[default_pg] + return default_store + + +def _update_default_pg(pg) -> None: + _world.default_pg = pg + # rank = pg.rank() if pg is not None and pg != GroupMember.NON_GROUP_MEMBER else -1 + # core._C._distributed_c10d._set_global_rank(rank) + + +def get_backend_config(group: Optional[ProcessGroup] = None) -> str: + """ + Return the backend configuration of the given process group. + + Args: + group (ProcessGroup, optional): The process group to work on. The + default is the general main process group. If another specific group + is specified, the calling process must be part of :attr:`group`. + + Returns: + The backend configuration of the given process group as a lower case string. + + """ + pg = group or _get_default_group() + if _rank_not_in_group(pg): + raise ValueError("Invalid process group specified") + backend_config = _world.pg_backend_config.get(pg) + return str(not_none(backend_config)) + + +def get_backend(group: Optional[ProcessGroup] = None) -> Backend: + """ + Return the backend of the given process group. + + Args: + group (ProcessGroup, optional): The process group to work on. The + default is the general main process group. If another specific group + is specified, the calling process must be part of :attr:`group`. + + Returns: + The backend of the given process group as a lower case string. + + """ + pg = group or _get_default_group() + if _rank_not_in_group(pg): + raise ValueError("Invalid process group specified") + pg_store = _world.pg_map[pg] if pg in _world.pg_map else None + return Backend(not_none(pg_store)[0]) + + +def _get_process_group_uid(pg: ProcessGroup) -> int: + backend = None + try: + backend = pg._get_backend(core.device("cuda")) + except RuntimeError: + pass + if is_nccl_available() and isinstance(backend, ProcessGroupNCCL): + return backend.uid + return -1 + + +def _get_pg_config(group: Optional[ProcessGroup] = None) -> Dict[str, Any]: + """ + Return the pg configuration of the given process group. + + """ + pg = group or _get_default_group() + return { + "pg_name": _get_process_group_name(pg), + "pg_desc": pg.group_desc, + "backend_config": get_backend_config(pg), + "pg_size": _get_group_size(pg), + "ranks": get_process_group_ranks(pg), + } + + +def _get_all_pg_configs() -> List[Dict[str, Any]]: + """ + Return the pg configuration of all the process groups. + + """ + config_info: List[Dict[str, Any]] = [ + _get_pg_config(pg) for pg in _world.pg_map.keys() + ] + return config_info + + +def get_pg_count() -> int: + """ + Return the number of process groups. + + """ + return _world.group_count + + +def get_node_local_rank(fallback_rank: Optional[int] = None) -> int: + """ + Return the local rank of the current process relative to the node. + + Semantically, this is a useful concept for mapping processes to devices. + For example, on a node with 8 accelerator you could use the node local rank to decide + which accelerator device to bind the process to. + + In practice, the actual assignment of node local ranks is handled by the process launcher outside of pytorch, + and communicated via the `LOCAL_RANK` environment variable. + + Torchrun will automatically populate `LOCAL_RANK`, but other launchers may not. If `LOCAL_RANK` is unspecified, + this API will fall back to the provided kwarg 'fallback_rank' if specified, otherwise it will raise an error. The + intent is to allow writing an application that runs either in single or multi device contexts without error. + + """ + if "LOCAL_RANK" in os.environ: + return int(os.environ["LOCAL_RANK"]) + elif fallback_rank is not None: + return int(fallback_rank) + raise RuntimeError( + "LOCAL_RANK is not in the environment. Consider passing fallback_rank to allow `get_node_local_rank` to work, " + "assuming you are not running in a multi-device context and want the code to run locally instead." + ) + + +def _add_ephemeral_timeout_for_all_pgs(timeout: timedelta) -> None: + """ + This API adds an ephemeral timeout extension for all PGs locally + on one rank. The timeout gets reset when the first collective issued + after API called finished. + NOTE: We only support to set timeout for cuda backends for now. + NOTE: While this feature + provides flexibility in specific scenarios, it introduces statefulness + to timeout setting. Therefore, it is advisable to use this API sparingly + and consider alternative approaches, such as directly setting the timeout + or utilizing a barrier collective (one can set any timeout to the barrier), + whenever feasible. + + Args: + timeout (timedelta): The delta of timeout to extend. + + Returns: + None. + """ + for pg in _world.pg_map.keys(): + devices = pg._device_types + if core.device("cuda") in devices: + backend = pg._get_backend(core.device("cuda")) + if is_nccl_available() and isinstance(backend, ProcessGroupNCCL): + backend._add_ephemeral_timeout(timeout) + + +def _set_pg_timeout(timeout: timedelta, group: Optional[ProcessGroup] = None) -> None: + """ + Set the timeout for the given process group when users want to use a different timeout instead of + default values. + + Args: + timeout (timedelta): Timeout for operations executed against the process group which + users want to set. Default value is 10 minutes for NCCL and 30 minutes for other backends. + This is the duration after which collectives will be aborted asynchronously and the process will crash. + This is done since CUDA execution is async and it is no longer safe to continue executing user code since + failed async NCCL operations might result in subsequent CUDA operations running on corrupted data. + When TORCH_NCCL_BLOCKING_WAIT is set, the process will block and wait for this timeout. + + group (ProcessGroup, optional): The process group to work on. The + default is the general main process group. If another specific group + is specified, the calling process must be part of :attr:`group`. + + Returns: + None + """ + if group is None: + group = _get_default_group() + if _rank_not_in_group(group): + raise ValueError("Invalid process group specified") + assert isinstance(group, ProcessGroup) + devices = group._device_types + backends = set() + if core.device("cpu") in devices and is_gloo_available(): + backend = group._get_backend(core.device("cpu")) + if isinstance(backend, ProcessGroupGloo): + backends.add(backend) + if core.device("cuda") in devices: + backend = group._get_backend(core.device("cuda")) + if is_nccl_available() and isinstance(backend, ProcessGroupNCCL): + backends.add(backend) # type: ignore[arg-type] + elif is_gloo_available() and isinstance(backend, ProcessGroupGloo): + backends.add(backend) # type: ignore[arg-type] + if len(backends) == 0: + warnings.warn("Set timeout is now only supported for either nccl or gloo.") + for backend in backends: + backend._set_default_timeout(timeout) + + +@_exception_logger +def init_process_group( + backend: Optional[str] = None, + init_method: Optional[str] = None, + timeout: Optional[timedelta] = None, + world_size: int = -1, + rank: int = -1, + store: Optional[Store] = None, + group_name: str = "", + pg_options: Optional[Any] = None, + device_id: Optional[core.device] = None, +) -> None: + """ + Initialize the default distributed process group. + + This will also initialize the distributed package. + + There are 2 main ways to initialize a process group: + 1. Specify ``store``, ``rank``, and ``world_size`` explicitly. + 2. Specify ``init_method`` (a URL string) which indicates where/how + to discover peers. Optionally specify ``rank`` and ``world_size``, + or encode all required parameters in the URL and omit them. + + If neither is specified, ``init_method`` is assumed to be "env://". + + + Args: + backend (str or Backend, optional): The backend to use. Depending on + build-time configurations, valid values include ``mpi``, ``gloo``, + ``nccl``, and ``ucc``. If the backend is not provided, then both a ``gloo`` + and ``nccl`` backend will be created, see notes below for how multiple + backends are managed. This field can be given as a lowercase string + (e.g., ``"gloo"``), which can also be accessed via + :class:`Backend` attributes (e.g., ``Backend.GLOO``). If using + multiple processes per machine with ``nccl`` backend, each process + must have exclusive access to every GPU it uses, as sharing GPUs + between processes can result in deadlocks. ``ucc`` backend is + experimental. + init_method (str, optional): URL specifying how to initialize the + process group. Default is "env://" if no + ``init_method`` or ``store`` is specified. + Mutually exclusive with ``store``. + world_size (int, optional): Number of processes participating in + the job. Required if ``store`` is specified. + rank (int, optional): Rank of the current process (it should be a + number between 0 and ``world_size``-1). + Required if ``store`` is specified. + store(Store, optional): Key/value store accessible to all workers, used + to exchange connection/address information. + Mutually exclusive with ``init_method``. + timeout (timedelta, optional): Timeout for operations executed against + the process group. Default value is 10 minutes for NCCL and 30 minutes for other backends. + This is the duration after which collectives will be aborted asynchronously and the process will crash. + This is done since CUDA execution is async and it is no longer safe to continue executing user code since + failed async NCCL operations might result in subsequent CUDA operations running on corrupted data. + When TORCH_NCCL_BLOCKING_WAIT is set, the process will block and wait for this timeout. + + group_name (str, optional, deprecated): Group name. This argument is ignored + pg_options (ProcessGroupOptions, optional): process group options + specifying what additional options need to be passed in during + the construction of specific process groups. As of now, the only + options we support is ``ProcessGroupNCCL.Options`` for the ``nccl`` + backend, ``is_high_priority_stream`` can be specified so that + the nccl backend can pick up high priority cuda streams when + there're compute kernels waiting. For other availble options to config nccl, + See https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/types.html#ncclconfig-t + device_id (core.device, optional): a single, specific device + to "bind" this process to, allowing for backend-specific + optimizations. Currently this has two effects, only under + NCCL: the communicator is immediately formed (calling + ``ncclCommInit*`` immediately rather than the normal lazy + call) and sub-groups will use ``ncclCommSplit`` when + possible to avoid unnecessary overhead of group creation. If you + want to know NCCL initialization error early, you can also use this + field. + + .. note:: To enable ``backend == Backend.MPI``, PyTorch needs to be built from source + on a system that supports MPI. + + .. note:: Support for multiple backends is experimental. Currently when no backend is + specified, both ``gloo`` and ``nccl`` backends will be created. The ``gloo`` backend + will be used for collectives with CPU tensors and the ``nccl`` backend will be used + for collectives with CUDA tensors. A custom backend can be specified by passing in + a string with format ":,:", e.g. + "cpu:gloo,cuda:custom_backend". + + """ + global _world + + global _backend + global _default_pg_init_method + + if GroupMember.WORLD is not None: + raise ValueError("trying to initialize the default process group twice!") + + # do mindspore communication init + init(backend_name=backend) + + # Convert string into `Backend` type + backend = Backend(backend) + + if timeout is None: + timeout = _get_default_timeout(backend) + + _check_valid_timeout(timeout) + + group_name = GlobalComm.WORLD_COMM_GROUP + if backend == Backend.MPI: + if world_size != -1 or rank != -1: + warnings.warn( + f"For MPI backend, world_size ({world_size}) and rank ({rank}) " + "are ignored since they are assigned by the " + "MPI runtime." + ) + + default_pg, _ = _new_process_group_helper( + -1, + -1, + [], + backend, + Store(), # Placeholder value since store cannot be None + group_name, + timeout=timeout, + group_desc="default_pg", + ) + _update_default_pg(default_pg) + else: + # backward compatible API + if store is None: + # store, rank, world_size = next(rendezvous_iterator) + rank = _get_rank(group_name) + world_size = get_group_size(group_name) + # store.set_timeout(timeout) + + # Use a PrefixStore to avoid accidental overrides of keys used by + # different systems (e.g. RPC) in case the store is multi-tenant. + store = PrefixStore("default_pg", store) + + default_pg, _ = _new_process_group_helper( + world_size, + rank, + [], + backend, + store, + group_name, + backend_options=pg_options, + timeout=timeout, + device_id=device_id, + group_desc="default_pg", + ) + _update_default_pg(default_pg) + + _world.pg_group_ranks[GroupMember.WORLD] = {i: i for i in range(GroupMember.WORLD.size())} # type: ignore[attr-defined, index] + _backend = _world.pg_map[not_none(GroupMember.WORLD)][0] + _default_pg_init_method = init_method + + old_hook = sys.excepthook + excepthook_prefix = f"[rank{get_rank()}]" + + def _distributed_excepthook(*args): + old_stderr = sys.stderr + sys.stderr = buf = io.StringIO() + try: + old_hook(*args) + finally: + sys.stderr = old_stderr + msg = buf.getvalue() + msg = "\n".join( + f"{excepthook_prefix}: {s}" if s != "" else "" for s in msg.split("\n") + ) + sys.stderr.write(msg) + sys.stderr.flush() + + sys.excepthook = _distributed_excepthook + + if _is_barrier_after_init() == 1: + # barrier at the end to ensure that once we return from this method, all + # process groups including global variables (if any) are updated + # correctly on all ranks. + # Update 04/2023: for large-scale runs, this barrier (esp. store-based + # barrier) may be costly and/or unscalable. Also, in a lot of cases, + # these barriers may be unnecessary, as proven by a green CI after + # removal. An environment variable `TORCH_DIST_INIT_BARRIER` has been + # added which enables this barrier only when set to 1. + logger.debug( + "Performing barrier after ProcessGroup initialization since " + "TORCH_DIST_INIT_BARRIER = 1" + ) + if backend == Backend.MPI: + # MPI backend doesn't use store. + barrier() + else: + # Use store based barrier here since barrier() used a bunch of + # default devices and messes up NCCL internal state. + _store_based_barrier(rank, store, group_name, world_size, timeout) + + +def _get_split_source(pg): + split_from = None + if pg.bound_device_id: + split_from = pg._get_backend(pg.bound_device_id) + elif pg is _world.default_pg: + try: + split_from = pg._get_backend(core.device("cuda")) + except RuntimeError: + # no cuda device associated with this backend + pass + + if not split_from or not split_from.supports_splitting: + return None + + # If necessary, find a backend to split from by peeling process + # group wrappers from our potentially wrapped process group. + while _GLOO_AVAILABLE and isinstance(split_from, _ProcessGroupWrapper): + split_from = split_from.wrapped_pg + + return split_from + + +def _shutdown_backend(pg): + """ + Try to shut down the backend of a process group. + Currently, only ProcessGroupNCCL backend is supported. + No op for other backends. + """ + backend = None + try: + backend = pg._get_backend(core.device("cuda")) + except RuntimeError: + pass + if is_nccl_available() and isinstance(backend, ProcessGroupNCCL): + # explictly call shutdown to ensure that NCCL resources are released + backend._shutdown() + + +def _abort_backend(pg: ProcessGroup): + """ + Abort the backend of a process group. + Currently, only ProcessGroupNCCL backend is supported. + No op for other backends. + """ + try: + backend = pg._get_backend(core.device("cuda")) + except RuntimeError: + backend = None + if isinstance(backend, ProcessGroupNCCL): + backend.abort() + + +def _new_process_group_helper( + group_size, + group_rank, + global_ranks_in_group, + backend, + store, + group_name, + backend_options=None, + timeout=None, + pg_tag=None, + device_id=None, + group_desc=None, +): + """ + Create a new distributed process group. + + This function must be called by ALL processes in the global group, even if + the calling process is not part of the newly created group. In that case, + this function returns GroupMember.NON_GROUP_MEMBER. + + This function is called with ``global_ranks_in_group == []`` for the default group. + """ + global _world + + if group_name in _world.pg_names.values(): + raise ValueError( + "The specified group name has already been " + "created, please use a different group name" + ) + + if device_id is not None and (device_id.index is None or device_id.type != "cuda"): + raise ValueError( + "init_process_group device_id parameter must be a cuda device with an " + "id, e.g. cuda:0, not just cuda or cpu" + ) + + # Note: _new_process_group_helper is only called from init_process_group, which always provides a timeout value + _check_valid_timeout(timeout) + + if pg_tag not in [None, ""]: + # creating with the same tag and rank set results in the same underlying PG + existing_group = _find_pg_by_ranks_and_tag(pg_tag, global_ranks_in_group) + if existing_group: + _, prefix_store = _world.pg_map[existing_group] + return existing_group, prefix_store + + group_desc = "undefined" if group_desc is None else group_desc + + # The list of group ranks is empty if we're creating the default group. + is_default_group = len(global_ranks_in_group) == 0 + + # nccl and potentially other backends allow creation of + # communicators based on pre-existing ones, which can save + # initialization time. Due to lazy initialization of + # communicators in some backends, we have to be careful and only + # split when we *know* the default PG has already started communicator initialization. + # We know this if we have bound a device id to the default pg (eager initialized). + # if is_initialized() and _get_default_group().bound_device_id: + # split_from = _get_split_source(_get_default_group()) + # else: + split_from = None + + # If this is a subgroup (which means group_ranks is specified), + # we check if the current process is a member of the new group. + if not is_default_group: + global_rank = _get_default_group().rank() + if global_rank not in global_ranks_in_group: + # If we are using `ncclCommSplit` (or similar split from + # other APIs) to create the communicator, we will need to + # call `ncclCommSplit` on *all* ranks in this new group's + # parent group, even those not in the new group. This is + # a requirement of the NCCL API as otherwise we would get + # out of sync. + if split_from: + split_from.perform_nocolor_split(_get_default_group().bound_device_id) + return GroupMember.NON_GROUP_MEMBER, None + + prefix_store = PrefixStore(f"{group_name}/", store) + # The backend for PG will be set later based on what's inside BackendConfig + # and timeout are set in each backend's option. + pg: ProcessGroup = ProcessGroup( + prefix_store, + group_rank, + group_size, + ) + + device = 'npu' if backend == 'hccl' else 'cpu' + pg._register_backend(core.device(device), backend, backend) + + # update global state + _world.pg_map[pg] = (backend, prefix_store) + _world.pg_names[pg] = group_name + if not is_default_group: + create_group(group_name, global_ranks_in_group) + + # _world.pg_backend_config[pg] = str(backend_config) + # "" is the default tag for user PGs + if pg_tag in [None, ""]: + pg_tag = f"ptd:{group_name}" + _world.tags_to_pg.setdefault("", []).append(pg) + else: + pg_tag = f"user:{pg_tag}" + + _world.tags_to_pg.setdefault(pg_tag, []).append(pg) + _world.pg_to_tag[pg] = pg_tag + return pg, prefix_store + + +def destroy_process_group(group: Optional[ProcessGroup] = None): + """ + Destroy a given process group, and deinitialize the distributed package. + + Args: + group (ProcessGroup, optional): The process group to be destroyed, if + group.WORLD is given, all process + groups including the default one will + be destroyed. + """ + global _world + + if group == GroupMember.NON_GROUP_MEMBER: + return + + if group is None: + pg = GroupMember.WORLD + else: + pg = group + + assert pg is not None + if _world.pg_map.get(pg, None) is None: + raise ValueError("Invalid process group specified") + + # When users register Python onCompletion hooks, those hooks will run on a + # different thread than the main thread. Today, the ProcessGroup dtor does + # wait for that thread. However, the dtor might finish after the Python + # Interpreter exits. After that grabbing the GIL for the Python hook will crash. + # We can either revive the interpreter when running hooks or keep the main one + # alive until all works and hooks are done. The current implementation does the + # latter. Therefore, we explicitly call _wait_for_pending_works() here to wait + # for the pending hooks to finish. + if pg.name().lower() == "nccl" and pg._has_hooks(): + pg._wait_for_pending_works() + + if group is None or group == GroupMember.WORLD: + # shutdown all backends in the order of pg names. shutting down in order because + # ncclCommAbort() was a 'collective' call in some versions of NCCL. + for pg_to_shutdown in sorted( + _world.pg_names, key=lambda x: _world.pg_names[x], reverse=True + ): + _shutdown_backend(pg_to_shutdown) + + _update_default_pg(None) + _world.pg_map.clear() + _world.pg_names.clear() + _world.pg_group_ranks.clear() + _world.pg_backend_config.clear() + _world.pg_to_tag.clear() + _world.tags_to_pg.clear() + _world.pg_coalesce_state.clear() + _unregister_all_process_groups() + + # when process group doesn't have an explicit name (only WORLD (default) + # process group can have an explicit name), we use global _world.group_count + # to generate the name. We need to reset the counter on destruction to + # allow consistent value to be generated when we re-create process + # groups after some trainers recover from failure + # + # We only reset this when WORLD is being destroyed because if this + # process group is in good state, we aren't dealing with failures. + _world.group_count = 0 + else: + _shutdown_backend(pg) + del _world.pg_map[pg] + del _world.pg_names[pg] + del _world.pg_group_ranks[pg] + del _world.pg_backend_config[pg] + if pg in _world.pg_coalesce_state.keys(): + warnings.warn( + "Some coalesced collectives haven't been launched when " + "ProcessGroup is destroyed. They will be cleaned." + ) + del _world.pg_coalesce_state[pg] + + tag = _world.pg_to_tag.get(pg) + del _world.pg_to_tag[pg] + if tag is not None: + try: + _world.tags_to_pg[tag].remove(pg) + if tag.startswith("ptd:"): + _world.tags_to_pg[""].remove(pg) + except Exception: + pass + _unregister_process_group(pg.group_name) + + +def _abort_process_group(group: Optional[ProcessGroup] = None): + """ + Abort a given process group. If group.WORLD (i.e. `None`) is given, all + process groups including the default one will be aborted. + + Args: + group (ProcessGroup, optional): The process group to be aborted. + + .. note:: this API is experimental and currently only works with the NCCL + backend. + + .. note:: this API should be used with `TORCH_NCCL_ASYNC_ERROR_HANDLING` + turned off (i.e. set to 0). Otherwise, ProcessGroupNCCL's watchdog may + automatically handle errors or timeouts for you including aborting the + ProcessGroup. + """ + global _world + + if group == GroupMember.NON_GROUP_MEMBER: + return + + pg = group or GroupMember.WORLD + + assert pg is not None + if _world.pg_map.get(pg, None) is None: + raise ValueError("Invalid process group specified or has been destroyed.") + + try: + backend = pg._get_backend(core.device("cuda")) + except RuntimeError: + backend = None + + if not isinstance(backend, ProcessGroupNCCL): + logger.warning( + "`abort_process_group` currently only has implementation for ProcessGroupNCCL; " + "however, no NCCL backend is found. This call will be a no-op." + ) + return + + if group == GroupMember.WORLD: + # Abort all backends within a ncclGroupStart|End semantic. + # This ensures that different NCCL communicators' abort calls won't + # deadlock each other. + # For details, please see: https://github.com/pytorch/pytorch/issues/119797 + backend._group_start() + for pg_to_abort in sorted( + _world.pg_names, key=lambda x: _world.pg_names[x], reverse=True + ): + _abort_backend(pg_to_abort) + backend._group_end() + + _update_default_pg(None) + _world.pg_map.clear() + _world.pg_names.clear() + _world.pg_group_ranks.clear() + _world.pg_backend_config.clear() + _world.pg_to_tag.clear() + _world.tags_to_pg.clear() + _world.pg_coalesce_state.clear() + _unregister_all_process_groups() + + # when process group doesn't have an explicit name (only WORLD (default) + # process group can have an explicit name), we use global _world.group_count + # to generate the name. We need to reset the counter on destruction to + # allow consistent value to be generated when we re-create process + # groups after some trainers recover from failure + # + # We only reset this when WORLD is being destroyed because if this + # process group is in good state, we aren't dealing with failures. + _world.group_count = 0 + else: + _abort_backend(pg) + del _world.pg_map[pg] + del _world.pg_names[pg] + del _world.pg_group_ranks[pg] + del _world.pg_backend_config[pg] + if pg in _world.pg_coalesce_state.keys(): + warnings.warn( + "Some coalesced collectives haven't been launched when " + "ProcessGroup is aborted. They will be cleaned." + ) + del _world.pg_coalesce_state[pg] + + tag = _world.pg_to_tag.get(pg) + del _world.pg_to_tag[pg] + if tag is not None: + try: + _world.tags_to_pg[tag].remove(pg) + if tag.startswith("ptd:"): + _world.tags_to_pg[""].remove(pg) + except Exception: + pass + _unregister_process_group(pg.group_name) + + +def get_rank(group: Optional[ProcessGroup] = None) -> int: + """ + Return the rank of the current process in the provided ``group``, default otherwise. + + Rank is a unique identifier assigned to each process within a distributed + process group. They are always consecutive integers ranging from 0 to + ``world_size``. + + Args: + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + + Returns: + The rank of the process group + -1, if not part of the group + + """ + if _rank_not_in_group(group): + return -1 + + default_pg = _get_default_group() + if group is None or group is GroupMember.WORLD: + return default_pg.rank() + + return get_group_rank(group, default_pg.rank()) + + +def get_world_size(group: Optional[ProcessGroup] = None) -> int: + """ + Return the number of processes in the current process group. + + Args: + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + + Returns: + The world size of the process group + -1, if not part of the group + + """ + if _rank_not_in_group(group): + return -1 + + return _get_group_size(group) + + +def isend( + tensor: core.Tensor, + dst: Optional[int] = None, + group: Optional[ProcessGroup] = None, + tag: int = 0, + group_dst: Optional[int] = None, +) -> Optional[Work]: + """ + Send a tensor asynchronously. + + .. warning:: + Modifying ``tensor`` before the request completes causes undefined + behavior. + + .. warning:: + ``tag`` is not supported with the NCCL backend. + + Unlike send, which is blocking, isend allows src == dst rank, i.e. send to self. + + Args: + tensor (Tensor): Tensor to send. + dst (int): Destination rank on global process group (regardless of ``group`` argument) + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + tag (int, optional): Tag to match send with remote recv + group_dst (int, optional): Destination rank on ``group``. Invalid to specify both ``dst`` and ``group_dst`` + + Returns: + A distributed request object. + None, if not part of the group + + """ + group = _group_or_default_group(group) + group_dst = _canonicalize_group_rank(group, dst, group_dst) + _check_single_tensor(tensor, "tensor") + if _rank_not_in_group(group): + _warn_not_in_group("isend") + return None + + if tensor.is_complex(): + tensor = core.view_as_real(tensor) + + return group.send([tensor], group_dst, tag) + + +def irecv( + tensor: core.Tensor, + src: Optional[int] = None, + group: Optional[ProcessGroup] = None, + tag: int = 0, + group_src: Optional[int] = None, +) -> Optional[Work]: + """ + Receives a tensor asynchronously. + + .. warning:: + ``tag`` is not supported with the NCCL backend. + + Unlike recv, which is blocking, irecv allows src == dst rank, i.e. recv from self. + + Args: + tensor (Tensor): Tensor to fill with received data. + src (int, optional): Source rank on global process group (regardless of ``group`` argument). + Will receive from any process if unspecified. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + tag (int, optional): Tag to match recv with remote send + group_src (int, optional): Destination rank on ``group``. Invalid to specify both ``src`` and ``group_src``. + + Returns: + A distributed request object. + None, if not part of the group + + """ + _check_single_tensor(tensor, "tensor") + if _rank_not_in_group(group): + _warn_not_in_group("irecv") + return None + + if tensor.is_complex(): + tensor = core.view_as_real(tensor) + + group = _group_or_default_group(group) + if src is None and group_src is None: + return group.recv_anysource([tensor], tag) + else: + group_src = _canonicalize_group_rank(group, src, group_src) + return group.recv([tensor], group_src, tag) + + +@_exception_logger +def send( + tensor: core.Tensor, + dst: Optional[int] = None, + group: Optional[ProcessGroup] = None, + tag: int = 0, + group_dst: Optional[int] = None, +) -> None: + """ + Send a tensor synchronously. + + .. warning:: + ``tag`` is not supported with the NCCL backend. + + Args: + tensor (Tensor): Tensor to send. + dst (int): Destination rank on global process group (regardless of ``group`` argument). + Destination rank should not be the same as the rank of the current process. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + tag (int, optional): Tag to match send with remote recv + group_dst (int, optional): Destination rank on ``group``. Invalid to specify both ``dst`` and ``group_dst``. + + """ + group = _group_or_default_group(group) + group_dst = _canonicalize_group_rank(group, dst, group_dst) + _check_not_self_rank(group, group_dst, "destination") + work = isend(tensor, group=group, tag=tag, group_dst=group_dst) + if work is not None: + work.wait() + + +@_exception_logger +def recv( + tensor: core.Tensor, + src: Optional[int] = None, + group: Optional[ProcessGroup] = None, + tag: int = 0, + group_src: Optional[int] = None, +) -> int: + """ + Receives a tensor synchronously. + + .. warning:: + ``tag`` is not supported with the NCCL backend. + + Args: + tensor (Tensor): Tensor to fill with received data. + src (int, optional): Source rank on global process group (regardless of ``group`` argument). + Will receive from any process if unspecified. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + tag (int, optional): Tag to match recv with remote send + group_src (int, optional): Destination rank on ``group``. Invalid to specify both ``src`` and ``group_src``. + Returns: + Sender rank + -1, if not part of the group + + """ + work = irecv(tensor, src=src, group=group, tag=tag, group_src=group_src) + if work is None: + return -1 + work.wait() + if src is None: + if group_src is None: + group_src = work._source_rank() + group = _group_or_default_group(group) + _check_not_self_rank(group, group_src, "source") + src = get_global_rank(group, group_src) + return src + + +class _IllegalWork(Work): + def __getattribute__(self, name): + if name in [ + "is_success", + "exception", + "wait", + "source_rank", + "_source_rank", + "result", + "synchronize", + ]: + raise ValueError(f"Illegal to call {name} on IllegalWork object") + + +class _CoalescingManager: + def __init__(self) -> None: + self.works: List[Work] = [] + + def append(self, work: Work): + if work: + self.works.append(work) + + def wait(self): + for work in self.works: + work.wait() + + +@contextlib.contextmanager +def _coalescing_manager( + group: Optional[ProcessGroup] = None, + device: Optional[core.device] = None, + async_ops: Optional[bool] = False, +): + """ + Context manager used to coalesce collectives or P2P operations when possible. + + Args: + group (`ProcessGroup`, optional): The process group to work on. If None, + the default process group will be used. + device (`core.device`, optional): Default is None, set to a device if + there isn't a `**_coalesced` implementation by the backend. + async_ops (`bool`, optional): whether the coalesced ops are async ops. + + Examples: + >>> # xdoctest: +SKIP("no rank") + >>> # Synchronous ops + >>> with _coalescing_manager(): + >>> for i in range(num_colls): + >>> dist.all_reduce(tensors[i]) + >>> # Asynchronous ops + >>> with _coalescing_manager(async_ops=True) as cm: + >>> for i in range(num_colls): + >>> dist.all_reduce(tensors[i]) + >>> cm.wait() + + .. warning:: + :func:`_coalescing_manager` currently do not support coalescing + all-reduces with different reduce operators, e.g. `ReduceOp.SUM` mixed + with `ReduceOp.PRODUCT`. + """ + group = group or _get_default_group() + op_list = _world.pg_coalesce_state.setdefault(group, []) + if op_list: + raise ValueError( + "ProcessGroup has non-empty op list at the start of coalescing" + ) + if device: + group._start_coalescing(device) + cm = _CoalescingManager() + yield cm + op_list = _world.pg_coalesce_state.pop(group) + if op_list: + # Collectives supporting "Fast Path" coalescing are captured. + # See implementation in corresponding collective APIs. + # Currently supported: + # - coalesced `all_reduce` + # - coalesced `all_gather_into_tensor` + # - coalesced `reduce_scatter_tensor` + op0 = op_list[0].op + if op0 == all_reduce: + tensors = [op.tensor for op in op_list] + all_reduce_opts = AllreduceCoalescedOptions() + all_reduce_opts.reduceOp = not_none(op_list[0].redop) + work = group.allreduce_coalesced(tensors, all_reduce_opts) + elif op0 == all_gather_into_tensor: + inputs = [] + outputs = [] + for op in op_list: + inputs.append(op.tensor) + outputs.append(not_none(op.dst_tensor)) + work = group.allgather_into_tensor_coalesced(outputs, inputs) + elif op0 == reduce_scatter_tensor: + inputs = [] + outputs = [] + for op in op_list: + inputs.append(op.tensor) + outputs.append(not_none(op.dst_tensor)) + reduce_opts = ReduceScatterOptions() + reduce_opts.reduceOp = not_none(op_list[0].redop) + work = group.reduce_scatter_tensor_coalesced(outputs, inputs, reduce_opts) + else: + raise AssertionError( + f"Coalescing manager does not support fast-path coalescing of {op0}, " + f"yet {op0} is still recorded in op list. This is an internal error of c10d." + ) + + if device: + # Old style of letting each coll inside the context manager to call into C++ counterpart via python binding + work = group._end_coalescing(device) + + if async_ops: + cm.append(work) # type: ignore[possibly-undefined] + else: + work.wait() # type: ignore[possibly-undefined] + + +def batch_isend_irecv(p2p_op_list: List[P2POp]) -> List[Work]: + """ + Send or Receive a batch of tensors asynchronously and return a list of requests. + + Process each of the operations in ``p2p_op_list`` and return the corresponding + requests. NCCL, Gloo, and UCC backend are currently supported. + + Args: + p2p_op_list: A list of point-to-point operations(type of each operator is + ``core.distributed.P2POp``). The order of the isend/irecv in the list + matters and it needs to match with corresponding isend/irecv on the + remote end. + + Returns: + A list of distributed request objects returned by calling the corresponding + op in the op_list. + + Examples: + >>> # xdoctest: +SKIP("no rank") + >>> send_tensor = core.arange(2, dtype=core.float32) + 2 * rank + >>> recv_tensor = core.randn(2, dtype=core.float32) + >>> send_op = dist.P2POp(dist.isend, send_tensor, (rank + 1)%world_size) + >>> recv_op = dist.P2POp(dist.irecv, recv_tensor, (rank - 1 + world_size)%world_size) + >>> reqs = batch_isend_irecv([send_op, recv_op]) + >>> for req in reqs: + >>> req.wait() + >>> recv_tensor + tensor([2, 3]) # Rank 0 + tensor([0, 1]) # Rank 1 + + .. note:: Note that when this API is used with the NCCL PG backend, users must set + the current GPU device with `core.cuda.set_device`, otherwise it will + lead to unexpected hang issues. + + In addition, if this API is the first collective call in the ``group`` + passed to ``dist.P2POp``, all ranks of the ``group`` must participate in + this API call; otherwise, the behavior is undefined. If this API call is + not the first collective call in the ``group``, batched P2P operations + involving only a subset of ranks of the ``group`` are allowed. + """ + _check_p2p_op_list(p2p_op_list) + group = p2p_op_list[0].group + # device = p2p_op_list[0].tensor.device + + def peer_kwarg(op: P2POp) -> Dict[str, int]: + key = "group_dst" if op.op == isend else "group_src" + return {key: op.group_peer} + + # if device.type == "cuda": + # NCCL style coalescing + with _coalescing_manager(group, None, async_ops=True) as cm: + for p2p_op in p2p_op_list: + p2p_op.op( + p2p_op.tensor, + group=p2p_op.group, + tag=p2p_op.tag, + **peer_kwarg(p2p_op), + ) + + return cm.works + # else: + # # Backward support for Gloo + # reqs = [] + # for p2p_op in p2p_op_list: + # work = p2p_op.op( + # p2p_op.tensor, + # group=p2p_op.group, + # tag=p2p_op.tag, + # **peer_kwarg(p2p_op), + # ) + # if work: + # reqs.append(work) + # return reqs + +@_exception_logger +def broadcast( + tensor: core.Tensor, + src: Optional[int] = None, + group: Optional[ProcessGroup] = None, + async_op: bool = False, + group_src: Optional[int] = None, +): + """ + Broadcasts the tensor to the whole group. + + ``tensor`` must have the same number of elements in all processes + participating in the collective. + + Args: + tensor (Tensor): Data to be sent if ``src`` is the rank of current + process, and tensor to be used to save received data otherwise. + src (int): Source rank on global process group (regardless of ``group`` argument). + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + async_op (bool, optional): Whether this op should be an async op + group_src (int): Source rank on ``group``. Must specify one of ``group_src`` + and ``src`` but not both. + + Returns: + Async work handle, if async_op is set to True. + None, if not async_op or if not part of the group + + """ + group = _group_or_default_group(group) + group_src = _canonicalize_group_rank(group, src, group_src, return_global=False) + # _check_single_tensor(tensor, "tensor") + if _rank_not_in_group(group): + _warn_not_in_group("broadcast") + return + + opts = BroadcastOptions() + opts.rootRank = group_src + opts.rootTensor = 0 + opts.asyncOp = async_op + work = group.broadcast([tensor], opts) + if async_op: + return work + else: + work.wait() + +@_exception_logger +def all_reduce(tensor, op=ReduceOp.SUM, group=None, async_op=False): + """ + Reduces the tensor data across all machines in a way that all get the final result. + + After the call ``tensor`` is going to be bitwise identical in all processes. + + Complex tensors are supported. + + Args: + tensor (Tensor): Input and output of the collective. The function + operates in-place. + op (optional): One of the values from + ``core.distributed.ReduceOp`` + enum. Specifies an operation used for element-wise reductions. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + async_op (bool, optional): Whether this op should be an async op + + Returns: + Async work handle, if async_op is set to True. + None, if not async_op or if not part of the group + + Examples: + >>> # xdoctest: +SKIP("no rank") + >>> # All tensors below are of core.int64 type. + >>> # We have 2 process groups, 2 ranks. + >>> device = core.device(f'cuda:{rank}') + >>> tensor = core.arange(2, dtype=core.int64, device=device) + 1 + 2 * rank + >>> tensor + tensor([1, 2], device='cuda:0') # Rank 0 + tensor([3, 4], device='cuda:1') # Rank 1 + >>> dist.all_reduce(tensor, op=ReduceOp.SUM) + >>> tensor + tensor([4, 6], device='cuda:0') # Rank 0 + tensor([4, 6], device='cuda:1') # Rank 1 + + >>> # All tensors below are of core.cfloat type. + >>> # We have 2 process groups, 2 ranks. + >>> tensor = core.tensor([1+1j, 2+2j], dtype=core.cfloat, device=device) + 2 * rank * (1+1j) + >>> tensor + tensor([1.+1.j, 2.+2.j], device='cuda:0') # Rank 0 + tensor([3.+3.j, 4.+4.j], device='cuda:1') # Rank 1 + >>> dist.all_reduce(tensor, op=ReduceOp.SUM) + >>> tensor + tensor([4.+4.j, 6.+6.j], device='cuda:0') # Rank 0 + tensor([4.+4.j, 6.+6.j], device='cuda:1') # Rank 1 + + """ + _check_single_tensor(tensor, "tensor") + if _rank_not_in_group(group): + _warn_not_in_group("all_reduce") + return + + if tensor.is_complex(): + if not supports_complex(op): + raise ValueError(f"all_reduce does not support {op} on complex tensors") + tensor = core.view_as_real(tensor) + + opts = AllreduceOptions() + opts.reduceOp = op + if group is None: + group = _get_default_group() + + if group in _world.pg_coalesce_state.keys(): + # We are in coalescing context, do not issue single operation, just append a collective representation + coll = _CollOp(all_reduce, tensor, None, op, None) + _world.pg_coalesce_state[group].append(coll) + if async_op: + return _IllegalWork() + else: + return None + + work = group.allreduce([tensor], opts) + + if async_op: + return work + else: + work.wait() + + +@_exception_logger +@deprecated( + "`core.distributed.all_reduce_coalesced` will be deprecated. If you must " + "use it, please revisit our documentation later at " + "https://pycore.org/docs/main/distributed.html#collective-functions", + category=FutureWarning, +) +def all_reduce_coalesced(tensors, op=ReduceOp.SUM, group=None, async_op=False): + """ + WARNING: at this time individual shape checking is not implemented across nodes. + + For example, if the rank 0 node passes [core.rand(4), core.rand(2)] and the + rank 1 node passes [core.rand(2), core.rand(2), core.rand(2)], the allreduce + operation will proceed without complaint and return erroneous outputs. This lack + of shape checking results in significant performance improvements but users of this + function should take extra care to ensure that each node passes in tensors whose + shapes match across nodes. + + Reduces each tensor in tensors (residing on the same device) across all machines + in such a way that all get the final result. + + After the call each tensor in tensors is going to bitwise identical + in all processes. + + Complex tensors are supported. + + Args: + tensors (Union[List[Tensor], Tensor]): Input and output of the collective. + The function operates in-place. + op (Optional[ReduceOp]): One of the values from + ``core.distributed.ReduceOp`` enum. Specifies an operation used for + element-wise reductions. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + async_op (Optional[bool]): Whether this op should be an async op. + + Returns: + Async work handle, if async_op is set to True. + None, if not async_op or if not part of the group. + + """ + if isinstance(tensors, core.Tensor): + tensors = [tensors] + _check_tensor_list(tensors, "tensor") + _ensure_all_tensors_same_dtype(tensors) + if _rank_not_in_group(group): + _warn_not_in_group("all_reduce_coalesced") + return + + if any(t.is_complex() for t in tensors) and not supports_complex(op): + raise ValueError(f"all_reduce does not support {op} on complex tensors") + + tensors = [t if not t.is_complex() else core.view_as_real(t) for t in tensors] + + opts = AllreduceCoalescedOptions() + opts.reduceOp = op + group = group or _get_default_group() + work = group.allreduce_coalesced(tensors, opts) + + if async_op: + return work.get_future() + else: + work.wait() + + +@_exception_logger +def reduce( + tensor: core.Tensor, + dst: Optional[int] = None, + op=ReduceOp.SUM, + group: Optional[ProcessGroup] = None, + async_op: bool = False, + group_dst: Optional[int] = None, +): + """ + Reduces the tensor data across all machines. + + Only the process with rank ``dst`` is going to receive the final result. + + Args: + tensor (Tensor): Input and output of the collective. The function + operates in-place. + dst (int): Destination rank on global process group (regardless of ``group`` argument) + op (optional): One of the values from + ``core.distributed.ReduceOp`` + enum. Specifies an operation used for element-wise reductions. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + async_op (bool, optional): Whether this op should be an async op + group_dst (int): Destination rank on ``group``. Must specify one of ``group_dst`` + and ``dst`` but not both. + + Returns: + Async work handle, if async_op is set to True. + None, if not async_op or if not part of the group + + """ + group = _group_or_default_group(group) + group_dst = _canonicalize_group_rank(group, dst, group_dst, return_global=False) + _check_single_tensor(tensor, "tensor") + if _rank_not_in_group(group): + _warn_not_in_group("reduce") + return + + opts = ReduceOptions() + opts.reduceOp = op + opts.rootRank = group_dst + out = group.reduce([tensor], opts) + # if async_op: + # return work + # else: + # work.wait() + return out + +def _object_to_tensor(obj, device, group): + f = io.BytesIO() + _pickler(f).dump(obj) + byte_data = f.getvalue() + byte_tensor = core.Tensor(core.Tensor.convert_bytes_to_tensor(byte_data, (len(byte_data),), core.int8)) + # Do not replace `core.ByteTensor` or `core.LongTensor` with core.tensor and specifying dtype. + # Otherwise, it will casue 100X slowdown. + # See: https://github.com/pytorch/pytorch/issues/65696 + local_size = core.Tensor([byte_tensor.numel()], dtype=core.int32) + return byte_tensor, local_size + + +def _tensor_to_object(tensor, tensor_size, group): + buf = tensor.asnumpy().tobytes()[:tensor_size] + return _unpickler(io.BytesIO(buf)).load() + + +@_exception_logger +def all_gather_object(object_list, obj, group=None): + """ + Gathers picklable objects from the whole group into a list. + + Similar to :func:`all_gather`, but Python objects can be passed in. + Note that the object must be picklable in order to be gathered. + + Args: + object_list (list[Any]): Output list. It should be correctly sized as the + size of the group for this collective and will contain the output. + obj (Any): Pickable Python object to be broadcast from current process. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. Default is ``None``. + + Returns: + None. If the calling rank is part of this group, the output of the + collective will be populated into the input ``object_list``. If the + calling rank is not part of the group, the passed in ``object_list`` will + be unmodified. + + .. note:: Note that this API differs slightly from the :func:`all_gather` + collective since it does not provide an ``async_op`` handle and thus + will be a blocking call. + + .. note:: For NCCL-based processed groups, internal tensor representations + of objects must be moved to the GPU device before communication takes + place. In this case, the device used is given by + ``core.cuda.current_device()`` and it is the user's responsiblity to + ensure that this is set so that each rank has an individual GPU, via + ``core.cuda.set_device()``. + + .. warning:: + :func:`all_gather_object` uses ``pickle`` module implicitly, which is + known to be insecure. It is possible to construct malicious pickle data + which will execute arbitrary code during unpickling. Only call this + function with data you trust. + + .. warning:: + Calling :func:`all_gather_object` with GPU tensors is not well supported + and inefficient as it incurs GPU -> CPU transfer since tensors would be + pickled. Please consider using :func:`all_gather` instead. + + Example:: + >>> # xdoctest: +SKIP("need process group init") + >>> # Note: Process group initialization omitted on each rank. + >>> from mindnlp import core.distributed as dist + >>> # Assumes world_size of 3. + >>> gather_objects = ["foo", 12, {1: 2}] # any picklable object + >>> output = [None for _ in gather_objects] + >>> dist.all_gather_object(output, gather_objects[dist.get_rank()]) + >>> output + ['foo', 12, {1: 2}] + """ + if _rank_not_in_group(group): + _warn_not_in_group("all_gather_object") + return + + # current_device = _get_object_coll_device(group) + input_tensor, local_size = _object_to_tensor(obj, None, group) + + # Gather all local sizes. This is so that we can find the max size, and index + # until the correct size when deserializing the tensors. + group_size = get_world_size(group=group) + object_sizes_tensor = core.zeros( + group_size, dtype=core.int32 + ) + object_size_list = [ + object_sizes_tensor[i].unsqueeze(dim=0) for i in range(group_size) + ] + # Allgather tensor sizes + object_size_list, _ = all_gather(object_size_list, local_size, group=group) + max_object_size = int(max(object_size_list).item()) # type: ignore[type-var] + # Resize tensor to max size across all ranks. + if max_object_size - input_tensor.shape[0] > 0: + input_tensor = core.concat([input_tensor, core.zeros(max_object_size - input_tensor.shape[0], dtype=input_tensor.dtype)]) + + coalesced_output_tensor = core.empty( + max_object_size * group_size, dtype=core.int8 + ) + # Output tensors are nonoverlapping views of coalesced_output_tensor + output_tensors = [ + coalesced_output_tensor[max_object_size * i : max_object_size * (i + 1)] + for i in range(group_size) + ] + output_tensors, _ = all_gather(output_tensors, input_tensor, group=group) + # Deserialize outputs back to object. + for i, tensor in enumerate(output_tensors): + tensor = tensor.type(core.int8) + tensor_size = object_size_list[i] + object_list[i] = _tensor_to_object(tensor, tensor_size, group) + + +@_exception_logger +def gather_object( + obj: Any, + object_gather_list: Optional[List[Any]] = None, + dst: Optional[int] = None, + group: Optional[ProcessGroup] = None, + group_dst: Optional[int] = None, +): + """ + Gathers picklable objects from the whole group in a single process. + + Similar to :func:`gather`, but Python objects can be passed in. Note that the + object must be picklable in order to be gathered. + + Args: + obj (Any): Input object. Must be picklable. + object_gather_list (list[Any]): Output list. On the ``dst`` rank, it + should be correctly sized as the size of the group for this + collective and will contain the output. Must be ``None`` on non-dst + ranks. (default is ``None``) + dst (int, optional): Destination rank on global process group (regardless of ``group`` argument). + (If both ``dst`` and ``group_dst`` are None, default is global rank 0) + group: (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. Default is ``None``. + group_dst (int, optional): Destination rank on ``group``. Invalid to specify both ``dst`` and ``group_dst`` + + Returns: + None. On the ``dst`` rank, ``object_gather_list`` will contain the + output of the collective. + + .. note:: Note that this API differs slightly from the gather collective + since it does not provide an async_op handle and thus will be a blocking + call. + + .. note:: For NCCL-based processed groups, internal tensor representations + of objects must be moved to the GPU device before communication takes + place. In this case, the device used is given by + ``core.cuda.current_device()`` and it is the user's responsiblity to + ensure that this is set so that each rank has an individual GPU, via + ``core.cuda.set_device()``. + + .. warning:: + :func:`gather_object` uses ``pickle`` module implicitly, which is + known to be insecure. It is possible to construct malicious pickle data + which will execute arbitrary code during unpickling. Only call this + function with data you trust. + + .. warning:: + Calling :func:`gather_object` with GPU tensors is not well supported + and inefficient as it incurs GPU -> CPU transfer since tensors would be + pickled. Please consider using :func:`gather` instead. + + Example:: + >>> # xdoctest: +SKIP("need process group init") + >>> # Note: Process group initialization omitted on each rank. + >>> from mindnlp import core.distributed as dist + >>> # Assumes world_size of 3. + >>> gather_objects = ["foo", 12, {1: 2}] # any picklable object + >>> output = [None for _ in gather_objects] + >>> dist.gather_object( + ... gather_objects[dist.get_rank()], + ... output if dist.get_rank() == 0 else None, + ... dst=0 + ... ) + >>> # On rank 0 + >>> output + ['foo', 12, {1: 2}] + """ + group = _group_or_default_group(group) + if dst is None and group_dst is None: + dst = 0 + global_dst = _canonicalize_group_rank(group, dst, group_dst, return_global=True) + if _rank_not_in_group(group): + _warn_not_in_group("gather_object") + return + + # Ensure object_gather_list is specified appropriately. + my_global_rank = get_rank() + _validate_output_list_for_rank(my_global_rank, global_dst, object_gather_list) + # current_device = _get_object_coll_device(group) + input_tensor, local_size = _object_to_tensor(obj, None, group) + # Gather all local sizes. This is so that we can find the max size, and index + # until the correct size when deserializing the tensors. + group_size = get_world_size(group=group) + object_sizes_tensor = core.zeros( + group_size, dtype=core.int32 + ) + object_size_list = [ + object_sizes_tensor[i].unsqueeze(dim=0) for i in range(group_size) + ] + # Allgather tensor sizes. An all-gather is needed here despite this being a + # gather, since each rank needs to broadcast a tensor of the same (maximal) + # size. + object_size_list, _ = all_gather(object_size_list, local_size, group=group) + max_object_size = int(max(object_size_list).item()) # type: ignore[type-var] + # Resize tensor to max size across all ranks. + if max_object_size - input_tensor.shape[0] > 0: + input_tensor = core.concat([input_tensor, core.zeros(max_object_size - input_tensor.shape[0], dtype=input_tensor.dtype)]) + + # Avoid populating output tensors if the result won't be gathered on this rank. + if my_global_rank == global_dst: + coalesced_output_tensor = core.empty( + max_object_size * group_size, dtype=core.int8 + ) + # Output tensors are nonoverlapping views of coalesced_output_tensor + output_tensors = [ + coalesced_output_tensor[max_object_size * i : max_object_size * (i + 1)] + for i in range(group_size) + ] + # All ranks call gather with equal-sized tensors. + output_tensors = gather( + input_tensor, + gather_list=output_tensors if my_global_rank == global_dst else None, # type: ignore[possibly-undefined] + dst=global_dst, + group=group, + ) + if my_global_rank != global_dst: + return + + assert object_gather_list is not None, "Must provide object_gather_list on dst rank" + for i, tensor in enumerate(output_tensors): + tensor = tensor.type(core.int8) + tensor_size = object_size_list[i] + object_gather_list[i] = _tensor_to_object(tensor, tensor_size, group) + + +@_exception_logger +def send_object_list( + object_list: List[Any], + dst: Optional[int] = None, + group: Optional[ProcessGroup] = None, + device: Optional[core.device] = None, + group_dst: Optional[int] = None, +): + """ + Sends picklable objects in ``object_list`` synchronously. + + Similar to :func:`send`, but Python objects can be passed in. + Note that all objects in ``object_list`` must be picklable in order to be + sent. + + Args: + object_list (List[Any]): List of input objects to sent. + Each object must be picklable. Receiver must provide lists of equal sizes. + dst (int): Destination rank to send ``object_list`` to. + Destination rank is based on global process group (regardless of ``group`` argument) + group: (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. Default is ``None``. + device (``core.device``, optional): If not None, the objects are + serialized and converted to tensors which are moved to the + ``device`` before sending. Default is ``None``. + group_dst (int, optional): Destination rank on ``group``. + Must specify one of ``dst`` and ``group_dst`` but not both + Returns: + ``None``. + + .. note:: For NCCL-based process groups, internal tensor representations + of objects must be moved to the GPU device before communication takes + place. In this case, the device used is given by + ``core.cuda.current_device()`` and it is the user's responsibility to + ensure that this is set so that each rank has an individual GPU, via + ``core.cuda.set_device()``. + + .. warning:: + :func:`send_object_list` uses ``pickle`` module implicitly, which + is known to be insecure. It is possible to construct malicious pickle + data which will execute arbitrary code during unpickling. Only call this + function with data you trust. + + .. warning:: + Calling :func:`send_object_list` with GPU tensors is not well supported + and inefficient as it incurs GPU -> CPU transfer since tensors would be + pickled. Please consider using :func:`send` instead. + + Example:: + >>> # xdoctest: +SKIP("need process group init") + >>> # Note: Process group initialization omitted on each rank. + >>> from mindnlp import core.distributed as dist + >>> # Assumes backend is not NCCL + >>> device = core.device("cpu") + >>> if dist.get_rank() == 0: + >>> # Assumes world_size of 2. + >>> objects = ["foo", 12, {1: 2}] # any picklable object + >>> dist.send_object_list(objects, dst=1, device=device) + >>> else: + >>> objects = [None, None, None] + >>> dist.recv_object_list(objects, src=0, device=device) + >>> objects + ['foo', 12, {1: 2}] + """ + group = _group_or_default_group(group) + group_dst = _canonicalize_group_rank(group, dst, group_dst) + _check_not_self_rank(group, group_dst, "destination") + + if _rank_not_in_group(group): + _warn_not_in_group("send_object_list") + return + + # Current device selection. + # To preserve backwards compatibility, ``device`` is default to ``None`` + # in which case we run current logic of device selection, i.e. + # ``current_device`` is CUDA if backend is NCCL otherwise CPU device. In the + # case it is not ``None`` we move the size and object tensors to be + # sent to this device. + current_device = device or _get_object_coll_device(group) + # Serialize object_list elements to tensors on src rank. + tensor_list, size_list = zip( + *[_object_to_tensor(obj, current_device, group) for obj in object_list] + ) + object_sizes_tensor = core.cat(size_list) + + # Send object sizes + send(object_sizes_tensor, group_dst=group_dst, group=group) + + # Concatenate and send serialized object tensors + # Note: core.cat will do an extra memory copy to the current device, if the tensor_list + # has only one element, we can skip the copy. + if len(tensor_list) == 1: # type: ignore[possibly-undefined] + object_tensor = tensor_list[0] + else: + object_tensor = core.cat(tensor_list) + + send(object_tensor, group_dst=group_dst, group=group) + + +@_exception_logger +def recv_object_list( + object_list: List[Any], + src: Optional[int] = None, + group: Optional[ProcessGroup] = None, + device: Optional[core.device] = None, + group_src: Optional[int] = None, +): + """ + Receives picklable objects in ``object_list`` synchronously. + + Similar to :func:`recv`, but can receive Python objects. + + Args: + object_list (List[Any]): List of objects to receive into. + Must provide a list of sizes equal to the size of the list being sent. + src (int, optional): Source rank from which to recv ``object_list``. + Source rank is based on global process group (regardless of ``group`` argument) + Will receive from any rank if set to None. Default is ``None``. + group: (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. Default is ``None``. + device (``core.device``, optional): If not None, receives on this device. + Default is ``None``. + group_src (int, optional): Destination rank on ``group``. Invalid to specify both ``src`` and ``group_src``. + + Returns: + Sender rank. -1 if rank is not part of the group. If rank is part of the group, + ``object_list`` will contain the sent objects from ``src`` rank. + + .. note:: For NCCL-based process groups, internal tensor representations + of objects must be moved to the GPU device before communication takes + place. In this case, the device used is given by + ``core.cuda.current_device()`` and it is the user's responsibility to + ensure that this is set so that each rank has an individual GPU, via + ``core.cuda.set_device()``. + + .. warning:: + :func:`recv_object_list` uses ``pickle`` module implicitly, which + is known to be insecure. It is possible to construct malicious pickle + data which will execute arbitrary code during unpickling. Only call this + function with data you trust. + + .. warning:: + Calling :func:`recv_object_list` with GPU tensors is not well supported + and inefficient as it incurs GPU -> CPU transfer since tensors would be + pickled. Please consider using :func:`recv` instead. + + Example:: + >>> # xdoctest: +SKIP("need process group init") + >>> # Note: Process group initialization omitted on each rank. + >>> from mindnlp import core.distributed as dist + >>> # Assumes backend is not NCCL + >>> device = core.device("cpu") + >>> if dist.get_rank() == 0: + >>> # Assumes world_size of 2. + >>> objects = ["foo", 12, {1: 2}] # any picklable object + >>> dist.send_object_list(objects, dst=1, device=device) + >>> else: + >>> objects = [None, None, None] + >>> dist.recv_object_list(objects, src=0, device=device) + >>> objects + ['foo', 12, {1: 2}] + """ + if _rank_not_in_group(group): + _warn_not_in_group("recv_object_list") + return -1 + + # Current device selection. + # To preserve backwards compatibility, ``device`` is default to ``None`` + # in which case we run current logic of device selection, i.e. + # ``current_device`` is CUDA if backend is NCCL otherwise CPU device. In the + # case it is not ``None`` we move the size and object tensors to be + # received to this device. + current_device = device or _get_object_coll_device(group) + object_sizes_tensor = core.empty( + len(object_list), dtype=core.long, device=current_device + ) + + # Receive object sizes + rank_sizes = recv(object_sizes_tensor, src=src, group=group, group_src=group_src) + + # Tensor to receive serialized objects into. + object_tensor = core.empty( # type: ignore[call-overload] + core.sum(object_sizes_tensor).item(), # type: ignore[arg-type] + dtype=core.int8, + device=current_device, + ) + + rank_objects = recv(object_tensor, src=src, group=group, group_src=group_src) + assert ( + rank_sizes == rank_objects + ), "Mismatch in return ranks for object sizes and objects." + # Deserialize objects using their stored sizes. + offset = 0 + for i, obj_size in enumerate(object_sizes_tensor): + obj_view = object_tensor[offset : offset + obj_size] + obj_view = obj_view.type(core.int8) + offset += obj_size + object_list[i] = _tensor_to_object(obj_view, obj_size, group) + return rank_objects + + +@_exception_logger +def broadcast_object_list( + object_list: List[Any], + src: Optional[int] = None, + group: Optional[ProcessGroup] = None, + device: Optional[core.device] = None, + group_src: Optional[int] = None, +): + """ + Broadcasts picklable objects in ``object_list`` to the whole group. + + Similar to :func:`broadcast`, but Python objects can be passed in. + Note that all objects in ``object_list`` must be picklable in order to be + broadcasted. + + Args: + object_list (List[Any]): List of input objects to broadcast. + Each object must be picklable. Only objects on the ``src`` rank will + be broadcast, but each rank must provide lists of equal sizes. + src (int): Source rank from which to broadcast ``object_list``. + Source rank is based on global process group (regardless of ``group`` argument) + group: (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. Default is ``None``. + device (``core.device``, optional): If not None, the objects are + serialized and converted to tensors which are moved to the + ``device`` before broadcasting. Default is ``None``. + group_src (int): Source rank on ``group``. Must not specify one of ``group_src`` + and ``src`` but not both. + + Returns: + ``None``. If rank is part of the group, ``object_list`` will contain the + broadcasted objects from ``src`` rank. + + .. note:: For NCCL-based process groups, internal tensor representations + of objects must be moved to the GPU device before communication takes + place. In this case, the device used is given by + ``core.cuda.current_device()`` and it is the user's responsibility to + ensure that this is set so that each rank has an individual GPU, via + ``core.cuda.set_device()``. + + .. note:: Note that this API differs slightly from the :func:`broadcast` + collective since it does not provide an ``async_op`` handle and thus + will be a blocking call. + + .. warning:: + :func:`broadcast_object_list` uses ``pickle`` module implicitly, which + is known to be insecure. It is possible to construct malicious pickle + data which will execute arbitrary code during unpickling. Only call this + function with data you trust. + + .. warning:: + Calling :func:`broadcast_object_list` with GPU tensors is not well supported + and inefficient as it incurs GPU -> CPU transfer since tensors would be + pickled. Please consider using :func:`broadcast` instead. + + Example:: + >>> # xdoctest: +SKIP("need process group init") + >>> # Note: Process group initialization omitted on each rank. + >>> from mindnlp import core.distributed as dist + >>> if dist.get_rank() == 0: + >>> # Assumes world_size of 3. + >>> objects = ["foo", 12, {1: 2}] # any picklable object + >>> else: + >>> objects = [None, None, None] + >>> # Assumes backend is not NCCL + >>> device = core.device("cpu") + >>> dist.broadcast_object_list(objects, src=0, device=device) + >>> objects + ['foo', 12, {1: 2}] + """ + group = _group_or_default_group(group) + if src is None and group_src is None: + src = 0 + global_src = _canonicalize_group_rank(group, src, group_src, return_global=True) + if _rank_not_in_group(group): + _warn_not_in_group("broadcast_object_list") + return + + # Current device selection. + # To preserve backwards compatibility, ``device`` is default to ``None`` + # in which case we run current logic of device selection, i.e. + # ``current_device`` is CUDA if backend is NCCL otherwise CPU device. In the + # case it is not ``None`` we move the size and object tensors to be + # broadcasted to this device. + current_device = device or _get_object_coll_device(group) + my_global_rank = get_rank() + # Serialize object_list elements to tensors on src rank. + if my_global_rank == global_src: + tensor_list, size_list = zip( + *[_object_to_tensor(obj, current_device, group) for obj in object_list] + ) + object_sizes_tensor = core.cat(size_list) + else: + object_sizes_tensor = core.empty( + len(object_list), dtype=core.long, device=current_device + ) + + # Broadcast object sizes + broadcast(object_sizes_tensor, src=global_src, group=group) + + # Concatenate and broadcast serialized object tensors + # Note: core.cat will do an extra memory copy to the current device, if the tensor_list + # has only one element, we can skip the copy. + if my_global_rank == global_src: + if len(tensor_list) == 1: # type: ignore[possibly-undefined] + object_tensor = tensor_list[0] + else: + object_tensor = core.cat(tensor_list) + else: + object_tensor = core.empty( # type: ignore[call-overload] + core.sum(object_sizes_tensor).item(), # type: ignore[arg-type] + dtype=core.int8, + device=current_device, + ) + + broadcast(object_tensor, src=global_src, group=group) + # Deserialize objects using their stored sizes. + offset = 0 + if my_global_rank != global_src: + for i, obj_size in enumerate(object_sizes_tensor): + obj_view = object_tensor[offset : offset + obj_size] + obj_view = obj_view.type(core.int8) + offset += obj_size + object_list[i] = _tensor_to_object(obj_view, obj_size, group) + + +@_exception_logger +def scatter_object_list( + scatter_object_output_list: List[Any], + scatter_object_input_list: Optional[List[Any]] = None, + src: Optional[int] = None, + group: Optional[ProcessGroup] = None, + group_src: Optional[int] = None, +): + """ + Scatters picklable objects in ``scatter_object_input_list`` to the whole group. + + Similar to :func:`scatter`, but Python objects can be passed in. On + each rank, the scattered object will be stored as the first element of + ``scatter_object_output_list``. Note that all objects in + ``scatter_object_input_list`` must be picklable in order to be scattered. + + Args: + scatter_object_output_list (List[Any]): Non-empty list whose first + element will store the object scattered to this rank. + scatter_object_input_list (List[Any], optional): List of input objects to scatter. + Each object must be picklable. Only objects on the ``src`` rank will + be scattered, and the argument can be ``None`` for non-src ranks. + src (int): Source rank from which to scatter ``scatter_object_input_list``. + Source rank is based on global process group (regardless of ``group`` argument). + (If both ``src`` and ``group_src`` are None, default is global rank 0) + group: (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. Default is ``None``. + group_src (int, optional): Source rank on ``group``. Invalid to specify both ``src`` and ``group_src`` + + Returns: + ``None``. If rank is part of the group, ``scatter_object_output_list`` + will have its first element set to the scattered object for this rank. + + .. note:: Note that this API differs slightly from the scatter collective + since it does not provide an ``async_op`` handle and thus will be a + blocking call. + + .. warning:: + :func:`scatter_object_list` uses ``pickle`` module implicitly, which + is known to be insecure. It is possible to construct malicious pickle + data which will execute arbitrary code during unpickling. Only call this + function with data you trust. + + .. warning:: + Calling :func:`scatter_object_list` with GPU tensors is not well supported + and inefficient as it incurs GPU -> CPU transfer since tensors would be + pickled. Please consider using :func:`scatter` instead. + + Example:: + >>> # xdoctest: +SKIP("need process group init") + >>> # Note: Process group initialization omitted on each rank. + >>> from mindnlp import core.distributed as dist + >>> if dist.get_rank() == 0: + >>> # Assumes world_size of 3. + >>> objects = ["foo", 12, {1: 2}] # any picklable object + >>> else: + >>> # Can be any list on non-src ranks, elements are not used. + >>> objects = [None, None, None] + >>> output_list = [None] + >>> dist.scatter_object_list(output_list, objects, src=0) + >>> # Rank i gets objects[i]. For example, on rank 2: + >>> output_list + [{1: 2}] + """ + group = _group_or_default_group(group) + if src is None and group_src is None: + src = 0 + global_src = _canonicalize_group_rank(group, src, group_src, return_global=True) + if _rank_not_in_group(group): + _warn_not_in_group("scatter_object_list") + return + + if ( + not isinstance(scatter_object_output_list, list) + or len(scatter_object_output_list) < 1 + ): + raise ValueError( + "Expected argument scatter_object_output_list to be a list of size at least 1." + ) + + my_global_rank = get_rank() + # pg_device = _get_object_coll_device(group) + if my_global_rank == global_src: + if scatter_object_input_list is None: + raise ValueError( + "source rank must provide non-None scatter_object_input_list" + ) + tensor_list, tensor_sizes = zip( + *[ + _object_to_tensor(obj, None, group) + for obj in scatter_object_input_list + ] + ) + tensor_list, tensor_sizes = list(tensor_list), list(tensor_sizes) + + # Src rank broadcasts the maximum tensor size. This is because all ranks are + # expected to call into scatter() with equal-sized tensors. + max_tensor_size = max(tensor_sizes) # type: ignore[possibly-undefined] + for i in range(len(tensor_list)): # type: ignore[possibly-undefined] + # tensor.resize_(max_tensor_size) + tensor = tensor_list[i] + if max_tensor_size - tensor.shape[0] > 0: + tensor = core.concat([tensor, core.zeros(max_tensor_size.item() - tensor.shape[0], dtype=tensor.dtype)]) + tensor_list[i] = tensor + else: + max_tensor_size = core.tensor([0], dtype=core.long) + max_tensor_size = broadcast(max_tensor_size, src=global_src, group=group) + + # Scatter actual serialized objects + output_tensor = core.empty( + max_tensor_size.item(), dtype=core.int8 + ) + output_tensor = scatter( + output_tensor, + scatter_list=None if my_global_rank != global_src else tensor_list, # type: ignore[possibly-undefined] + src=global_src, + group=group, + ) + # Scatter per-object sizes to trim tensors when deserializing back to object + obj_tensor_size = core.tensor([0], dtype=core.int32) + obj_tensor_size = scatter( + obj_tensor_size, + scatter_list=None if my_global_rank != global_src else tensor_sizes, # type: ignore[possibly-undefined] + src=global_src, + group=group, + ) + + # Deserialize back to object + scatter_object_output_list[0] = _tensor_to_object( + output_tensor, obj_tensor_size, group + ) + + +@_exception_logger +def all_gather(tensor_list, tensor, group=None, async_op=False): + """ + Gathers tensors from the whole group in a list. + + Complex and uneven sized tensors are supported. + + Args: + tensor_list (list[Tensor]): Output list. It should contain + correctly-sized tensors to be used for output of the collective. + Uneven sized tensors are supported. + tensor (Tensor): Tensor to be broadcast from current process. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + async_op (bool, optional): Whether this op should be an async op + + Returns: + Async work handle, if async_op is set to True. + None, if not async_op or if not part of the group + + Examples: + >>> # xdoctest: +SKIP("need process group init") + >>> # All tensors below are of core.int64 dtype. + >>> # We have 2 process groups, 2 ranks. + >>> device = core.device(f'cuda:{rank}') + >>> tensor_list = [core.zeros(2, dtype=core.int64, device=device) for _ in range(2)] + >>> tensor_list + [tensor([0, 0], device='cuda:0'), tensor([0, 0], device='cuda:0')] # Rank 0 + [tensor([0, 0], device='cuda:1'), tensor([0, 0], device='cuda:1')] # Rank 1 + >>> tensor = core.arange(2, dtype=core.int64, device=device) + 1 + 2 * rank + >>> tensor + tensor([1, 2], device='cuda:0') # Rank 0 + tensor([3, 4], device='cuda:1') # Rank 1 + >>> dist.all_gather(tensor_list, tensor) + >>> tensor_list + [tensor([1, 2], device='cuda:0'), tensor([3, 4], device='cuda:0')] # Rank 0 + [tensor([1, 2], device='cuda:1'), tensor([3, 4], device='cuda:1')] # Rank 1 + + >>> # All tensors below are of core.cfloat dtype. + >>> # We have 2 process groups, 2 ranks. + >>> tensor_list = [core.zeros(2, dtype=core.cfloat, device=device) for _ in range(2)] + >>> tensor_list + [tensor([0.+0.j, 0.+0.j], device='cuda:0'), tensor([0.+0.j, 0.+0.j], device='cuda:0')] # Rank 0 + [tensor([0.+0.j, 0.+0.j], device='cuda:1'), tensor([0.+0.j, 0.+0.j], device='cuda:1')] # Rank 1 + >>> tensor = core.tensor([1+1j, 2+2j], dtype=core.cfloat, device=device) + 2 * rank * (1+1j) + >>> tensor + tensor([1.+1.j, 2.+2.j], device='cuda:0') # Rank 0 + tensor([3.+3.j, 4.+4.j], device='cuda:1') # Rank 1 + >>> dist.all_gather(tensor_list, tensor) + >>> tensor_list + [tensor([1.+1.j, 2.+2.j], device='cuda:0'), tensor([3.+3.j, 4.+4.j], device='cuda:0')] # Rank 0 + [tensor([1.+1.j, 2.+2.j], device='cuda:1'), tensor([3.+3.j, 4.+4.j], device='cuda:1')] # Rank 1 + + """ + _check_tensor_list(tensor_list, "tensor_list") + _check_single_tensor(tensor, "tensor") + _ensure_all_tensors_same_dtype(tensor_list, tensor) + if _rank_not_in_group(group): + _warn_not_in_group("all_gather") + return + + tensor_list = [ + t if not t.is_complex() else core.view_as_real(t) for t in tensor_list + ] + tensor = tensor if not tensor.is_complex() else core.view_as_real(tensor) + + group = group or _get_default_group() + work = group.allgather([tensor_list], [tensor]) + + if async_op: + return work + else: + work.wait() + + +@_exception_logger +def all_gather_into_tensor(output_tensor, input_tensor, group=None, async_op=False): + """ + Gather tensors from all ranks and put them in a single output tensor. + + This function requires all tensors to be the same size on each process. + + Args: + output_tensor (Tensor): Output tensor to accommodate tensor elements + from all ranks. It must be correctly sized to have one of the + following forms: + (i) a concatenation of all the input tensors along the primary + dimension; for definition of "concatenation", see ``core.cat()``; + (ii) a stack of all the input tensors along the primary dimension; + for definition of "stack", see ``core.stack()``. + Examples below may better explain the supported output forms. + input_tensor (Tensor): Tensor to be gathered from current rank. + Different from the ``all_gather`` API, the input tensors in this + API must have the same size across all ranks. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + async_op (bool, optional): Whether this op should be an async op + + Returns: + Async work handle, if async_op is set to True. + None, if not async_op or if not part of the group + + Examples: + >>> # xdoctest: +SKIP("need process group init") + >>> # All tensors below are of core.int64 dtype and on CUDA devices. + >>> # We have two ranks. + >>> device = core.device(f'cuda:{rank}') + >>> tensor_in = core.arange(2, dtype=core.int64, device=device) + 1 + 2 * rank + >>> tensor_in + tensor([1, 2], device='cuda:0') # Rank 0 + tensor([3, 4], device='cuda:1') # Rank 1 + >>> # Output in concatenation form + >>> tensor_out = core.zeros(world_size * 2, dtype=core.int64, device=device) + >>> dist.all_gather_into_tensor(tensor_out, tensor_in) + >>> tensor_out + tensor([1, 2, 3, 4], device='cuda:0') # Rank 0 + tensor([1, 2, 3, 4], device='cuda:1') # Rank 1 + >>> # Output in stack form + >>> tensor_out2 = core.zeros(world_size, 2, dtype=core.int64, device=device) + >>> dist.all_gather_into_tensor(tensor_out2, tensor_in) + >>> tensor_out2 + tensor([[1, 2], + [3, 4]], device='cuda:0') # Rank 0 + tensor([[1, 2], + [3, 4]], device='cuda:1') # Rank 1 + + .. warning:: + The Gloo backend does not support this API. + + """ + _check_single_tensor(input_tensor, "input_tensor") + _check_single_tensor(output_tensor, "output_tensor") + if _rank_not_in_group(group): + _warn_not_in_group("all_gather_into_tensor") + return + + output_tensor = ( + output_tensor + if not output_tensor.is_complex() + else core.view_as_real(output_tensor) + ) + input_tensor = ( + input_tensor + if not input_tensor.is_complex() + else core.view_as_real(input_tensor) + ) + + opts = AllgatherOptions() + opts.asyncOp = async_op + + group = group or _get_default_group() + + if group in _world.pg_coalesce_state.keys(): + # We are in coalescing context, do not issue single operation, just append a collective representation + coll = _CollOp(all_gather_into_tensor, input_tensor, output_tensor) + _world.pg_coalesce_state[group].append(coll) + if async_op: + return _IllegalWork() + else: + return None + + work = group._allgather_base(output_tensor, input_tensor, opts) + + if async_op: + return work + else: + work.wait() + + +@_exception_logger +@deprecated( + "`core.distributed._all_gather_base` is a private function and will be deprecated. " + "Please use `core.distributed.all_gather_into_tensor` instead.", + category=FutureWarning, +) +def _all_gather_base(output_tensor, input_tensor, group=None, async_op=False): + """ + Single tensor all gather. Gathers a single tensor from all ranks, and puts them in a single output tensor. + + Args: + output_tensor (Tensor): Output tensor. It should contain + correctly-sized tensors to be used for output of the collective. + input_tensor (Tensor): Tensor to be broadcast from current process. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + async_op (bool, optional): Whether this op should be an async op + + Returns: + Async work handle, if async_op is set to True. + None, if not async_op or if not part of the group + + .. warning:: + `_all_gather_base` is a private function. Users should use + `all_gather_into_tensor` instead. + + """ + return all_gather_into_tensor(output_tensor, input_tensor, group, async_op) + + +@_exception_logger +@deprecated( + "`core.distributed.all_gather_coalesced` will be deprecated. If you must use it, " + "please revisit our documentation later at " + "https://pycore.org/docs/main/distributed.html#collective-functions", + category=FutureWarning, +) +def all_gather_coalesced( + output_tensor_lists, input_tensor_list, group=None, async_op=False +): + """ + Gathers input tensors from the whole group in a list in a coalesced manner. + + Complex tensors are supported. + + Args: + output_tensor_lists (list[list[Tensor]]): Output list. It should contain + correctly-sized tensors to be used for output of the collective. + input_tensor_list (list[Tensor]): Tensors to be broadcast from + current process. At least one tensor has to be non empty. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + async_op (bool, optional): Whether this op should be an async op. + + Returns: + Async work handle, if async_op is set to True. + None, if not async_op or if not part of the group + + Example: + we have 2 process groups, 2 ranks. + rank 0 passes: + input_tensor_list = [[[1, 1], [1, 1]], [2], [3, 3]] + output_tensor_lists = + [[[[-1, -1], [-1, -1]], [-1], [-1, -1]], + [[[-1, -1], [-1, -1]], [-1], [-1, -1]]] + rank 1 passes: + input_tensor_list = [[[3, 3], [3, 3]], [5], [1, 1]] + output_tensor_lists = + [[[[-1, -1], [-1, -1]], [-1], [-1, -1]], + [[[-1, -1], [-1, -1]], [-1], [-1, -1]]] + both rank 0 and 1 get: + output_tensor_lists = + [[[1, 1], [1, 1]], [2], [3, 3]], + [[3, 3], [3, 3]], [5], [1, 1]]]. + + WARNING: at this time individual shape checking is not implemented across nodes. + For example, if the rank 0 node passes [core.rand(4), core.rand(2)] and the + rank 1 node passes [core.rand(2), core.rand(2), core.rand(2)], the + all_gather_coalesced operation will proceed without complaint and return + erroneous outputs. This lack of shape checking results in significant + performance improvements but users of this function should take extra care + to ensure that each node passes in tensors whose shapes match across nodes. + """ + # We only check basic compatibility with C++ params here, C++ code will + # do shape and type checking. + if _rank_not_in_group(group): + _warn_not_in_group("all_gather_coalesced") + return + _check_tensor_list(input_tensor_list, "input_tensor_list") + _ensure_all_tensors_same_dtype(input_tensor_list) + if not isinstance(output_tensor_lists, list): + raise TypeError( + "Invalid function argument: output_tensor_lists should be a list" + ) + for output_tensor_list in output_tensor_lists: + _check_tensor_list(output_tensor_list, "output_tensor_lists") + _ensure_all_tensors_same_dtype(output_tensor_list) + + output_tensor_lists = [ + [t if not t.is_complex() else core.view_as_real(t) for t in l] + for l in output_tensor_lists + ] + input_tensor_list = [ + t if not t.is_complex() else core.view_as_real(t) for t in input_tensor_list + ] + + group = group or _get_default_group() + work = group.allgather_coalesced(output_tensor_lists, input_tensor_list) + + if async_op: + return work.get_future() + else: + work.wait() + + +def _validate_output_list_for_rank(my_rank, dst, gather_list): + if dst == my_rank: + if not gather_list: + raise ValueError( + "Argument ``gather_list`` must be specified on destination rank." + ) + elif gather_list: + raise ValueError( + "Argument ``gather_list`` must NOT be specified " + "on non-destination ranks." + ) + + +@_exception_logger +def gather( + tensor: core.Tensor, + gather_list: Optional[List[core.Tensor]] = None, + dst: Optional[int] = None, + group: Optional[ProcessGroup] = None, + async_op: bool = False, + group_dst: Optional[int] = None, +): + """ + Gathers a list of tensors in a single process. + + This function requires all tensors to be the same size on each process. + + Args: + tensor (Tensor): Input tensor. + gather_list (list[Tensor], optional): List of appropriately, + same-sized tensors to use for gathered data + (default is None, must be specified on the destination rank) + dst (int, optional): Destination rank on global process group (regardless of ``group`` argument). + (If both ``dst`` and ``group_dst`` are None, default is global rank 0) + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + async_op (bool, optional): Whether this op should be an async op + group_dst (int, optional): Destination rank on ``group``. Invalid to specify both ``dst`` and ``group_dst`` + + Returns: + Async work handle, if async_op is set to True. + None, if not async_op or if not part of the group + + .. note:: Note that all Tensors in gather_list must have the same size. + + Example:: + >>> # xdoctest: +SKIP("no rank") + >>> # We have 2 process groups, 2 ranks. + >>> tensor_size = 2 + >>> device = core.device(f'cuda:{rank}') + >>> tensor = core.ones(tensor_size, device=device) + rank + >>> if dist.get_rank() == 0: + >>> gather_list = [core.zeros_like(tensor, device=device) for i in range(2)] + >>> else: + >>> gather_list = None + >>> dist.gather(tensor, gather_list, dst=0) + >>> # Rank 0 gets gathered data. + >>> gather_list + [tensor([1., 1.], device='cuda:0'), tensor([2., 2.], device='cuda:0')] # Rank 0 + None # Rank 1 + + """ + _check_single_tensor(tensor, "tensor") + + # Parameter ``gather_list`` may be left unspecified on non-dst ranks. + if gather_list: + _check_tensor_list(gather_list, "gather_list") + else: + gather_list = [] + _ensure_all_tensors_same_dtype(tensor, gather_list) + group = _group_or_default_group(group) + if _rank_not_in_group(group): + _warn_not_in_group("gather") + return + if dst is None and group_dst is None: + dst = 0 + global_dst = _canonicalize_group_rank(group, dst, group_dst, return_global=True) + group_dst = _canonicalize_group_rank(group, dst, group_dst, return_global=False) + my_global_rank = get_rank() + _validate_output_list_for_rank(my_global_rank, global_dst, gather_list) + output_tensors = [gather_list] if global_dst == my_global_rank else [] + input_tensors = [tensor] + + opts = GatherOptions() + opts.rootRank = global_dst + opts.groupRank = group_dst + work = group.gather(output_tensors, input_tensors, opts) + + if async_op: + return work + else: + work.wait() + + +@_exception_logger +def scatter( + tensor: core.Tensor, + scatter_list: Optional[List[core.Tensor]] = None, + src: Optional[int] = None, + group: Optional[ProcessGroup] = None, + async_op: bool = False, + group_src: Optional[int] = None, +): + """ + Scatters a list of tensors to all processes in a group. + + Each process will receive exactly one tensor and store its data in the + ``tensor`` argument. + + Complex tensors are supported. + + Args: + tensor (Tensor): Output tensor. + scatter_list (list[Tensor]): List of tensors to scatter (default is + None, must be specified on the source rank) + src (int): Source rank on global process group (regardless of ``group`` argument). + (If both ``src`` and ``group_src`` are None, default is global rank 0) + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + async_op (bool, optional): Whether this op should be an async op + group_src (int, optional): Source rank on ``group``. Invalid to specify both ``src`` and ``group_src`` + + Returns: + Async work handle, if async_op is set to True. + None, if not async_op or if not part of the group + + .. note:: Note that all Tensors in scatter_list must have the same size. + + Example:: + >>> # xdoctest: +SKIP("need process group init") + >>> # Note: Process group initialization omitted on each rank. + >>> from mindnlp import core.distributed as dist + >>> tensor_size = 2 + >>> device = core.device(f'cuda:{rank}') + >>> output_tensor = core.zeros(tensor_size, device=device) + >>> if dist.get_rank() == 0: + >>> # Assumes world_size of 2. + >>> # Only tensors, all of which must be the same size. + >>> t_ones = core.ones(tensor_size, device=device) + >>> t_fives = core.ones(tensor_size, device=device) * 5 + >>> scatter_list = [t_ones, t_fives] + >>> else: + >>> scatter_list = None + >>> dist.scatter(output_tensor, scatter_list, src=0) + >>> # Rank i gets scatter_list[i]. + >>> output_tensor + tensor([1., 1.], device='cuda:0') # Rank 0 + tensor([5., 5.], device='cuda:1') # Rank 1 + + """ + _check_single_tensor(tensor, "tensor") + # Parameter ``scatter_list`` may be left unspecified on non-src ranks. + if scatter_list: + _check_tensor_list(scatter_list, "scatter_list") + else: + scatter_list = [] + _ensure_all_tensors_same_dtype(tensor, scatter_list) + group = _group_or_default_group(group) + if src is None and group_src is None: + src = 0 + global_src = _canonicalize_group_rank(group, src, group_src, return_global=True) + group_src = _canonicalize_group_rank(group, src, group_src, return_global=False) + if _rank_not_in_group(group): + _warn_not_in_group("scatter") + return + scatter_list = [ + t if not t.is_complex() else core.view_as_real(t) for t in scatter_list + ] + tensor = tensor if not tensor.is_complex() else core.view_as_real(tensor) + + my_global_rank = get_rank() + if global_src == my_global_rank: + if not scatter_list: + raise ValueError( + "Argument ``scatter_list`` must be specified on source rank." + ) + input_tensors = [scatter_list] + output_tensors = [tensor] + else: + if scatter_list: + raise ValueError( + "Argument ``scatter_list`` must NOT be specified " + "on non-source ranks." + ) + input_tensors = [] + output_tensors = [tensor] + + opts = ScatterOptions() + opts.rootRank = global_src + opts.groupRank = group_src + opts.asyncOp = async_op + work = group.scatter(output_tensors, input_tensors, opts) + + if async_op: + return work + else: + work.wait() + + +@_exception_logger +def reduce_scatter(output, input_list, op=ReduceOp.SUM, group=None, async_op=False): + """ + Reduces, then scatters a list of tensors to all processes in a group. + + Args: + output (Tensor): Output tensor. + input_list (list[Tensor]): List of tensors to reduce and scatter. + op (optional): One of the values from + ``core.distributed.ReduceOp`` + enum. Specifies an operation used for element-wise reductions. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + async_op (bool, optional): Whether this op should be an async op. + + Returns: + Async work handle, if async_op is set to True. + None, if not async_op or if not part of the group. + + """ + _check_single_tensor(output, "output") + _check_tensor_list(input_list, "input_list") + _ensure_all_tensors_same_dtype(output, input_list) + if _rank_not_in_group(group): + _warn_not_in_group("reduce_scatter") + return + + opts = ReduceScatterOptions() + opts.reduceOp = op + + group = group or _get_default_group() + work = group.reduce_scatter([output], [input_list], opts) + + if async_op: + return work + else: + work.wait() + + +@_exception_logger +def reduce_scatter_tensor(output, input, op=ReduceOp.SUM, group=None, async_op=False): + """ + Reduces, then scatters a tensor to all ranks in a group. + + Args: + output (Tensor): Output tensor. It should have the same size across all + ranks. + input (Tensor): Input tensor to be reduced and scattered. Its size + should be output tensor size times the world size. The input tensor + can have one of the following shapes: + (i) a concatenation of the output tensors along the primary + dimension, or + (ii) a stack of the output tensors along the primary dimension. + For definition of "concatenation", see ``core.cat()``. + For definition of "stack", see ``core.stack()``. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + async_op (bool, optional): Whether this op should be an async op. + + Returns: + Async work handle, if async_op is set to True. + None, if not async_op or if not part of the group. + + Examples: + >>> # xdoctest: +SKIP("need process group init") + >>> # All tensors below are of core.int64 dtype and on CUDA devices. + >>> # We have two ranks. + >>> device = core.device(f'cuda:{rank}') + >>> tensor_out = core.zeros(2, dtype=core.int64, device=device) + >>> # Input in concatenation form + >>> tensor_in = core.arange(world_size * 2, dtype=core.int64, device=device) + >>> tensor_in + tensor([0, 1, 2, 3], device='cuda:0') # Rank 0 + tensor([0, 1, 2, 3], device='cuda:1') # Rank 1 + >>> dist.reduce_scatter_tensor(tensor_out, tensor_in) + >>> tensor_out + tensor([0, 2], device='cuda:0') # Rank 0 + tensor([4, 6], device='cuda:1') # Rank 1 + >>> # Input in stack form + >>> tensor_in = core.reshape(tensor_in, (world_size, 2)) + >>> tensor_in + tensor([[0, 1], + [2, 3]], device='cuda:0') # Rank 0 + tensor([[0, 1], + [2, 3]], device='cuda:1') # Rank 1 + >>> dist.reduce_scatter_tensor(tensor_out, tensor_in) + >>> tensor_out + tensor([0, 2], device='cuda:0') # Rank 0 + tensor([4, 6], device='cuda:1') # Rank 1 + + .. warning:: + The Gloo backend does not support this API. + + """ + _check_single_tensor(output, "output") + _check_single_tensor(input, "input") + + if _rank_not_in_group(group): + _warn_not_in_group("reduce_scatter_tensor") + return + + opts = ReduceScatterOptions() + opts.reduceOp = op + opts.asyncOp = async_op + + group = group or _get_default_group() + + # Check if we are in coalescing context + # If we are, do not issue single operation, just append a collective representation + if group in _world.pg_coalesce_state.keys(): + coll = _CollOp(reduce_scatter_tensor, input, output, op, None) + _world.pg_coalesce_state[group].append(coll) + if async_op: + return _IllegalWork() + else: + return None + + work = group._reduce_scatter_base(output, input, opts) + + if async_op: + return work + else: + work.wait() + + +@deprecated( + "`core.distributed._reduce_scatter_base` is a private function and will be deprecated. " + "Please use `core.distributed.reduce_scatter_tensor` instead.", + category=FutureWarning, +) +def _reduce_scatter_base(output, input, op=ReduceOp.SUM, group=None, async_op=False): + """ + Reduces, then scatters a flattened tensor to all processes in a group. + + Args: + output (Tensor): Output tensor. + input (Tensor): Input tensor that is of size output tensor size times world size + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + async_op (bool, optional): Whether this op should be an async op. + + Returns: + Async work handle, if async_op is set to True. + None, if not async_op or if not part of the group. + + .. warning:: + `_reduce_scatter_base` is a private function. Users should use + `reduce_scatter_tensor` instead. + + """ + return reduce_scatter_tensor(output, input, op, group, async_op) + + +@_exception_logger +def all_to_all_single( + output, + input, + output_split_sizes=None, + input_split_sizes=None, + group=None, + async_op=False, +): + """ + Split input tensor and then scatter the split list to all processes in a group. + + Later the received tensors are concatenated from all the processes in the group + and returned as a single output tensor. + + Complex tensors are supported. + + Args: + output (Tensor): Gathered concatenated output tensor. + input (Tensor): Input tensor to scatter. + output_split_sizes: (list[Int], optional): Output split sizes for dim 0 + if specified None or empty, dim 0 of ``output`` tensor must divide + equally by ``world_size``. + input_split_sizes: (list[Int], optional): Input split sizes for dim 0 + if specified None or empty, dim 0 of ``input`` tensor must divide + equally by ``world_size``. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + async_op (bool, optional): Whether this op should be an async op. + + Returns: + Async work handle, if async_op is set to True. + None, if not async_op or if not part of the group. + + .. warning:: + `all_to_all_single` is experimental and subject to change. + + Examples: + >>> # xdoctest: +SKIP("Undefined rank") + >>> input = core.arange(4) + rank * 4 + >>> input + tensor([0, 1, 2, 3]) # Rank 0 + tensor([4, 5, 6, 7]) # Rank 1 + tensor([8, 9, 10, 11]) # Rank 2 + tensor([12, 13, 14, 15]) # Rank 3 + >>> output = core.empty([4], dtype=core.int64) + >>> dist.all_to_all_single(output, input) + >>> output + tensor([0, 4, 8, 12]) # Rank 0 + tensor([1, 5, 9, 13]) # Rank 1 + tensor([2, 6, 10, 14]) # Rank 2 + tensor([3, 7, 11, 15]) # Rank 3 + + >>> # Essentially, it is similar to following operation: + >>> scatter_list = list(input.chunk(world_size)) + >>> gather_list = list(output.chunk(world_size)) + >>> for i in range(world_size): + >>> dist.scatter(gather_list[i], scatter_list if i == rank else [], src = i) + + >>> # Another example with uneven split + >>> input + tensor([0, 1, 2, 3, 4, 5]) # Rank 0 + tensor([10, 11, 12, 13, 14, 15, 16, 17, 18]) # Rank 1 + tensor([20, 21, 22, 23, 24]) # Rank 2 + tensor([30, 31, 32, 33, 34, 35, 36]) # Rank 3 + >>> input_splits + [2, 2, 1, 1] # Rank 0 + [3, 2, 2, 2] # Rank 1 + [2, 1, 1, 1] # Rank 2 + [2, 2, 2, 1] # Rank 3 + >>> output_splits + [2, 3, 2, 2] # Rank 0 + [2, 2, 1, 2] # Rank 1 + [1, 2, 1, 2] # Rank 2 + [1, 2, 1, 1] # Rank 3 + >>> output = ... + >>> dist.all_to_all_single(output, input, output_splits, input_splits) + >>> output + tensor([ 0, 1, 10, 11, 12, 20, 21, 30, 31]) # Rank 0 + tensor([ 2, 3, 13, 14, 22, 32, 33]) # Rank 1 + tensor([ 4, 15, 16, 23, 34, 35]) # Rank 2 + tensor([ 5, 17, 18, 24, 36]) # Rank 3 + + + >>> # Another example with tensors of core.cfloat type. + >>> input = core.tensor([1+1j, 2+2j, 3+3j, 4+4j], dtype=core.cfloat) + 4 * rank * (1+1j) + >>> input + tensor([1+1j, 2+2j, 3+3j, 4+4j]) # Rank 0 + tensor([5+5j, 6+6j, 7+7j, 8+8j]) # Rank 1 + tensor([9+9j, 10+10j, 11+11j, 12+12j]) # Rank 2 + tensor([13+13j, 14+14j, 15+15j, 16+16j]) # Rank 3 + >>> output = core.empty([4], dtype=core.int64) + >>> dist.all_to_all_single(output, input) + >>> output + tensor([1+1j, 5+5j, 9+9j, 13+13j]) # Rank 0 + tensor([2+2j, 6+6j, 10+10j, 14+14j]) # Rank 1 + tensor([3+3j, 7+7j, 11+11j, 15+15j]) # Rank 2 + tensor([4+4j, 8+8j, 12+12j, 16+16j]) # Rank 3 + """ + if _rank_not_in_group(group): + _warn_not_in_group("all_to_all_single") + return + + opts = AllToAllOptions() + _check_single_tensor(output, "output") + _check_single_tensor(input, "input") + _ensure_all_tensors_same_dtype(output, input) + + if input.is_complex(): + input = core.view_as_real(input) + if output.is_complex(): + output = core.view_as_real(output) + + output_split_sizes = [] if output_split_sizes is None else output_split_sizes + input_split_sizes = [] if input_split_sizes is None else input_split_sizes + + group = group or _get_default_group() + work = group.alltoall_base( + output, input, output_split_sizes, input_split_sizes, opts + ) + + if async_op: + return work + else: + work.wait() + + +@_exception_logger +def all_to_all(output_tensor_list, input_tensor_list, group=None, async_op=False): + """ + Scatters list of input tensors to all processes in a group and return gathered list of tensors in output list. + + Complex tensors are supported. + + Args: + output_tensor_list (list[Tensor]): List of tensors to be gathered one + per rank. + input_tensor_list (list[Tensor]): List of tensors to scatter one per rank. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + async_op (bool, optional): Whether this op should be an async op. + + Returns: + Async work handle, if async_op is set to True. + None, if not async_op or if not part of the group. + + .. warning:: + `all_to_all` is experimental and subject to change. + + Examples: + >>> # xdoctest: +SKIP("Undefined rank") + >>> input = core.arange(4) + rank * 4 + >>> input = list(input.chunk(4)) + >>> input + [tensor([0]), tensor([1]), tensor([2]), tensor([3])] # Rank 0 + [tensor([4]), tensor([5]), tensor([6]), tensor([7])] # Rank 1 + [tensor([8]), tensor([9]), tensor([10]), tensor([11])] # Rank 2 + [tensor([12]), tensor([13]), tensor([14]), tensor([15])] # Rank 3 + >>> output = list(core.empty([4], dtype=core.int64).chunk(4)) + >>> dist.all_to_all(output, input) + >>> output + [tensor([0]), tensor([4]), tensor([8]), tensor([12])] # Rank 0 + [tensor([1]), tensor([5]), tensor([9]), tensor([13])] # Rank 1 + [tensor([2]), tensor([6]), tensor([10]), tensor([14])] # Rank 2 + [tensor([3]), tensor([7]), tensor([11]), tensor([15])] # Rank 3 + + >>> # Essentially, it is similar to following operation: + >>> scatter_list = input + >>> gather_list = output + >>> for i in range(world_size): + >>> dist.scatter(gather_list[i], scatter_list if i == rank else [], src=i) + + >>> input + tensor([0, 1, 2, 3, 4, 5]) # Rank 0 + tensor([10, 11, 12, 13, 14, 15, 16, 17, 18]) # Rank 1 + tensor([20, 21, 22, 23, 24]) # Rank 2 + tensor([30, 31, 32, 33, 34, 35, 36]) # Rank 3 + >>> input_splits + [2, 2, 1, 1] # Rank 0 + [3, 2, 2, 2] # Rank 1 + [2, 1, 1, 1] # Rank 2 + [2, 2, 2, 1] # Rank 3 + >>> output_splits + [2, 3, 2, 2] # Rank 0 + [2, 2, 1, 2] # Rank 1 + [1, 2, 1, 2] # Rank 2 + [1, 2, 1, 1] # Rank 3 + >>> input = list(input.split(input_splits)) + >>> input + [tensor([0, 1]), tensor([2, 3]), tensor([4]), tensor([5])] # Rank 0 + [tensor([10, 11, 12]), tensor([13, 14]), tensor([15, 16]), tensor([17, 18])] # Rank 1 + [tensor([20, 21]), tensor([22]), tensor([23]), tensor([24])] # Rank 2 + [tensor([30, 31]), tensor([32, 33]), tensor([34, 35]), tensor([36])] # Rank 3 + >>> output = ... + >>> dist.all_to_all(output, input) + >>> output + [tensor([0, 1]), tensor([10, 11, 12]), tensor([20, 21]), tensor([30, 31])] # Rank 0 + [tensor([2, 3]), tensor([13, 14]), tensor([22]), tensor([32, 33])] # Rank 1 + [tensor([4]), tensor([15, 16]), tensor([23]), tensor([34, 35])] # Rank 2 + [tensor([5]), tensor([17, 18]), tensor([24]), tensor([36])] # Rank 3 + + >>> # Another example with tensors of core.cfloat type. + >>> input = core.tensor([1+1j, 2+2j, 3+3j, 4+4j], dtype=core.cfloat) + 4 * rank * (1+1j) + >>> input = list(input.chunk(4)) + >>> input + [tensor([1+1j]), tensor([2+2j]), tensor([3+3j]), tensor([4+4j])] # Rank 0 + [tensor([5+5j]), tensor([6+6j]), tensor([7+7j]), tensor([8+8j])] # Rank 1 + [tensor([9+9j]), tensor([10+10j]), tensor([11+11j]), tensor([12+12j])] # Rank 2 + [tensor([13+13j]), tensor([14+14j]), tensor([15+15j]), tensor([16+16j])] # Rank 3 + >>> output = list(core.empty([4], dtype=core.int64).chunk(4)) + >>> dist.all_to_all(output, input) + >>> output + [tensor([1+1j]), tensor([5+5j]), tensor([9+9j]), tensor([13+13j])] # Rank 0 + [tensor([2+2j]), tensor([6+6j]), tensor([10+10j]), tensor([14+14j])] # Rank 1 + [tensor([3+3j]), tensor([7+7j]), tensor([11+11j]), tensor([15+15j])] # Rank 2 + [tensor([4+4j]), tensor([8+8j]), tensor([12+12j]), tensor([16+16j])] # Rank 3 + + """ + if _rank_not_in_group(group): + _warn_not_in_group("all_to_all") + return + + opts = AllToAllOptions() + _check_tensor_list(output_tensor_list, "output_tensor_list") + _check_tensor_list(input_tensor_list, "input_tensor_list") + _ensure_all_tensors_same_dtype(output_tensor_list, input_tensor_list) + + input_tensor_list = [ + t if not t.is_complex() else core.view_as_real(t) for t in input_tensor_list + ] + output_tensor_list = [ + t if not t.is_complex() else core.view_as_real(t) for t in output_tensor_list + ] + + group = group or _get_default_group() + work = group.alltoall(output_tensor_list, input_tensor_list, opts) + + if async_op: + return work + else: + work.wait() + + +@_exception_logger +def barrier( + group: Optional[ProcessGroup] = GroupMember.WORLD, async_op=False, device_ids=None +): + """ + Synchronize all processes. + + This collective blocks processes until the whole group enters this function, + if async_op is False, or if async work handle is called on wait(). + + Args: + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + async_op (bool, optional): Whether this op should be an async op + device_ids ([int], optional): List of device/GPU ids. + + Returns: + Async work handle, if async_op is set to True. + None, if not async_op or if not part of the group + + .. note:: `ProcessGroupNCCL` now blocks the cpu thread till the completion of the barrier collective. + """ + if _rank_not_in_group(group): + _warn_not_in_group("barrier") + return + + opts = BarrierOptions() + # opts.device = core.device(_get_object_coll_device(group)) + if device_ids is not None: + if isinstance(device_ids, list): + opts.device_ids = device_ids + else: + raise TypeError( + "Invalid function argument: device_ids type should be List[int]" + ) + + group = group or _get_default_group() + work = group.barrier(opts=opts) + + # wait for new op + if async_op: + return work + else: + work.wait() + + +def monitored_barrier( + group: Optional[ProcessGroup] = GroupMember.WORLD, + timeout=None, + wait_all_ranks=False, +): + """ + Synchronize processes similar to ``core.distributed.barrier``, but consider a configurable timeout. + + It is able to report ranks that did not pass this barrier within the provided timeout. + Specifically, for non-zero ranks, will block until a send/recv is processed from rank 0. + Rank 0 will block until all send /recv from other ranks are processed, and will report + failures for ranks that failed to respond in time. Note that if one rank does not reach the + monitored_barrier (for example due to a hang), all other ranks would fail in monitored_barrier. + + This collective will block all processes/ranks in the group, until the + whole group exits the function successfully, making it useful for debugging + and synchronizing. However, it can have a performance impact and should only + be used for debugging or scenarios that require full synchronization points + on the host-side. For debugging purposes, this barrier can be inserted + before the application's collective calls to check if any ranks are + desynchronized. + + .. note:: Note that this collective is only supported with the GLOO backend. + + Args: + group (ProcessGroup, optional): The process group to work on. If + ``None``, the default process group will be used. + timeout (datetime.timedelta, optional): Timeout for monitored_barrier. + If ``None``, the default process group timeout will be used. + wait_all_ranks (bool, optional): Whether to collect all failed ranks or + not. By default, this is ``False`` and ``monitored_barrier`` on rank 0 + will throw on the first failed rank it encounters in order to fail + fast. By setting ``wait_all_ranks=True`` ``monitored_barrier`` will + collect all failed ranks and throw an error containing information + about all failed ranks. + + Returns: + ``None``. + + Example:: + >>> # xdoctest: +SKIP("need process group init") + >>> # Note: Process group initialization omitted on each rank. + >>> from mindnlp import core.distributed as dist + >>> if dist.get_rank() != 1: + >>> dist.monitored_barrier() # Raises exception indicating that + >>> # rank 1 did not call into monitored_barrier. + >>> # Example with wait_all_ranks=True + >>> if dist.get_rank() == 0: + >>> dist.monitored_barrier(wait_all_ranks=True) # Raises exception + >>> # indicating that ranks 1, 2, ... world_size - 1 did not call into + >>> # monitored_barrier. + """ + # Need to call rank not in group before using the group, otherwise + # "Invalid process group" error is raised. + if _rank_not_in_group(group): + _warn_not_in_group("monitored_barrier") + return + + if get_backend(group) != Backend.GLOO: + raise ValueError("monitored_barrier is only implemented for GLOO backend.") + + if timeout is None: + timeout = _get_default_timeout(get_backend(group)) + elif isinstance(timeout, float): + # TODO(whc) aparently some existing test case for monitored_barrier passes in a timeout in float format? + warnings.warn( + "Please specify timeout arg as a timedelta. " + f"Converting current value of {timeout} assuming it represents seconds", + ) + timeout = timedelta(seconds=timeout) + + _check_valid_timeout(timeout) + + group_to_use = _get_default_group() if group is None else group + return group_to_use.monitored_barrier( # type:ignore[attr-defined] + timeout, wait_all_ranks=wait_all_ranks + ) + + +def _create_process_group_wrapper( + wrapped_pg: mindspore.communication._comm_helper.Backend, + store_prefix: str, + store: Store, + rank: int, + world_size: int, + timeout: timedelta = default_pg_timeout, +): + assert _GLOO_AVAILABLE, "ProcessGroupWrapper unsupported without GLOO backend." + + # (whc) this appears to be just for the gloo backend? if so, `default_pg_timeout` is appropriate... + + # Create a separate prefix store for the helper process group. + prefix = f"{PG_WRAPPER_STORE_PREFIX}:{store_prefix}" + store = PrefixStore(prefix, store) + helper_pg = ProcessGroupGloo(store, rank, world_size, timeout=timeout) + # Wrap the underlying pg with ProcessGroupWrapper. + wrapped_pg = _ProcessGroupWrapper(wrapped_pg, helper_pg) + return wrapped_pg + + +# helper function for deterministically hashing a list of ranks to a unique +# string +def _hash_ranks_to_str(ranks: List[int]) -> str: + rank_join: str = "_".join(map(str, ranks)) + # In case there is already a PG with the same rank composition + unique_str = "_".join([rank_join, str(len(_world.pg_names))]) + return hashlib.sha1(bytes(unique_str, "utf-8")).hexdigest() + + +# Takes a list of ranks and computes an integer color +def _process_group_color(ranks: List[int]) -> int: + # Convert list to tuple to make it hashable + ranks = tuple(ranks) + hash_value = hash(ranks) + # Split color must be: + # - a non-negative integer; + # - a type compatible with C's int because we are pybinding to the latter. + # Thus, we limit the hash value within c_int's max value. + max_c_int = 2 ** (ctypes.sizeof(ctypes.c_int) * 8 - 1) + color = abs(hash_value) % max_c_int + return color + + +def _process_group_name(ranks, use_hashed_name): + # Create name for a process group. + global _world + if use_hashed_name: + pg_name = _hash_ranks_to_str(ranks) + else: + pg_name = str(_world.group_count) + _world.group_count += 1 + # TODO: why is group count incremented only in the else path? + return pg_name + + +def _get_backend_from_str(backend: Optional[str] = None) -> Backend: + # Default to the same backend as the global process group + # if backend is not specified. + if not backend: + backend = get_backend(_get_default_group()) + return Backend(backend) + + +def _is_safe_to_split() -> bool: + """ + Checks if it is safe to split the any process group in the world. + This is only safe if the default pg has a bound device id, otherwise + users must be aware that a pg is only splittable after the first collective is + issued. + """ + return False if _get_default_group().bound_device_id is None else True + + +def split_group( + parent_pg: Optional[ProcessGroup] = None, + split_ranks: Optional[list] = None, + timeout: Optional[timedelta] = None, + pg_options: Optional[Any] = None, + group_desc: Optional[str] = None, +) -> Optional[ProcessGroup]: + """ + Create a new process group splitted from the given parent process group. + + warning:: This is an experimental API and only the ``NCCL`` backend supports this API. + Other backends will raise an error. + Users of this API must gurantee that all ranks in the parent group enter this API call, + and the split of the sub groups is the same accross all ranks in the parent group. + + Args: + parent_pg (ProcessGroup, optional): The parent process group. If None, + the default process group will be used. Users need to gurantee that + the parent group is fully initialized (e.g, communicators are initialized) + split_ranks (list[list[int]]): the split ranks, which is a list of list of ranks. + Users need to make sure the validity of the split ranks such that one + split (represented by one inner list of ints) does not overlap with any other split. + Note that the ranks in each split is the group rank (instead of global rank) + in the parent pg. For example, if the parent group has 4 ranks, and split_ranks can be + [[0, 1], [2, 3]]. Note [[0,1]] is also a valid split, in which case ranks 2, 3 would + return a non-group member. + timeout (timedelta, optional): see `init_process_group` for details and default value. + pg_options (ProcessGroupOptions, optional): only ProcessGroupNCCLOptions is supported now. + specifying what additional options need to be passed in during + the construction of specific process groups. i.e.``is_high_priority_stream`` + can be specified so that process group can pick up high priority cuda streams. + For other availble options to config nccl, + See https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/types.html#ncclconfig-t + group_desc (str, optional): a string to describe the process group. + + Returns: + ProcessGroup if the current rank is within one split/subgroup given by split_ranks, + or None if the current rank is not part of any split_ranks`. + + """ + # check inputs + if split_ranks is None: + raise ValueError("split_ranks cannot be None") + + global _world + default_pg = _get_default_group() + device_id = default_pg.bound_device_id + if not device_id: + raise RuntimeError( + "No device associated with the default pg, not safe to split any process groups" + ) + _default_backend, default_store = _world.pg_map[default_pg] + global_rank = default_pg.rank() + global_world_size = default_pg.size() + + if not parent_pg: + parent_pg = default_pg + if parent_pg not in _world.pg_group_ranks: + raise ValueError(f"Group {parent_pg} is not registered") + + parent_global_to_group_ranks = _world.pg_group_ranks[parent_pg] + parent_group_to_global_ranks = { + group_rank: global_rank + for global_rank, group_rank in parent_global_to_group_ranks.items() + } + + if global_rank not in parent_global_to_group_ranks: + raise ValueError( + f"Global rank {global_rank} is not part of the parent group {parent_pg}" + ) + + parent_group_rank = parent_global_to_group_ranks[global_rank] + parent_backend = parent_pg._get_backend(core.device("cuda")) + + # if the parent backend does not support splitting, raise error + # currently this API only support NCCL backend + if ( + not parent_backend + or not parent_backend.supports_splitting + or not isinstance(parent_backend, ProcessGroupNCCL) + ): + raise RuntimeError( + "No backend for the parent process group or its backend does not support splitting" + ) + + # set the group_desc before the color or no_cloor split + group_desc = ( + f"{parent_pg.group_desc}:split:{parent_backend.comm_split_count()}" + if group_desc is None + else group_desc + ) + + parent_backend_str, _ = _world.pg_map[parent_pg] + # same type of backend as the parent process group + backend = Backend(parent_backend_str) + backend_config = BackendConfig(backend) + + if pg_options is not None: + assert isinstance( + pg_options, ProcessGroupNCCL.Options + ), "Expected pg_options argument to be of type ProcessGroupNCCL.Options" + else: + # default pg_options same as the parent process group + pg_options = parent_backend.options + + # this timeout defaulting/validation is used for all the new_groups/new_subgroups variants, + # which may just pass their timeout value (or None) + if timeout is None: + timeout = _get_default_timeout(backend) + _check_valid_timeout(timeout) + + # find my group of ranks and my group local rank in split_ranks + my_group = None + group_rank = -1 + + for split_group in split_ranks: + if len(split_group) == 0: + raise ValueError("the split group cannot be empty") + if len(split_group) > global_world_size: + raise ValueError( + "the split group's size should be less or equal to the world_size set by init_process_group" + ) + if len(split_group) != len(set(split_group)): + raise ValueError("the split group cannot have duplicate ranks") + split_group = sorted(split_group) + if parent_group_rank in split_group: + my_group = split_group + group_rank = split_group.index(parent_group_rank) + break + # if my rank does not belong to any sub group, + # no_color split should be called + if my_group is None or group_rank == -1: + parent_backend.perform_nocolor_split(device_id) + return None + + group_name = _process_group_name(my_group, use_hashed_name=False) + global_ranks_in_my_group = [parent_group_to_global_ranks[rank] for rank in my_group] + + prefix_store = PrefixStore(f"{group_name}/", default_store) + # We register the backend after initializing and timeout is set in pg_options. + pg: ProcessGroup = ProcessGroup( + prefix_store, + group_rank, + len(my_group), + ) + backend_type = ProcessGroup.BackendType.NCCL + pg.bound_device_id = device_id + pg._set_default_backend(backend_type) + + pg_options._timeout = timeout + pg_options.split_from = parent_backend + pg_options.split_color = _process_group_color(my_group) + pg_options.global_ranks_in_group = global_ranks_in_my_group + pg_options.group_name = group_name + backend_class = ProcessGroupNCCL( + prefix_store, group_rank, len(my_group), pg_options + ) + backend_class._set_sequence_number_for_group() + + pg._register_backend(core.device("cuda"), backend_type, backend_class) + + # set group_name and group_desc to backend + assert group_name is not None + assert group_desc is not None + pg._set_group_name(group_name) + pg._set_group_desc(group_desc) + + # always eagerly initialize the backend in split_group + eager_backend = pg._get_backend(device_id) + eager_backend.eager_connect_single_device(device_id) + + # update global state + _world.pg_map[pg] = (backend, prefix_store) + _world.pg_names[pg] = group_name + _register_process_group(group_name, pg) + _world.pg_backend_config[pg] = str(backend_config) + pg_tag = f"ptd:{group_name}" + _world.tags_to_pg.setdefault(pg_tag, []).append(pg) + _world.pg_to_tag[pg] = pg_tag + + # Create the global rank to group rank mapping + _world.pg_group_ranks[pg] = { + global_rank: group_rank + for group_rank, global_rank in enumerate(global_ranks_in_my_group) + } + + return pg + + +def new_group( + ranks=None, + timeout=None, + backend=None, + pg_options=None, + use_local_synchronization=False, + group_desc=None, + device_id: Optional[core.device] = None, +): + """ + Create a new distributed group. + + This function requires that all processes in the main group (i.e. all + processes that are part of the distributed job) enter this function, even + if they are not going to be members of the group. Additionally, groups + should be created in the same order in all processes. + + .. warning:: + Safe concurrent usage: + When using multiple process groups with the ``NCCL`` backend, the user + must ensure a globally consistent execution order of collectives across + ranks. + + If multiple threads within a process issue collectives, explicit + synchronization is necessary to ensure consistent ordering. + + When using async variants of core.distributed communication APIs, + a work object is returned and the communication kernel is + enqueued on a separate CUDA stream, allowing overlap of communication + and computation. Once one or more async ops have been issued on one process + group, they must be synchronized with other cuda streams by calling `work.wait()` + before using another process group. + + See `Using multiple NCCL communicators concurrently `_ for more details. + + Args: + ranks (list[int]): List of ranks of group members. If ``None``, will be + set to all ranks. Default is ``None``. + timeout (timedelta, optional): see `init_process_group` for details and default value. + backend (str or Backend, optional): The backend to use. Depending on + build-time configurations, valid values are ``gloo`` and ``nccl``. + By default uses the same backend as the global group. This field + should be given as a lowercase string (e.g., ``"gloo"``), which can + also be accessed via :class:`Backend` attributes (e.g., + ``Backend.GLOO``). If ``None`` is passed in, the backend + corresponding to the default process group will be used. Default is + ``None``. + pg_options (ProcessGroupOptions, optional): process group options + specifying what additional options need to be passed in during + the construction of specific process groups. i.e. for the ``nccl`` + backend, ``is_high_priority_stream`` can be specified so that + process group can pick up high priority cuda streams. For other availble options to config nccl, + See https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/types.html#ncclconfig-t + use_local_synchronization (bool, optional): perform a group-local + barrier at the end of the process group creation. This is different + in that non-member ranks don't need to call into API and don't + join the barrier. + group_desc (str, optional): a string to describe the process group. + device_id (core.device, optional): a single, specific device + to "bind" this process to, The `new_group` call will try to initialize + a communication backend immediately for the device if this field is given. + + Returns: + A handle of distributed group that can be given to collective calls or + GroupMember.NON_GROUP_MEMBER if the rank is not part of ``ranks``. + + N.B. use_local_synchronization doesn't work with MPI. + + N.B. While use_local_synchronization=True can be significantly faster with larger + clusters and small process groups, care must be taken since it changes cluster behavior + as non-member ranks don't join the group barrier(). + + N.B. use_local_synchronization=True can lead to deadlocks when each rank creates + multiple overlaping process groups. To avoid that, make sure all ranks follow the + same global creation order. + """ + return _new_group_with_tag( + ranks, + timeout, + backend, + pg_options, + None, + use_local_synchronization=use_local_synchronization, + group_desc=group_desc, + device_id=device_id, + ) + + +def _new_group_with_tag( + ranks=None, + timeout=None, + backend=None, + backend_options=None, + pg_tag=None, + use_local_synchronization=False, + group_desc=None, + device_id: Optional[core.device] = None, +): + """ + Variant of ``new_group`` that exposes tag creation. + + :: N.B. The mechanism is experimental and tied to the functional collectives effort, see + ``core.distributed._functional_collectives`` for reference on how to use it. + """ + global _world + + default_pg = _get_default_group() + # if device_id is None: + # device_id = default_pg.bound_device_id + # elif default_pg.bound_device_id is not None: + # assert ( + # device_id == default_pg.bound_device_id + # ), "Mismatched bound device between new pg and the default pg." + default_backend, default_store = _world.pg_map[default_pg] + global_rank = default_pg.rank() + global_world_size = default_pg.size() + + # Default to the same backend as the global process group + # if the backend is not specified. + if not backend: + backend = default_backend + backend = Backend(backend) + + # this timeout defaulting/validation is used for all the new_groups/new_subgroups variants, + # which may just pass their timeout value (or None) + if timeout is None: + timeout = _get_default_timeout(backend) + _check_valid_timeout(timeout) + + if use_local_synchronization: + # MPI backend doesn't have have a way for us to perform a partial sync + if backend == Backend.MPI: + raise ValueError( + "MPI backend doesn't support use_local_synchronization=True" + ) + if ranks is not None and get_rank() not in ranks: + return None + + # checks the input ranks + if ranks is not None: + ranks = sorted(ranks) + group_world_size = len(ranks) + if group_world_size > global_world_size: + raise ValueError( + "the new group's world size should be less or " + "equal to the world size set by " + "init_process_group" + ) + # check ranks' sanity + for rank in ranks: + if rank < 0 or rank >= global_world_size: + raise ValueError( + "The new group's rank should be within " + "the world_size set by init_process_group" + ) + if global_rank in ranks: + group_rank = ranks.index(global_rank) + else: + group_rank = None + else: + ranks = list(range(global_world_size)) + group_world_size = global_world_size + group_rank = global_rank + + group_name = _process_group_name(ranks, use_hashed_name=use_local_synchronization) + + pg, pg_store = _new_process_group_helper( + group_world_size, + group_rank, + ranks, + backend, + default_store, + group_name, + backend_options=backend_options, + timeout=timeout, + pg_tag=pg_tag, + device_id=device_id, + group_desc=group_desc, + ) + + # Create the global rank to group rank mapping + _world.pg_group_ranks[pg] = { + global_rank: group_rank for group_rank, global_rank in enumerate(ranks) + } + + if _is_barrier_after_init() == 1: + # barrier at the end to ensure that once we return from this method, all + # process groups including global variables (if any) are updated + # correctly on all ranks. + # Update 04/2023: for large-scale runs, this barrier (esp. store-based + # barrier) may be costly and/or unscalable. Also, in a lot of cases, + # these barriers may be unnecessary, as proven by a green CI after + # removal. An environment variable `TORCH_DIST_INIT_BARRIER` has been + # added which enables this barrier only when set to 1. + logger.info( + "Performing barrier after ProcessGroup initialization since " + "TORCH_DIST_INIT_BARRIER = 1" + ) + if backend == Backend.MPI: + # MPI doesn't have store. + barrier() + else: + barrier_store = pg_store if use_local_synchronization else default_store + world_size = len(ranks) if use_local_synchronization else get_world_size() + # Use store based barrier here since barrier() used a bunch of + # default devices and messes up NCCL internal state. + _store_based_barrier( + global_rank, barrier_store, group_name, world_size, timeout + ) + + return pg + + +def new_subgroups( + group_size=None, + group=None, + timeout=None, + backend=None, + pg_options=None, + group_desc=None, +): + """ + Create subgroups of equal size. + + By default, it creates intra-machine subgroups, + where each of which contains all the ranks of a machine, based on the assumption + that each machine has the same number of devices. + + This is a convenience API that calls ``new_group`` to generate multiple subgroups. + It requires that all processes in the main group (i.e. all + processes that are part of the distributed job) enter this function, even + if they are not going to be members of the group. + + .. warning:: + If ``group_size`` is passed in, the world size must be divisible by ``group_size``. + If no ``group_size`` is passed in, it believe that you are creating a group based + on CUDA and determining the group size by number of CUDA devices, and if not all + the machines have the same number of devices, the subgroup division will be + different across nodes and can cause unexpected behaviors. Therefore, if you are + creating a subgroup that does not depend on CUDA (such as Gloo on CPU), please + pass in ``group_size`` correctly. + + .. warning:: + See warning `Safe concurrent usage` for `new_group` API for important details about + using multiple process groups concurrently in a safe manner. + + Args: + group_size (int, optional): The size of each subgroup. If ``None``, + the default subgroup size is equal to the number of devices on each machine, + based on the assumption that each machine has exactly the same + number of devices. Default is ``None``. + timeout (timedelta, optional): see `init_process_group` for details and default value. + backend (str or Backend, optional): The backend to use. Depending on + build-time configurations, valid values are ``gloo`` and ``nccl``. + By default uses the same backend as the global group. This field + should be given as a lowercase string (e.g., ``"gloo"``), which can + also be accessed via :class:`Backend` attributes (e.g., + ``Backend.GLOO``). If ``None`` is passed in, the backend + corresponding to the default process group will be used. Default is + ``None``. + pg_options (ProcessGroupOptions, optional): process group options + specifying what additional options need to be passed in during + the construction of specific process groups. i.e. for the ``nccl`` + backend, ``is_high_priority_stream`` can be specified so that + process group can pick up high priority cuda streams. + group_desc (str, optional): A string describing the group. Each subgroup will + inherit its group_desc + + Returns: + The subgroup containing the current rank, and all the subgroups used for cleanup. + + Examples: + >>> # Create intra-machine subgroups. + >>> # xdoctest: +SKIP("need process group init") + >>> cur_subgroup, subgroups = dist.new_subgroups() + >>> # Allreduce within the machine. + >>> rank = dist.get_rank() + >>> tensor = core.ones(1, device=rank) * rank + >>> dist.all_reduce(tensor, group=cur_subgroup) + >>> tensor + tensor([28]) # Assume 8 CUDA devices per machine. 28 is sum(range(8)). + >>> # Cleanup. + >>> for subgroup in subgroups: + >>> dist.destroy_process_group(subgroup) + """ + if group_size is None: + if not core.cuda.is_available(): + raise ValueError( + "Default group size only takes effect when CUDA is available." + "If your subgroup using a backend that does not depend on CUDA," + "please pass in 'group_size' correctly." + ) + group_size = core.cuda.device_count() + if group_size <= 0: + raise ValueError(f"The arg 'group_size' ({group_size}) must be positive") + + world_size = get_world_size() + if world_size < group_size: + raise ValueError( + f"The arg 'group_size' ({group_size}) must not exceed the world size ({world_size})" + ) + if world_size % group_size != 0: + raise ValueError("The world size must be divisible by 'group_size'") + + subgroups = [] + cur_subgroup = None + + for subgroup_id in range(world_size // group_size): + start_rank = subgroup_id * group_size + end_rank = start_rank + group_size + ranks_in_subgroup = list(range(start_rank, end_rank)) + subgroup = new_group( + ranks=ranks_in_subgroup, + timeout=timeout, + backend=backend, + pg_options=pg_options, + group_desc=group_desc, + ) + subgroups.append(subgroup) + + rank = get_rank() + if rank in ranks_in_subgroup: + cur_subgroup = subgroup + logger.info("Rank %s is assigned to subgroup %s", rank, ranks_in_subgroup) + + return cur_subgroup, subgroups + + +def new_subgroups_by_enumeration( + ranks_per_subgroup_list, + timeout=None, + backend=None, + pg_options=None, + group_desc=None, +): + """ + Create subgroups by dividing the global world. + + The division is specified by a nested list of ranks. The subgroups cannot have + overlap, and some ranks may not have to be in any subgroup. + + This is a convenience API that calls ``new_group`` to generate multiple subgroups. + It requires that all processes in the main group (i.e. all + processes that are part of the distributed job) enter this function, even + if they are not going to be members of the group. + + .. warning:: + See warning `Safe concurrent usage` for `new_group` API for important details about + using multiple process groups concurrently in a safe manner. + + Args: + ranks_per_subgroup_list (list[list[int]]): A nested list of ranks of + group members. + timeout (timedelta, optional): see `init_process_group` for details and default value. + backend (str or Backend, optional): The backend to use. Depending on + build-time configurations, valid values are ``gloo`` and ``nccl``. + By default uses the same backend as the global group. This field + should be given as a lowercase string (e.g., ``"gloo"``), which can + also be accessed via :class:`Backend` attributes (e.g., + ``Backend.GLOO``). If ``None`` is passed in, the backend + corresponding to the default process group will be used. Default is + ``None``. + pg_options (ProcessGroupOptions, optional): process group options + specifying what additional options need to be passed in during + the construction of specific process groups. i.e. for the ``nccl`` + backend, ``is_high_priority_stream`` can be specified so that + process group can pick up high priority cuda streams. + group_desc (str, optional): A string describing the group. Each subgroup will + inherit its group_desc. + + Returns: + The subgroup containing the current rank, and all the subgroups used for cleanup. + + Examples: + >>> # Create two subgroups, where each has 2 processes. + >>> # xdoctest: +SKIP("need process group init") + >>> cur_subgroup, subgroups = dist.new_subgroups(ranks=[[0, 2], [1, 3]]) + >>> rank = dist.get_rank() + >>> tensor = core.ones(1, device=rank) * rank + >>> dist.all_reduce(tensor, group=cur_subgroup) + >>> tensor + tensor([2]) # Subgroup 0: ranks 0 and 2 + tensor([4]) # Subgroup 1: ranks 1 and 3 + """ + if ranks_per_subgroup_list is None or len(ranks_per_subgroup_list) == 0: + raise ValueError("The arg 'ranks_per_subgroup_list' cannot be empty") + + subgroups = [] + cur_subgroup = None + # Create a mapping from rank to subgroup to check if there is any subgroup overlap. + rank_to_ranks_dict = {} # type: ignore[var-annotated] + for ranks in ranks_per_subgroup_list: + subgroup = new_group( + ranks=ranks, + timeout=timeout, + backend=backend, + pg_options=pg_options, + group_desc=group_desc, + ) + subgroups.append(subgroup) + my_rank = get_rank() + for rank in ranks: + if rank in rank_to_ranks_dict: + raise ValueError( + f"Rank {rank} has appeared in both subgroup {rank_to_ranks_dict[rank]} and {ranks}" + ) + rank_to_ranks_dict[rank] = ranks + if my_rank == rank: + cur_subgroup = subgroup + logger.info("Rank %s is assigned to subgroup %s", rank, ranks) + + return cur_subgroup, subgroups + + +def _find_pg_by_ranks_and_tag(tag: str, ranks: List[int]) -> Optional[ProcessGroup]: + if len(tag) > 0 and not tag.startswith("ptd:") and not tag.startswith("user:"): + tag = f"user:{tag}" + + for group in _world.tags_to_pg.get(tag, []): + if group.size() != len(ranks): + continue + + group_ranks = get_process_group_ranks(group) + good = all(r in group_ranks for r in ranks) + if good: + return group + return None + + +def _find_or_create_pg_by_ranks_and_tag( + tag: str, ranks: List[int], stride: int +) -> ProcessGroup: + assert ( + len(ranks) % stride == 0 + ), f"Ranks length ({len(ranks)}) must be divisible by stride ({stride})" + + my_rank = get_rank() + my_ranks = None + + if stride == len(ranks): + my_ranks = ranks.copy() + assert my_rank in my_ranks, "rankset doesn't include the current node" + else: + for i in range(0, len(ranks), stride): + rank_set = ranks[i : i + stride] + if my_rank in rank_set: + my_ranks = rank_set + assert my_ranks is not None, "rankset doesn't include the current node" + + my_ranks = sorted(my_ranks) + + pg = _find_pg_by_ranks_and_tag(tag, my_ranks) + if pg is not None: + return pg + if tag == "": + raise ValueError("Cannot automatically create PG with empty tag") + # TODO copy settings and timeout from default PG + return _new_group_with_tag(my_ranks, pg_tag=tag) + + +def _get_group_tag(pg: ProcessGroup) -> str: + """Return the tag associated with ``pg``.""" + tag = _world.pg_to_tag[pg] + if tag.startswith("user:"): + tag = tag[5:] + return tag + + +def _get_process_group_name(pg: ProcessGroup) -> str: + return _world.pg_names.get(pg, "None") + + +def _get_process_group_store(pg: ProcessGroup) -> Store: + return _world.pg_map[pg][1] diff --git a/mindnlp/core/distributed/elastic/__init__.py b/mindnlp/core/distributed/elastic/__init__.py new file mode 100644 index 000000000..427e1745c --- /dev/null +++ b/mindnlp/core/distributed/elastic/__init__.py @@ -0,0 +1,77 @@ +#!/usr/bin/env/python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" + +Torchelastic agent and user worker failover contract: + +**TL;DR;**: + +* TE(torchelastic) expects user workers to finish with the 5 minutes drift +* It is better to design DDP app to fail for all workers, rather than a single one. +* TE does not synchronize number of restarts between agents +* TE re-rendezvous does not trigger restart decrease +* When a single agent finishes its job(successfully or not), it will close rendezvous. + If other agents still have workers in progress, they will be terminated. +* Based on above, scale down does not work if at least single agent finishes the job. +* When Scale up is detected by agents, it will not decrease ``max_restarts`` + + +In general TE(torchelastic) can launch arbitrary user code, but there is some +clarifications need to be done around what failover mechanism torchelastic +provides and what failover mechanism it expects from user workers. + +Torchelastic currently supports DDP style applications. That means that +TE expects *ALL* workers finish approximately at the same time. In practice, +it is nearly to impossible to guarantee that all workers in arbitrary +DDP application finish at the time, so TE provides a finalization barrier +that waits for TIMEOUT(5 minutes) for worker finalization. + +**Worker Failure** + +When worker fails, TE will check the number of restarts +available, if there is more than 0 restarts, TE will start a new rendezvous +round and restart the worker process. New rendezvous round will other +TE agents to terminate their workers. + +.. note:: The TE agent does not synchronize restarts between themselves. + When a single agent performs restart, it will trigger a local ``max_restarts`` + decrease, other agent will not decrease their ``max_restarts``. + the user to run the distributed application locally on a dev host. + +A single worker failure can cause the whole cluster to fail: +If a single worker is constantly failing, it will cause the TE agent +``max_restarts`` to go to zero. This will cause an agent to finish its +work and close rendezvous. If there are any other workers on different +agents, they will be terminated. + + +**Re-Rendezvous** + +Re-rendezvous occurs when TE agents detect a new node +trying to joint a cluster. TE will not decrease ``max_restarts``. TE agents +will terminate its workers and start a new rendezvous round. + +Note about DynamicRendezvous(etcd-v2, c10d-experimental): If the rendezvous +has already max_nodes, the new node won't be added to the wait list right +away since there is no need to tear down a rendezvous that is already fully +utilized. The new node will wait until its timeout (600 secs by default) +and periodically check the number of participants. If the number becomes +less than max_nodes, it will be added to the wait list; otherwise, it will time out after 600 secs. + +*Scale up event*. When scale up event happens, torchelastic rendezvous +will detect that there are new nodes trying to join. Torchelastic agent +will stop all workers and perform re-rendezvous. Note: when scale up event +happens, *``max_restarts``* will *not* decrease. + +*Scale down event*. When scale down event happens, rendezvous will not +notify the torchelastic agent about it. If TE agent launched with ``max_restarts=0`` , +it relies on the underlying scheduler to handle job restart. If the ``max_restarts>0`` , +TE agent will terminate workers and start a new rdzv round, which is a *Scale up event*. + +""" diff --git a/mindnlp/core/distributed/elastic/agent/__init__.py b/mindnlp/core/distributed/elastic/agent/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/mindnlp/core/distributed/elastic/agent/server/__init__.py b/mindnlp/core/distributed/elastic/agent/server/__init__.py new file mode 100644 index 000000000..611451667 --- /dev/null +++ b/mindnlp/core/distributed/elastic/agent/server/__init__.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +The elastic agent is the control plane of torchelastic. + +It is a process that launches and manages underlying worker processes. +The agent is responsible for: + +1. Working with distributed torch: the workers are started with all the + necessary information to successfully and trivially call + ``core.distributed.init_process_group()``. + +2. Fault tolerance: monitors workers and upon detecting worker failures + or unhealthiness, tears down all workers and restarts everyone. + +3. Elasticity: Reacts to membership changes and restarts workers with the new + members. + +The simplest agents are deployed per node and works with local processes. +A more advanced agent can launch and manage workers remotely. Agents can +be completely decentralized, making decisions based on the workers it manages. +Or can be coordinated, communicating to other agents (that manage workers +in the same job) to make a collective decision. +""" + +from .api import ( # noqa: F401 + ElasticAgent, + RunResult, + SimpleElasticAgent, + Worker, + WorkerGroup, + WorkerSpec, + WorkerState, +) +from .local_elastic_agent import TORCHELASTIC_ENABLE_FILE_TIMER, TORCHELASTIC_TIMER_FILE diff --git a/mindnlp/core/distributed/elastic/agent/server/api.py b/mindnlp/core/distributed/elastic/agent/server/api.py new file mode 100644 index 000000000..d86f1a287 --- /dev/null +++ b/mindnlp/core/distributed/elastic/agent/server/api.py @@ -0,0 +1,957 @@ +# mypy: ignore-errors + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import abc +import json +import os +import signal +import socket +import time +import traceback +import warnings +from collections import defaultdict +from contextlib import contextmanager +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +from mindnlp import core.distributed.elastic.rendezvous as rdzv +from mindnlp import core.distributed.elastic.utils.store as store_util +from core.distributed.elastic.events import Event, EventSource, record +from core.distributed.elastic.metrics import prof, put_metric +from core.distributed.elastic.multiprocessing import ProcessFailure, SignalException +from core.distributed.elastic.rendezvous import RendezvousGracefulExitError +from core.distributed.elastic.utils.logging import get_logger + + +__all__ = [ + "WorkerSpec", + "Worker", + "WorkerState", + "WorkerGroup", + "RunResult", + "ElasticAgent", + "SimpleElasticAgent", +] +_TERMINAL_STATE_SYNC_ID = "torchelastic/agent/terminal_state" + +DEFAULT_ROLE = "default" +logger = get_logger(__name__) + + +@dataclass +class WorkerSpec: + """Blueprint information about a particular type of worker. + + For a given role, there must only exist a single worker spec. + Worker spec is expected to be homogeneous across all nodes (machine), + that is each node runs the same number of workers for a particular spec. + + Args: + role: user-defined role for the workers with this spec + local_world_size: number local workers to run + fn: (deprecated use entrypoint instead) + entrypoint: worker function or command + args: arguments to pass to ``entrypoint`` + rdzv_handler: handles rdzv for this set of workers + max_restarts: number of max retries for the workers + monitor_interval: monitor status of workers every ``n`` seconds + master_port: fixed port to run the c10d store on rank 0 + if not specified then will chose a random free port + master_addr: fixed master_addr to run the c10d store on rank 0 + if not specified then will chose hostname on agent rank 0 + redirects: redirect std streams to a file, + selectively redirect for a particular + local rank by passing a map + tee: tees the specified std stream(s) to console + file, + selectively tee for a particular local rank by passing a map, + takes precedence over ``redirects`` settings. + + """ + + role: str + local_world_size: int + rdzv_handler: rdzv.RendezvousHandler + fn: Optional[Callable] = None + # TODO @kiuk - make entrypoint a required field + entrypoint: Union[Callable, str, None] = None + args: Tuple = () + max_restarts: int = 3 + monitor_interval: float = 0.1 + master_port: Optional[int] = None + master_addr: Optional[str] = None + local_addr: Optional[str] = None + + def __post_init__(self): + assert self.local_world_size > 0 + assert self.monitor_interval > 0 + + if self.fn: + warnings.warn( + "WorkerSpec.fn will be deprecated," + " please use WorkerSpec.entrypoint instead", + category=DeprecationWarning, + ) + self.entrypoint = self.fn + assert self.entrypoint + + def get_entrypoint_name(self): + """Get the entry point name. + + If the entrypoint is a function (e.g. ``Callable``) returns its ``__qualname__`` + else if the entrypoint is a binary (e.g. ``str``), returns the binary name. + """ + if isinstance(self.entrypoint, str): + return os.path.basename(self.entrypoint) + else: + assert self.entrypoint is not None + return self.entrypoint.__qualname__ + + +class Worker: + """A worker instance. + + Contrast this with ``WorkerSpec`` that represents the specifications of a + worker. A ``Worker`` is created from a ``WorkerSpec``. A ``Worker`` is to + a ``WorkerSpec`` as an object is to a class. + + The ``id`` of the worker is interpreted + by the specific implementation of ``ElasticAgent``. For a local + agent, it could be the ``pid (int)`` of the worker, for a remote + agent it could be encoded as ``host:port (string)``. + + Args: + id (Any): uniquely identifies a worker (interpreted by the agent) + local_rank (int): local rank of the worker + global_rank (int): global rank of the worker + role_rank (int): rank of the worker across all workers that have the same role + world_size (int): number of workers (globally) + role_world_size (int): number of workers that have the same role + """ + + __slots__ = [ + "id", + "local_rank", + "global_rank", + "role_rank", + "world_size", + "role_world_size", + ] + + def __init__( + self, + local_rank: int, + global_rank: int = -1, + role_rank: int = -1, + world_size: int = -1, + role_world_size: int = -1, + ): + # unique identifier for this worker + self.id: Any = None + + # rank of the worker among workers with the same role being monitored + # by the same ``agent`` instance. + self.local_rank: int = local_rank + + # rank of the worker among all the workers across all roles + # across all ``agent`` instances. + # Global rank is not stable between re-rendezvous. + self.global_rank: int = global_rank + + # rank of the worker among all the workers with the same role + # across all ``agent`` instances. + # Role rank is not stable between re-rendezvous. + self.role_rank: int = role_rank + + # total number of workers (globally). Due to elasticity + # the world size may change between re-rendezvous. + self.world_size: int = world_size + + # total number of workers that share the same role. Due to elasticity + # the role world size may change between re-rendezvous. + self.role_world_size: int = role_world_size + + def __str__(self): + return ( + f"local_rank={self.local_rank},global_rank={self.global_rank}" + f",role_rank={self.role_rank},world_size={self.world_size}" + f",role_world_size={self.role_world_size}" + ) + + def __repr__(self): + return str(self) + + +class WorkerState(str, Enum): + """A state of the ``WorkerGroup``. + + Workers in a worker group change state as a unit. If a single worker + in a worker group fails the entire set is considered failed:: + + UNKNOWN - agent lost track of worker group state, unrecoverable + INIT - worker group object created not yet started + HEALTHY - workers running and healthy + UNHEALTHY - workers running and unhealthy + STOPPED - workers stopped (interrupted) by the agent + SUCCEEDED - workers finished running (exit 0) + FAILED - workers failed to successfully finish (exit !0) + + + A worker group starts from an initial ``INIT`` state, + then progresses to ``HEALTHY`` or ``UNHEALTHY`` states, + and finally reaches a terminal ``SUCCEEDED`` or ``FAILED`` state. + + Worker groups can be interrupted and temporarily put into ``STOPPED`` state + by the agent. Workers in ``STOPPED`` state are scheduled to be restarted + in the near future by the agent. Some examples of workers being put into + ``STOPPED`` state are: + + 1. Worker group failure|unhealthy observed + 2. Membership change detected + + When actions (start, stop, rdzv, retry, etc) on worker group fails + and results in the action being partially applied to the worker group + the state will be ``UNKNOWN``. Typically this happens on uncaught/unhandled + exceptions during state change events on the agent. The agent is not + expected to recover worker groups in ``UNKNOWN`` state and is better off + self terminating and allowing the job manager to retry the node. + """ + + UNKNOWN = "UNKNOWN" + INIT = "INIT" + HEALTHY = "HEALTHY" + UNHEALTHY = "UNHEALTHY" + STOPPED = "STOPPED" + SUCCEEDED = "SUCCEEDED" + FAILED = "FAILED" + + @staticmethod + def is_running(state: "WorkerState") -> bool: + """Return the state of the Worker. + + Returns: + True if the worker state represents workers still running + (e.g. that the process exists but not necessarily healthy). + """ + return state in {WorkerState.HEALTHY, WorkerState.UNHEALTHY} + + +class WorkerGroup: + """A set of ``Worker`` instances. + + The class defines a set of ``Worker`` instances for the given ``WorkerSpec`` managed by ``ElasticAgent``. Whether the worker + group contains cross instance workers or not depends on the implementation of the agent. + """ + + __slots__ = [ + "spec", + "workers", + "store", + "group_rank", + "group_world_size", + "state", + "master_addr", + "master_port", + ] + + def __init__(self, spec: WorkerSpec): + self.spec = spec + self.workers = [Worker(local_rank=i) for i in range(self.spec.local_world_size)] + + # assigned after rdzv + self.store = None + self.group_rank = None + self.group_world_size = None + self.master_addr = None + self.master_port = None + + self.state = WorkerState.INIT + + +class _RoleInstanceInfo: + """The class is used by the agent to exchange the information with other agents. + + The information is used to determine the rank of the workers that agent + manages in heterogeneous environments, where different agents can have + different number of workers. + """ + + __slots__ = ["role", "rank", "local_world_size"] + + def __init__(self, role: str, rank: int, local_world_size: int): + r"""Initialize the agent class instance. + + Args: + role (str): user-defined role for the workers with this spec + rank (int): the rank of the agent + local_world_size (int): number of local workers to run + """ + self.role = role + self.rank = rank + self.local_world_size = local_world_size + + def serialize(self) -> bytes: + dict_data = { + "role": self.role, + "rank": self.rank, + "local_world_size": self.local_world_size, + } + return json.dumps(dict_data).encode(encoding="UTF-8") + + @staticmethod + def deserialize(data: bytes): + dict_data = json.loads(data.decode(encoding="UTF-8")) + return _RoleInstanceInfo( + dict_data["role"], dict_data["rank"], dict_data["local_world_size"] + ) + + @staticmethod + def compare(obj1, obj2) -> int: + if obj1.role == obj2.role: + return obj1.rank - obj2.rank + elif obj1.role > obj2.role: + return 1 + else: + return -1 + + @staticmethod + def find_role_boundaries(roles_infos: List, role: str) -> Tuple[int, int]: + start_idx, end_idx = -1, -1 + for idx, role_info in enumerate(roles_infos): + if role_info.role == role: + if start_idx == -1: + start_idx = idx + end_idx = idx + return (start_idx, end_idx) + + +@dataclass +class RunResult: + """Return results of the worker executions. + + Run results follow an "all-or-nothing" policy where the run is successful if and + only if ALL local workers managed by this agent complete successfully. + + If the result is successful (e.g. ``is_failed() = False``) then the ``return_values`` + field contains the outputs (return values) of the workers managed by THIS agent mapped + by their GLOBAL ranks. That is ``result.return_values[0]`` is the return value of + global rank 0. + + .. note:: ``return_values`` are only meaningful for when the worker entrypoint + is a function. Workers specified as a binary entrypoint do not canonically + have a return value and the ``return_values`` field is meaningless and + may be empty. + + If ``is_failed()`` returns ``True`` then the ``failures`` field contains the + failure information, again, mapped by the GLOBAL rank of the worker that failed. + + The keys in ``return_values`` and ``failures`` are mutually exclusive, that is, + a worker's final state can only be one of: succeeded, failed. Workers intentionally + terminated by the agent according to the agent's restart policy, are not represented + in either ``return_values`` nor ``failures``. + """ + + state: WorkerState + return_values: Dict[int, Any] = field(default_factory=dict) + failures: Dict[int, ProcessFailure] = field(default_factory=dict) + + def is_failed(self) -> bool: + return self.state == WorkerState.FAILED + + +def _get_fq_hostname() -> str: + return socket.getfqdn(socket.gethostname()) + + +class ElasticAgent(abc.ABC): + """An agent process responsible for managing one or more worker processes. + + The worker processes are assumed to be regular distributed PyTorch scripts. + When the worker process is created by the agent, the agent provides the + necessary information for the worker processes to properly initialize + a torch process group. + + The exact deployment topology and ratio of agent-to-worker is dependent + on the specific implementation of the agent and the user's job placement + preferences. For instance, to run a distributed training job on GPU with + 8 trainers (one per GPU) one can: + + 1. Use 8 x single GPU instances, place an agent per instance, managing + 1 worker per agent. + 2. Use 4 x double GPU instances, place an agent per instance, managing + 2 workers per agent. + 3. Use 2 x quad GPU instances, place an agent per instance, managing + 4 workers per agent. + 4. Use 1 x 8 GPU instance, place an agent per instance, managing + 8 workers per agent. + + Usage + :: + + group_result = agent.run() + if group_result.is_failed(): + # workers failed + failure = group_result.failures[0] + logger.exception("worker 0 failed with exit code : %s", failure.exit_code) + else: + return group_result.return_values[0] # return rank 0's results + + """ + + @abc.abstractmethod + def run(self, role: str = DEFAULT_ROLE) -> RunResult: + """Run the agent. + + Supports retrying the worker group on failures up to ``max_restarts``. + + Returns: + The result of the execution, containing the return values or + failure details for each worker mapped by the worker's global rank. + + Raises: + Exception - any other failures NOT related to worker process + """ + raise NotImplementedError + + @abc.abstractmethod + def get_worker_group(self, role: str = DEFAULT_ROLE) -> WorkerGroup: + """Return the ``WorkerGroup`` for the given ``role``. + + Note that the worker group is a mutable object and hence in a + multi-threaded/process environment it may change state. + Implementors are encouraged (but not required) to return + a defensive read-only copy. + """ + raise NotImplementedError + + +class SimpleElasticAgent(ElasticAgent): + """An ``ElasticAgent`` that manages one particular type of worker role. + + An ``ElasticAgent`` that manages workers (``WorkerGroup``) for a single ``WorkerSpec`` + such as one particular type of worker role. + """ + + def __init__(self, spec: WorkerSpec, exit_barrier_timeout: float = 300): + self._worker_group = WorkerGroup(spec) + self._remaining_restarts = self._worker_group.spec.max_restarts + self._store = None + self._exit_barrier_timeout = exit_barrier_timeout + self._total_execution_time = 0 + + def get_worker_group(self, role: str = DEFAULT_ROLE) -> WorkerGroup: + return self._worker_group + + @abc.abstractmethod + def _start_workers(self, worker_group: WorkerGroup) -> Dict[int, Any]: + r"""Start ``worker_group.spec.local_world_size`` number of workers. + + This is according to worker spec for the worker group . + Returns a map of ``local_rank`` to worker ``id``. + """ + raise NotImplementedError + + @abc.abstractmethod + def _stop_workers( + self, worker_group: WorkerGroup, is_restart: bool = False + ) -> None: + r"""Stop all workers in the given worker group. + + Implementors must deal with workers in all states defined by + ``WorkerState``. That is, it must gracefully handle stopping + non-existent workers, unhealthy (stuck) workers, etc. + """ + raise NotImplementedError + + @abc.abstractmethod + def _monitor_workers(self, worker_group: WorkerGroup) -> RunResult: + r"""Check on the workers for the ``worker_group``. + + This function also returns the new state of the worker group. + """ + raise NotImplementedError + + @abc.abstractmethod + def _shutdown( + self, death_sig: signal.Signals = signal.SIGTERM, is_restart: bool = False + ) -> None: + """Clean up any resources that were allocated during the agent's work. + + Args: + death_sig: Signal to send to the child process, SIGTERM is default + """ + raise NotImplementedError + + @prof + def _rendezvous(self, worker_group: WorkerGroup) -> None: + r"""Run rendezvous for the workers specified by the worker spec. + + Assigns workers a new global rank and world size. + Updates the rendezvous store for the worker group. + """ + spec = worker_group.spec + + with self.record_duration("RENDEZVOUS"): + rdzv_info = spec.rdzv_handler.next_rendezvous() + store = rdzv_info.store + group_rank = rdzv_info.rank + group_world_size = rdzv_info.world_size + + # master_addr/master_port could be explicitly overriden + # TODO: BC - specific to static rdzv and can be simplifed further + master_addr = spec.master_addr or rdzv_info.bootstrap_store_info.master_addr + master_port = spec.master_port or rdzv_info.bootstrap_store_info.master_port + + self._store = store + + with self.record_duration("ASSIGN_WORKER_RANKS"): + workers = self._assign_worker_ranks( + store, group_rank, group_world_size, spec + ) + worker_group.workers = workers + worker_group.store = store + worker_group.group_rank = group_rank + worker_group.group_world_size = group_world_size + worker_group.master_addr = master_addr + worker_group.master_port = master_port + + restart_count = spec.max_restarts - self._remaining_restarts + + logger.info( + "[%(role)s] Rendezvous complete for workers. Result:\n" + " restart_count=%(restart_count)s\n" + " master_addr=%(master_addr)s\n" + " master_port=%(master_port)s\n" + " group_rank=%(group_rank)s\n" + " group_world_size=%(group_world_size)s\n" + " local_ranks=%(local_ranks)s\n" + " role_ranks=%(role_ranks)s\n" + " global_ranks=%(global_ranks)s\n" + " role_world_sizes=%(role_world_sizes)s\n" + " global_world_sizes=%(global_world_sizes)s\n", + { + "role": spec.role, + "restart_count": restart_count, + "master_addr": master_addr, + "master_port": master_port, + "group_rank": group_rank, + "group_world_size": group_world_size, + "local_ranks": [worker.local_rank for worker in workers], + "role_ranks": [worker.role_rank for worker in workers], + "global_ranks": [worker.global_rank for worker in workers], + "role_world_sizes": [worker.role_world_size for worker in workers], + "global_world_sizes": [worker.world_size for worker in workers], + }, + ) + + # pyre-fixme[56]: Pyre was not able to infer the type of the decorator + # `core.distributed.elastic.metrics.prof`. + @prof + def _assign_worker_ranks( + self, store, group_rank: int, group_world_size: int, spec: WorkerSpec + ) -> List[Worker]: + """Determine proper ranks for worker processes. + + Fast Path: when all workers have the same role and world size. We calculate + the global rank to be group_rank * group_world_size + local_rank. And the + `role_world_size` is the same as `global_world_size`. No TCP store is used in + this case. This is only enabled when users set the environment variable + `TORCH_ELASTIC_WORKER_IDENTICAL` to 1. + + Time complexity: each worker O(1), overall O(1) + + Slow Path: when workers have different roles and world sizes. We use the + the following algorithm: + + 1. Each agent writes its configuration(group_rank, group_world_size + , num_workers) to the common store. + 2. The rank 0 agent reads all the role_info from the store and + determines each agents worker ranks. + 3. Determine the global rank: the global rank of the workers is computed + by cumulative sum of the local_world_size for all workers in front of it. + For efficiency reasons each worker is assigned a base global rank + such that it's workers are in the range [base_global_rank, + base_global_rank + local_world_size). + 4. Determine the role rank: The role rank is determined using the algorithms + in the point 3 with the exception that the ranks are calculated with + respect to the role name. + 5. The rank 0 agent writes the assigned ranks to the store. + 6. Each agent reads the assigned ranks from the store. + + Time complexity: each worker O(1), rank0 O(n), overall O(n) + """ + + if os.environ.get("TORCH_ELASTIC_WORKER_IDENTICAL", "0") == "1": + global_world_size = group_world_size * spec.local_world_size + base_global_rank = group_rank * spec.local_world_size + base_role_rank = base_global_rank + role_world_size = global_world_size + else: + ROLE_INFO_PREFIX = "torchelastic/role_info/" + ASSIGNED_RANKS_PREFIX = "torchelastic/assigned_ranks/" + + agent_role_info = _RoleInstanceInfo( + spec.role, group_rank, spec.local_world_size + ) + store.set(f"{ROLE_INFO_PREFIX}{group_rank}", agent_role_info.serialize()) + + # tcp store is collocated with rank 0 so we can use it to do extra compute to reduce overall # of operations. + if group_rank == 0: + role_infos_bytes = store.multi_get( + [f"torchelastic/role_info/{i}" for i in range(group_world_size)] + ) + role_infos = [ + _RoleInstanceInfo.deserialize(info_bytes) + for info_bytes in role_infos_bytes + ] + + role_sizes = defaultdict(lambda: 0) + global_size = 0 + for role_info in role_infos: + role_sizes[role_info.role] += role_info.local_world_size + global_size += role_info.local_world_size + + base_global_rank = 0 + role_ranks = defaultdict(lambda: 0) + + keys = [] + values = [] + for i, role_info in enumerate(role_infos): + keys.append(f"{ASSIGNED_RANKS_PREFIX}{i}") + values.append( + json.dumps( + [ + base_global_rank, + global_size, + role_ranks[role_info.role], + role_sizes[role_info.role], + ] + ) + ) + + base_global_rank += role_info.local_world_size + role_ranks[role_info.role] += role_info.local_world_size + + store.multi_set(keys, values) + + # get will block until the data is available in the store. + ( + base_global_rank, + global_world_size, + base_role_rank, + role_world_size, + ) = json.loads(store.get(f"{ASSIGNED_RANKS_PREFIX}{group_rank}")) + + workers = [] + for local_rank in range(spec.local_world_size): + worker = Worker( + local_rank=local_rank, + global_rank=base_global_rank + local_rank, + role_rank=base_role_rank + local_rank, + world_size=global_world_size, + role_world_size=role_world_size, + ) + workers.append(worker) + return workers + + # pyre-fixme[56]: Pyre was not able to infer the type of the decorator + # `core.distributed.elastic.metrics.prof`. + @prof + def _initialize_workers(self, worker_group: WorkerGroup) -> None: + r"""Start a fresh set of workers for the worker_group. + + Essentially, a rendezvous followed by a ``start_workers``. + The caller should first call ``_stop_workers()`` to stop running workers + prior to calling this method. + + Optimistically sets the state of the worker group that + just started as ``HEALTHY`` and delegates the actual monitoring + of state to ``_monitor_workers()`` method + """ + role = worker_group.spec.role + logger.info("[%s] Rendezvous'ing worker group", role) + + # TODO after stopping workers, wait at least monitor_interval*2 for + # workers on different nodes to fail on a collective op before waiting + # on the rdzv barrier, this way we ensure that nodes enter rdzv + # at around the same time and reduce false positive rdzv timeout errors + self._rendezvous(worker_group) + + logger.info("[%s] Starting worker group", role) + worker_ids = self._start_workers(worker_group) + for local_rank, w_id in worker_ids.items(): + worker = worker_group.workers[local_rank] + worker.id = w_id + + worker_group.state = WorkerState.HEALTHY + + # pyre-fixme[56]: Pyre was not able to infer the type of the decorator + # `core.distributed.elastic.metrics.prof`. + @prof + def _restart_workers(self, worker_group: WorkerGroup) -> None: + """Restart (stops, rendezvous, starts) all local workers in the group.""" + role = worker_group.spec.role + logger.info("[%s] Stopping worker group", role) + self._stop_workers(worker_group, is_restart=True) + worker_group.state = WorkerState.STOPPED + self._initialize_workers(worker_group) + + # pyre-fixme[56]: Pyre was not able to infer the type of the decorator + # `core.distributed.elastic.metrics.prof`. + @prof + def run(self, role: str = DEFAULT_ROLE) -> RunResult: + start_time = time.monotonic() + shutdown_called: bool = False + try: + result = self._invoke_run(role) + self._total_execution_time = int(time.monotonic() - start_time) + self._record_metrics(result) + self._record_worker_events(result) + return result + except RendezvousGracefulExitError as e: + logger.info("Rendezvous gracefully exited: %s", e) + except SignalException as e: + logger.warning("Received %s death signal, shutting down workers", e.sigval) + self._shutdown(e.sigval) + shutdown_called = True + raise + finally: + if not shutdown_called: + self._shutdown() + # record the execution time in case there were any exceptions during run. + self._total_execution_time = int(time.monotonic() - start_time) + + def get_event_failed(self) -> Event: + return self._construct_event( + state="FAILED", + source=EventSource.AGENT, + raw_error=traceback.format_exc(), + ) + + def get_event_succeeded(self) -> Event: + return self._construct_event( + state="SUCCEEDED", + source=EventSource.AGENT, + ) + + def _record_worker_events(self, result: RunResult) -> None: + for worker in self._worker_group.workers: + failure = result.failures.get(worker.global_rank) + state: str = self._get_worker_state(worker, result) + raw_error = json.dumps(failure.error_file_data) if failure else None + record(self._construct_event(state, EventSource.WORKER, worker, raw_error)) + + def _get_worker_state(self, worker: Worker, result: RunResult) -> str: + failure = result.failures.get(worker.global_rank) + if result.state in {WorkerState.UNHEALTHY, WorkerState.FAILED} and not failure: + # The worker got terminated by the torchelastic agent via SIGTERM signal + return "TERMINATED" + elif failure or worker.global_rank in result.return_values: + return result.state.value + else: + raise ValueError(f"Unknown worker: {worker.global_rank}") + + @contextmanager + def record_duration(self, state: str): + start_time = time.perf_counter() + try: + yield + finally: + end_time = time.perf_counter() + duration_ms = (end_time - start_time) * 1000 + record( + self._construct_event( + state=state, source=EventSource.AGENT, duration_ms=duration_ms + ) + ) + + def _construct_event( + self, + state: str, + source: EventSource, + worker: Optional[Worker] = None, + raw_error: Optional[str] = None, + duration_ms: Optional[float] = None, + ) -> Event: + wg = self._worker_group + spec = wg.spec + md = { + "group_world_size": wg.group_world_size, + "entry_point": spec.get_entrypoint_name(), + } + if worker: + md["local_rank"] = (worker.local_rank,) + md["role_rank"] = (worker.role_rank,) + md["role_world_size"] = (worker.role_world_size,) + global_rank = worker.global_rank + worker_id = str(worker.id) + else: + global_rank = None + worker_id = None + md_str = json.dumps(md) + metadata = { + "run_id": spec.rdzv_handler.get_run_id(), + "global_rank": global_rank, + "group_rank": wg.group_rank, + "worker_id": worker_id, + "role": spec.role, + "hostname": _get_fq_hostname(), + "state": state, + "total_run_time": self._total_execution_time, + "rdzv_backend": spec.rdzv_handler.get_backend(), + "raw_error": raw_error, + "metadata": md_str, + "agent_restarts": spec.max_restarts - self._remaining_restarts, + "duration_ms": duration_ms, + } + return Event( + f"torchelastic.worker.status.{state}", source=source, metadata=metadata + ) + + def _record_metrics(self, group_results: RunResult): + is_failed = group_results.is_failed() + self._record_flakiness_metric(is_failed) + spec = self._worker_group.spec + restarts_happened = self._remaining_restarts != spec.max_restarts + put_metric(f"workers.{spec.role}.run_total", 1) + self._record_metric_with_condition( + "run_success_with_retries", not is_failed and restarts_happened + ) + self._record_metric_with_condition( + "run_success_no_retries", not is_failed and not restarts_happened + ) + self._record_metric_with_condition( + "run_failed_with_retries", is_failed and restarts_happened + ) + self._record_metric_with_condition( + "run_failed_no_retries", is_failed and not restarts_happened + ) + + def _record_metric_with_condition(self, metric_name, condition): + spec = self._worker_group.spec + if condition: + put_metric(f"workers.{spec.role}.{metric_name}", 1) + else: + put_metric(f"workers.{spec.role}.{metric_name}", 0) + + def _record_flakiness_metric(self, is_failed: bool = False): + if is_failed: + flakiness = 100.0 + else: + spec = self._worker_group.spec + flakiness = 100.0 - 100.0 * (self._remaining_restarts + 1) / ( + spec.max_restarts + 1 + ) + spec = self._worker_group.spec + + put_metric(f"workers.{spec.role}.flakiness", int(flakiness)) + + def _invoke_run(self, role: str = DEFAULT_ROLE) -> RunResult: + # NOTE: currently only works for a single role + + spec = self._worker_group.spec + role = spec.role + + logger.info( + "[%s] starting workers for entrypoint: %s", role, spec.get_entrypoint_name() + ) + + self._initialize_workers(self._worker_group) + monitor_interval = spec.monitor_interval + rdzv_handler = spec.rdzv_handler + + while True: + assert self._worker_group.state != WorkerState.INIT + time.sleep(monitor_interval) + run_result = self._monitor_workers(self._worker_group) + state = run_result.state + self._worker_group.state = state + + put_metric(f"workers.{role}.remaining_restarts", self._remaining_restarts) + put_metric(f"workers.{role}.{state.name.lower()}", 1) + + if state == WorkerState.SUCCEEDED: + logger.info( + "[%s] worker group successfully finished." + " Waiting %s seconds for other agents to finish.", + role, + self._exit_barrier_timeout, + ) + self._exit_barrier() + return run_result + elif state in {WorkerState.UNHEALTHY, WorkerState.FAILED}: + if self._remaining_restarts > 0: + logger.info( + "[%s] Worker group %s. " + "%s/%s attempts left;" + " will restart worker group", + role, + state.name, + self._remaining_restarts, + spec.max_restarts, + ) + self._remaining_restarts -= 1 + self._restart_workers(self._worker_group) + else: + self._stop_workers(self._worker_group) + self._worker_group.state = WorkerState.FAILED + return run_result + elif state == WorkerState.HEALTHY: + # membership changes do not count as retries + num_nodes_waiting = rdzv_handler.num_nodes_waiting() + group_rank = self._worker_group.group_rank + if num_nodes_waiting > 0: + logger.info( + "[%s] Detected %s " + "new nodes from group_rank=%s; " + "will restart worker group", + role, + num_nodes_waiting, + group_rank, + ) + self._restart_workers(self._worker_group) + else: + raise Exception( # noqa: TRY002 + f"[{role}] Worker group in {state.name} state" + ) + + def _exit_barrier(self): + """ + Define a barrier that keeps the agent process alive until all workers finish. + + Wait for ``exit_barrier_timeout`` seconds for all agents to finish + executing their local workers (either successfully or not). This + acts as a safety guard against user scripts that terminate at different + times. + """ + logger.info( + "Local worker group finished (%s). " + "Waiting %s seconds for other agents to finish", + self._worker_group.state, + self._exit_barrier_timeout, + ) + start = time.time() + try: + store_util.barrier( + store=self._store, + world_size=self._worker_group.group_world_size, + key_prefix=_TERMINAL_STATE_SYNC_ID, + barrier_timeout=self._exit_barrier_timeout, + ) + logger.info( + "Done waiting for other agents. Elapsed: %s seconds", + time.time() - start, + ) + except SignalException as e: + logger.warning("Got termination signal: %s", e.sigval) + raise + except Exception: + logger.exception( + "Error waiting on exit barrier. Elapsed: %s seconds", + time.time() - start, + ) diff --git a/mindnlp/core/distributed/elastic/agent/server/health_check_server.py b/mindnlp/core/distributed/elastic/agent/server/health_check_server.py new file mode 100644 index 000000000..6b80e855a --- /dev/null +++ b/mindnlp/core/distributed/elastic/agent/server/health_check_server.py @@ -0,0 +1,65 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Callable + +from core.distributed.elastic.utils.logging import get_logger + + +log = get_logger(__name__) + +__all__ = ["HealthCheckServer", "create_healthcheck_server"] + + +class HealthCheckServer: + """ + Interface for health check monitoring server, which can be extended + by starting tcp/http server on the specified port. + + Args: + + alive_callback: Callable[[], int], callback to last progress time of agent + + port: int, port number to start tcp/http server + + timeout: int, timeout seconds to decide agent is alive/dead + """ + + _alive_callback: Callable[[], int] + _port: int + _timeout: int + + def __init__( + self, alive_callback: Callable[[], int], port: int, timeout: int + ) -> None: + self._alive_callback = alive_callback + self._port = port + self._timeout = timeout + + def start(self) -> None: + """ + Unsupported functionality for Pytorch, doesn't start any health check server + """ + log.warning("No health check server started") + + def stop(self) -> None: + """ + Function to stop health check server + """ + log.info("Stopping noop health check server.") + + +def create_healthcheck_server( + alive_callback: Callable[[], int], + port: int, + timeout: int, +) -> HealthCheckServer: + """ + creates health check server object + """ + return HealthCheckServer(alive_callback, port, timeout) diff --git a/mindnlp/core/distributed/elastic/agent/server/local_elastic_agent.py b/mindnlp/core/distributed/elastic/agent/server/local_elastic_agent.py new file mode 100644 index 000000000..23773e507 --- /dev/null +++ b/mindnlp/core/distributed/elastic/agent/server/local_elastic_agent.py @@ -0,0 +1,417 @@ +#!/usr/bin/env python3 +# mypy: allow-untyped-defs + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import json +import os +import signal +import socket +import time +import uuid +from string import Template +from typing import Any, Dict, Optional, Tuple, TYPE_CHECKING + +from mindnlp import core.distributed.elastic.timer as timer +from core.distributed.elastic import events +from core.distributed.elastic.agent.server.api import ( + RunResult, + SimpleElasticAgent, + WorkerGroup, + WorkerSpec, + WorkerState, +) +from core.distributed.elastic.agent.server.health_check_server import ( + create_healthcheck_server, + HealthCheckServer, +) +from core.distributed.elastic.metrics.api import prof +from core.distributed.elastic.multiprocessing import ( + LogsSpecs, + PContext, + start_processes, +) +from core.distributed.elastic.utils import macros +from core.distributed.elastic.utils.logging import get_logger + + +if TYPE_CHECKING: + from core.distributed.elastic.events.api import EventMetadataValue + +logger = get_logger(__name__) + +__all__ = [ + "LocalElasticAgent", + "TORCHELASTIC_ENABLE_FILE_TIMER", + "TORCHELASTIC_TIMER_FILE", + "TORCHELASTIC_HEALTH_CHECK_PORT", +] + +TORCHELASTIC_ENABLE_FILE_TIMER = "TORCHELASTIC_ENABLE_FILE_TIMER" +TORCHELASTIC_HEALTH_CHECK_PORT = "TORCHELASTIC_HEALTH_CHECK_PORT" +TORCHELASTIC_TIMER_FILE = "TORCHELASTIC_TIMER_FILE" + + +class LocalElasticAgent(SimpleElasticAgent): + """An implementation of :py:class:`torchelastic.agent.server.ElasticAgent` that handles host-local workers. + + This agent is deployed per host and is configured to spawn ``n`` workers. + When using GPUs, ``n`` maps to the number of GPUs available on the host. + + The local agent does not communicate to other local agents deployed on + other hosts, even if the workers may communicate inter-host. The worker id + is interpreted to be a local process. The agent starts and stops all worker + processes as a single unit. + + + The worker function and argument passed to the worker function must be + python multiprocessing compatible. To pass multiprocessing data structures + to the workers you may create the data structure in the same multiprocessing + context as the specified ``start_method`` and pass it as a function argument. + + The ``exit_barrier_timeout`` specifies the amount of time (in seconds) to wait + for other agents to finish. This acts as a safety net to handle cases where + workers finish at different times, to prevent agents from viewing workers + that finished early as a scale-down event. It is strongly advised that the + user code deal with ensuring that workers are terminated in a synchronous + manner rather than relying on the exit_barrier_timeout. + + A named pipe based watchdog can be enabled in ```LocalElasticAgent``` if an + environment variable ``TORCHELASTIC_ENABLE_FILE_TIMER`` with value 1 has + been defined in the ```LocalElasticAgent``` process. + Optionally, another environment variable ```TORCHELASTIC_TIMER_FILE``` + can be set with a unique file name for the named pipe. If the environment + variable ```TORCHELASTIC_TIMER_FILE``` is not set, ```LocalElasticAgent``` + will internally create a unique file name and set it to the environment + variable ```TORCHELASTIC_TIMER_FILE```, and this environment variable will + be propagated to the worker processes to allow them to connect to the same + named pipe that ```LocalElasticAgent``` uses. + + Logs are written to the specified log directory. Each log line will be by default + prefixed by ``[${role_name}${local_rank}]:`` (e.g. ``[trainer0]: foobar``). + Log prefixes can be customized by passing a `template string + `_ as the + ``log_line_prefix_template`` argument. + The following macros (identifiers) are substituted at runtime: + ``${role_name}, ${local_rank}, ${rank}``. For example, to prefix each log line with + global rank instead of the local rank, set ``log_line_prefix_template = "[${rank}]:``. + + + Example launching function + + :: + + def trainer(args) -> str: + return "do train" + + def main(): + start_method="spawn" + shared_queue= multiprocessing.get_context(start_method).Queue() + spec = WorkerSpec( + role="trainer", + local_world_size=nproc_per_process, + entrypoint=trainer, + args=("foobar",), + ...) + agent = LocalElasticAgent(spec, start_method) + results = agent.run() + + if results.is_failed(): + print("trainer failed") + else: + print(f"rank 0 return value: {results.return_values[0]}") + # prints -> rank 0 return value: do train + + Example launching binary + + :: + + def main(): + spec = WorkerSpec( + role="trainer", + local_world_size=nproc_per_process, + entrypoint="/usr/local/bin/trainer", + args=("--trainer-args", "foobar"), + ...) + agent = LocalElasticAgent(spec) + results = agent.run() + + if not results.is_failed(): + print("binary launches do not have return values") + + """ + + def __init__( + self, + spec: WorkerSpec, + logs_specs: LogsSpecs, + start_method="spawn", + exit_barrier_timeout: float = 300, + log_line_prefix_template: Optional[str] = None, + ): + super().__init__(spec, exit_barrier_timeout) + self._start_method = start_method + self._pcontext: Optional[PContext] = None + self._rdzv_handler = spec.rdzv_handler + self._log_line_prefix_template = log_line_prefix_template + self._worker_watchdog: Optional[timer.FileTimerServer] = None + self._logs_specs = logs_specs + self._health_check_server: Optional[HealthCheckServer] = None + + def _setup_local_watchdog(self, envs: Dict[int, Dict[str, str]]) -> None: + enable_watchdog_env_name = TORCHELASTIC_ENABLE_FILE_TIMER + watchdog_enabled = os.getenv(enable_watchdog_env_name) + watchdog_file_env_name = TORCHELASTIC_TIMER_FILE + watchdog_file_path = os.getenv(watchdog_file_env_name) + if watchdog_enabled is not None and str(watchdog_enabled) == "1": + if watchdog_file_path is None: + watchdog_file_path = "/tmp/watchdog_timer_" + str(uuid.uuid4()) + logger.info("Starting a FileTimerServer with %s ...", watchdog_file_path) + if not envs: + logger.warning( + "Empty envs variables, using empty run_id for FileTimerServer" + ) + run_id = "" + else: + run_id = envs[0]["TORCHELASTIC_RUN_ID"] + self._worker_watchdog = timer.FileTimerServer( + file_path=watchdog_file_path, + run_id=run_id, + max_interval=0.1, + daemon=True, + log_event=self._log_watchdog_event, + ) + self._worker_watchdog.start() + logger.info("FileTimerServer started") + else: + logger.info( + "Environment variable '%s' not found. Do not start FileTimerServer.", + enable_watchdog_env_name, + ) + # Propagate the watchdog file env to worker processes + if watchdog_file_path is not None: + for worker_env in envs.values(): + worker_env[watchdog_file_env_name] = watchdog_file_path + + @staticmethod + def _get_current_time_secs() -> int: + return int(time.time()) + + def _setup_healthcheck(self) -> None: + healthcheck_port_env_name = TORCHELASTIC_HEALTH_CHECK_PORT + healthcheck_port = os.getenv(healthcheck_port_env_name) + if healthcheck_port is not None: + logger.info( + "Found healthcheck port %s: %s", + healthcheck_port_env_name, + healthcheck_port, + ) + if self._worker_watchdog is None: + logger.info( + "FileTimerServer doesn't exist, using current time as dummy callback" + ) + alive_callback = LocalElasticAgent._get_current_time_secs + else: + alive_callback = self._worker_watchdog.get_last_progress_time + + try: + healthcheck_port_as_int = int(healthcheck_port) + self._health_check_server = create_healthcheck_server( + alive_callback=alive_callback, + port=healthcheck_port_as_int, + timeout=60, + ) + self._health_check_server.start() + except ValueError: + logger.info( + "Invalid healthcheck port value: '%s', expecting integer. Not starting healthcheck server.", + healthcheck_port, + ) + else: + logger.info( + "Environment variable '%s' not found. Do not start health check.", + healthcheck_port_env_name, + ) + + def _get_fq_hostname(self) -> str: + return socket.getfqdn(socket.gethostname()) + + def _log_watchdog_event( + self, + name: str, + request: Optional[timer.FileTimerRequest], + ) -> None: + wg = self._worker_group + spec = wg.spec + md = {"watchdog_event": name} + if request is not None: + md["worker_pid"] = str(request.worker_pid) + md["scope_id"] = request.scope_id + md["expiration_time"] = str(request.expiration_time) + md["signal"] = str(request.signal) + md_str = json.dumps(md) + state = "RUNNING" + metadata: Dict[str, EventMetadataValue] = { + "run_id": spec.rdzv_handler.get_run_id(), + "global_rank": None, + "group_rank": wg.group_rank, + "worker_id": None, + "role": spec.role, + "hostname": self._get_fq_hostname(), + "state": state, + "total_run_time": self._total_execution_time, + "rdzv_backend": spec.rdzv_handler.get_backend(), + "raw_error": None, + "metadata": md_str, + "agent_restarts": spec.max_restarts - self._remaining_restarts, + } + # Note: The 'metadata' field of the Event is converted to a TorchelasticStatusLogEntry later. + # The 'name' field of the Event is NOT used in the TorchelasticStatusLogEntry. + event = events.Event( + name=name, source=events.EventSource.AGENT, metadata=metadata + ) + events.record(event) + + # pyre-fixme[56]: Pyre was not able to infer the type of the decorator + # `core.distributed.elastic.metrics.prof`. + @prof + def _stop_workers( + self, worker_group: WorkerGroup, is_restart: bool = False + ) -> None: + self._shutdown(is_restart=is_restart) + + # pyre-fixme[56]: Pyre was not able to infer the type of the decorator + # `core.distributed.elastic.metrics.prof`. + @prof + def _start_workers(self, worker_group: WorkerGroup) -> Dict[int, Any]: + spec = worker_group.spec + store = worker_group.store + assert store is not None + restart_count = spec.max_restarts - self._remaining_restarts + + use_agent_store: bool = spec.rdzv_handler.use_agent_store + logger.info("use_agent_store: %s", use_agent_store) + + args: Dict[int, Tuple] = {} + envs: Dict[int, Dict[str, str]] = {} + log_line_prefixes: Optional[Dict[int, str]] = ( + {} if self._log_line_prefix_template else None + ) + for worker in worker_group.workers: + local_rank = worker.local_rank + worker_env = { + "LOCAL_RANK": str(local_rank), + "RANK": str(worker.global_rank), + "GROUP_RANK": str(worker_group.group_rank), + "ROLE_RANK": str(worker.role_rank), + "ROLE_NAME": spec.role, + "LOCAL_WORLD_SIZE": str(spec.local_world_size), + "WORLD_SIZE": str(worker.world_size), + "GROUP_WORLD_SIZE": str(worker_group.group_world_size), + "ROLE_WORLD_SIZE": str(worker.role_world_size), + "MASTER_ADDR": worker_group.master_addr, + "MASTER_PORT": str(worker_group.master_port), + "TORCHELASTIC_RESTART_COUNT": str(restart_count), + "TORCHELASTIC_MAX_RESTARTS": str(spec.max_restarts), + "TORCHELASTIC_RUN_ID": spec.rdzv_handler.get_run_id(), + "TORCHELASTIC_USE_AGENT_STORE": str(use_agent_store), + "TORCH_NCCL_ASYNC_ERROR_HANDLING": os.getenv( + "TORCH_NCCL_ASYNC_ERROR_HANDLING", str(1) + ), + } + if "OMP_NUM_THREADS" in os.environ: + worker_env["OMP_NUM_THREADS"] = os.environ["OMP_NUM_THREADS"] + + if self._log_line_prefix_template: + log_line_prefix = Template( + self._log_line_prefix_template + ).safe_substitute( + role_name=spec.role, + rank=worker.global_rank, + local_rank=local_rank, + ) + log_line_prefixes[local_rank] = log_line_prefix + + envs[local_rank] = worker_env + worker_args = list(spec.args) + worker_args = macros.substitute(worker_args, str(local_rank)) + args[local_rank] = tuple(worker_args) + + self._setup_local_watchdog(envs=envs) + self._setup_healthcheck() + + assert spec.entrypoint is not None + assert self._logs_specs is not None + self._pcontext = start_processes( + name=spec.role, + entrypoint=spec.entrypoint, + args=args, + envs=envs, + logs_specs=self._logs_specs, + log_line_prefixes=log_line_prefixes, + start_method=self._start_method, + ) + + return self._pcontext.pids() + + def _shutdown( + self, death_sig: signal.Signals = signal.SIGTERM, is_restart: bool = False + ) -> None: + if self._worker_watchdog is not None: + self._worker_watchdog.stop() + self._worker_watchdog = None + if self._health_check_server is not None: + self._health_check_server.stop() + self._health_check_server = None + if self._pcontext: + self._pcontext.close(death_sig) + if not is_restart and self._rdzv_handler: + self._rdzv_handler.shutdown() + + # pyre-fixme[56]: Pyre was not able to infer the type of the decorator + # `core.distributed.elastic.metrics.prof`. + @prof + def _monitor_workers(self, worker_group: WorkerGroup) -> RunResult: + role = worker_group.spec.role + worker_pids = {w.id for w in worker_group.workers} + assert self._pcontext is not None + pc_pids = set(self._pcontext.pids().values()) + if worker_pids != pc_pids: + logger.error( + "[%s] worker pids do not match process_context pids." + " Expected: %s, actual: %s", + role, + worker_pids, + pc_pids, + ) + return RunResult(state=WorkerState.UNKNOWN) + + result = self._pcontext.wait(0) + if result: + if result.is_failed(): + # map local rank failure to global rank + worker_failures = {} + for local_rank, failure in result.failures.items(): + worker = worker_group.workers[local_rank] + worker_failures[worker.global_rank] = failure + return RunResult( + state=WorkerState.FAILED, + failures=worker_failures, + ) + else: + # copy ret_val_queue into a map with a global ranks + workers_ret_vals = {} + for local_rank, ret_val in result.return_values.items(): + worker = worker_group.workers[local_rank] + workers_ret_vals[worker.global_rank] = ret_val + return RunResult( + state=WorkerState.SUCCEEDED, + return_values=workers_ret_vals, + ) + else: + return RunResult(state=WorkerState.HEALTHY) diff --git a/mindnlp/core/distributed/elastic/control_plane.py b/mindnlp/core/distributed/elastic/control_plane.py new file mode 100644 index 000000000..d17410c2d --- /dev/null +++ b/mindnlp/core/distributed/elastic/control_plane.py @@ -0,0 +1,52 @@ +import os +from contextlib import contextmanager, ExitStack +from typing import Generator + +from core.distributed.elastic.multiprocessing.errors import record + + +__all__ = [ + "worker_main", +] + +TORCH_WORKER_SERVER_SOCKET = "TORCH_WORKER_SERVER_SOCKET" + + +@contextmanager +def _worker_server(socket_path: str) -> Generator[None, None, None]: + from core._C._distributed_c10d import _WorkerServer + + server = _WorkerServer(socket_path) + try: + yield + finally: + server.shutdown() + + +@contextmanager +@record +def worker_main() -> Generator[None, None, None]: + """ + This is a context manager that wraps your main entry function. This combines + the existing ``errors.record`` logic as well as a new ``_WorkerServer`` that + exposes handlers via a unix socket specified by + ``Torch_WORKER_SERVER_SOCKET``. + + Example + + :: + + @worker_main() + def main(): + pass + + if __name__=="__main__": + main() + + """ + with ExitStack() as stack: + socket_path = os.environ.get(TORCH_WORKER_SERVER_SOCKET) + if socket_path is not None: + stack.enter_context(_worker_server(socket_path)) + + yield diff --git a/mindnlp/core/distributed/elastic/events/__init__.py b/mindnlp/core/distributed/elastic/events/__init__.py new file mode 100644 index 000000000..f0231079c --- /dev/null +++ b/mindnlp/core/distributed/elastic/events/__init__.py @@ -0,0 +1,170 @@ +#!/usr/bin/env/python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Module contains events processing mechanisms that are integrated with the standard python logging. + +Example of usage: + +:: + + from core.distributed.elastic import events + event = events.Event(name="test_event", source=events.EventSource.WORKER, metadata={...}) + events.get_logging_handler(destination="console").info(event) + +""" + +import inspect +import logging +import os +import socket +import traceback +from typing import Dict, Optional + +from core.distributed.elastic.events.handlers import get_logging_handler + +from .api import ( # noqa: F401 + Event, + EventMetadataValue, + EventSource, + NodeState, + RdzvEvent, +) + + +_events_loggers: Dict[str, logging.Logger] = {} + + +def _get_or_create_logger(destination: str = "null") -> logging.Logger: + """ + Construct python logger based on the destination type or extends if provided. + + Available destination could be found in ``handlers.py`` file. + The constructed logger does not propagate messages to the upper level loggers, + e.g. root logger. This makes sure that a single event can be processed once. + + Args: + destination: The string representation of the event handler. + Available handlers found in ``handlers`` module + """ + global _events_loggers + + if destination not in _events_loggers: + _events_logger = logging.getLogger(f"torchelastic-events-{destination}") + _events_logger.setLevel(os.environ.get("LOGLEVEL", "INFO")) + # Do not propagate message to the root logger + _events_logger.propagate = False + + logging_handler = get_logging_handler(destination) + _events_logger.addHandler(logging_handler) + + # Add the logger to the global dictionary + _events_loggers[destination] = _events_logger + + return _events_loggers[destination] + + +def record(event: Event, destination: str = "null") -> None: + _get_or_create_logger(destination).info(event.serialize()) + + +def record_rdzv_event(event: RdzvEvent) -> None: + _get_or_create_logger("dynamic_rendezvous").info(event.serialize()) + + +def construct_and_record_rdzv_event( + run_id: str, + message: str, + node_state: NodeState, + name: str = "", + hostname: str = "", + pid: Optional[int] = None, + master_endpoint: str = "", + local_id: Optional[int] = None, + rank: Optional[int] = None, +) -> None: + """ + Initialize rendezvous event object and record its operations. + + Args: + run_id (str): The run id of the rendezvous. + message (str): The message describing the event. + node_state (NodeState): The state of the node (INIT, RUNNING, SUCCEEDED, FAILED). + name (str): Event name. (E.g. Current action being performed). + hostname (str): Hostname of the node. + pid (Optional[int]): The process id of the node. + master_endpoint (str): The master endpoint for the rendezvous store, if known. + local_id (Optional[int]): The local_id of the node, if defined in dynamic_rendezvous.py + rank (Optional[int]): The rank of the node, if known. + Returns: + None + Example: + >>> # See DynamicRendezvousHandler class + >>> def _record( + ... self, + ... message: str, + ... node_state: NodeState = NodeState.RUNNING, + ... rank: Optional[int] = None, + ... ) -> None: + ... construct_and_record_rdzv_event( + ... name=f"{self.__class__.__name__}.{get_method_name()}", + ... run_id=self._settings.run_id, + ... message=message, + ... node_state=node_state, + ... hostname=self._this_node.addr, + ... pid=self._this_node.pid, + ... local_id=self._this_node.local_id, + ... rank=rank, + ... ) + """ + # We don't want to perform an extra computation if not needed. + if isinstance(get_logging_handler("dynamic_rendezvous"), logging.NullHandler): + return + + # Set up parameters. + if not hostname: + hostname = socket.getfqdn() + if not pid: + pid = os.getpid() + + # Determines which file called this function. + callstack = inspect.stack() + filename = "no_file" + if len(callstack) > 1: + stack_depth_1 = callstack[1] + filename = os.path.basename(stack_depth_1.filename) + if not name: + name = stack_depth_1.function + + # Delete the callstack variable. If kept, this can mess with python's + # garbage collector as we are holding on to stack frame information in + # the inspect module. + del callstack + + # Set up error trace if this is an exception + if node_state == NodeState.FAILED: + error_trace = traceback.format_exc() + else: + error_trace = "" + + # Initialize event object + event = RdzvEvent( + name=f"{filename}:{name}", + run_id=run_id, + message=message, + hostname=hostname, + pid=pid, + node_state=node_state, + master_endpoint=master_endpoint, + rank=rank, + local_id=local_id, + error_trace=error_trace, + ) + + # Finally, record the event. + record_rdzv_event(event) diff --git a/mindnlp/core/distributed/elastic/events/api.py b/mindnlp/core/distributed/elastic/events/api.py new file mode 100644 index 000000000..f85fdd835 --- /dev/null +++ b/mindnlp/core/distributed/elastic/events/api.py @@ -0,0 +1,114 @@ +#!/usr/bin/env python3 +# mypy: allow-untyped-defs + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import json +from dataclasses import asdict, dataclass, field +from enum import Enum +from typing import Dict, Optional, Union + + +__all__ = ["EventSource", "Event", "NodeState", "RdzvEvent"] + +EventMetadataValue = Union[str, int, float, bool, None] + + +class EventSource(str, Enum): + """Known identifiers of the event producers.""" + + AGENT = "AGENT" + WORKER = "WORKER" + + +@dataclass +class Event: + """ + The class represents the generic event that occurs during the torchelastic job execution. + + The event can be any kind of meaningful action. + + Args: + name: event name. + source: the event producer, e.g. agent or worker + timestamp: timestamp in milliseconds when event occurred. + metadata: additional data that is associated with the event. + """ + + name: str + source: EventSource + timestamp: int = 0 + metadata: Dict[str, EventMetadataValue] = field(default_factory=dict) + + def __str__(self): + return self.serialize() + + @staticmethod + def deserialize(data: Union[str, "Event"]) -> "Event": + if isinstance(data, Event): + return data + if isinstance(data, str): + data_dict = json.loads(data) + data_dict["source"] = EventSource[data_dict["source"]] # type: ignore[possibly-undefined] + return Event(**data_dict) + + def serialize(self) -> str: + return json.dumps(asdict(self)) + + +class NodeState(str, Enum): + """The states that a node can be in rendezvous.""" + + INIT = "INIT" + RUNNING = "RUNNING" + SUCCEEDED = "SUCCEEDED" + FAILED = "FAILED" + + +@dataclass +class RdzvEvent: + """ + Dataclass to represent any rendezvous event. + + Args: + name: Event name. (E.g. Current action being performed) + run_id: The run id of the rendezvous + message: The message describing the event + hostname: Hostname of the node + pid: The process id of the node + node_state: The state of the node (INIT, RUNNING, SUCCEEDED, FAILED) + master_endpoint: The master endpoint for the rendezvous store, if known + rank: The rank of the node, if known + local_id: The local_id of the node, if defined in dynamic_rendezvous.py + error_trace: Error stack trace, if this is an error event. + """ + + name: str + run_id: str + message: str + hostname: str + pid: int + node_state: NodeState + master_endpoint: str = "" + rank: Optional[int] = None + local_id: Optional[int] = None + error_trace: str = "" + + def __str__(self): + return self.serialize() + + @staticmethod + def deserialize(data: Union[str, "RdzvEvent"]) -> "RdzvEvent": + if isinstance(data, RdzvEvent): + return data + if isinstance(data, str): + data_dict = json.loads(data) + data_dict["node_state"] = NodeState[data_dict["node_state"]] # type: ignore[possibly-undefined] + return RdzvEvent(**data_dict) + + def serialize(self) -> str: + return json.dumps(asdict(self)) diff --git a/mindnlp/core/distributed/elastic/events/handlers.py b/mindnlp/core/distributed/elastic/events/handlers.py new file mode 100644 index 000000000..51dd14280 --- /dev/null +++ b/mindnlp/core/distributed/elastic/events/handlers.py @@ -0,0 +1,22 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from typing import Dict + + +_log_handlers: Dict[str, logging.Handler] = { + "console": logging.StreamHandler(), + "dynamic_rendezvous": logging.NullHandler(), + "null": logging.NullHandler(), +} + + +def get_logging_handler(destination: str = "null") -> logging.Handler: + global _log_handlers + return _log_handlers[destination] diff --git a/mindnlp/core/distributed/elastic/metrics/__init__.py b/mindnlp/core/distributed/elastic/metrics/__init__.py new file mode 100644 index 000000000..8f15c7373 --- /dev/null +++ b/mindnlp/core/distributed/elastic/metrics/__init__.py @@ -0,0 +1,164 @@ +#!/usr/bin/env/python3 +# mypy: allow-untyped-defs + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Metrics API. + +**Overview**: + +The metrics API in torchelastic is used to publish telemetry metrics. +It is designed to be used by torchelastic's internal modules to +publish metrics for the end user with the goal of increasing visibility +and helping with debugging. However you may use the same API in your +jobs to publish metrics to the same metrics ``sink``. + +A ``metric`` can be thought of as timeseries data +and is uniquely identified by the string-valued tuple +``(metric_group, metric_name)``. + +torchelastic makes no assumptions about what a ``metric_group`` is +and what relationship it has with ``metric_name``. It is totally up +to the user to use these two fields to uniquely identify a metric. + +.. note:: The metric group ``torchelastic`` is reserved by torchelastic for + platform level metrics that it produces. + For instance torchelastic may output the latency (in milliseconds) + of a re-rendezvous operation from the agent as + ``(torchelastic, agent.rendezvous.duration.ms)`` + +A sensible way to use metric groups is to map them to a stage or module +in your job. You may also encode certain high level properties +the job such as the region or stage (dev vs prod). + +**Publish Metrics**: + +Using torchelastic's metrics API is similar to using python's logging +framework. You first have to configure a metrics handler before +trying to add metric data. + +The example below measures the latency for the ``calculate()`` function. + +:: + + import time + from mindnlp import core.distributed.elastic.metrics as metrics + + # makes all metrics other than the one from "my_module" to go /dev/null + metrics.configure(metrics.NullMetricsHandler()) + metrics.configure(metrics.ConsoleMetricsHandler(), "my_module") + + def my_method(): + start = time.time() + calculate() + end = time.time() + metrics.put_metric("calculate_latency", int(end-start), "my_module") + +You may also use the core.distributed.elastic.metrics.prof` decorator +to conveniently and succinctly profile functions + +:: + + # -- in module examples.foobar -- + + from mindnlp import core.distributed.elastic.metrics as metrics + + metrics.configure(metrics.ConsoleMetricsHandler(), "foobar") + metrics.configure(metrics.ConsoleMetricsHandler(), "Bar") + + @metrics.prof + def foo(): + pass + + class Bar(): + + @metrics.prof + def baz(): + pass + +``@metrics.prof`` will publish the following metrics +:: + + .success - 1 if the function finished successfully + .failure - 1 if the function threw an exception + .duration.ms - function duration in milliseconds + +**Configuring Metrics Handler**: + +`core.distributed.elastic.metrics.MetricHandler` is responsible for emitting +the added metric values to a particular destination. Metric groups can be +configured with different metric handlers. + +By default torchelastic emits all metrics to ``/dev/null``. +By adding the following configuration metrics, +``torchelastic`` and ``my_app`` metric groups will be printed out to +console. + +:: + + from mindnlp import core.distributed.elastic.metrics as metrics + + metrics.configure(metrics.ConsoleMetricHandler(), group = "torchelastic") + metrics.configure(metrics.ConsoleMetricHandler(), group = "my_app") + +**Writing a Custom Metric Handler**: + +If you want your metrics to be emitted to a custom location, implement +the `core.distributed.elastic.metrics.MetricHandler` interface +and configure your job to use your custom metric handler. + +Below is a toy example that prints the metrics to ``stdout`` + +:: + + from mindnlp import core.distributed.elastic.metrics as metrics + + class StdoutMetricHandler(metrics.MetricHandler): + def emit(self, metric_data): + ts = metric_data.timestamp + group = metric_data.group_name + name = metric_data.name + value = metric_data.value + print(f"[{ts}][{group}]: {name}={value}") + + metrics.configure(StdoutMetricHandler(), group="my_app") + +Now all metrics in the group ``my_app`` will be printed to stdout as: + +:: + + [1574213883.4182858][my_app]: my_metric= + [1574213940.5237644][my_app]: my_metric= + +""" + +from typing import Optional + +from .api import ( # noqa: F401 + configure, + ConsoleMetricHandler, + get_elapsed_time_ms, + getStream, + MetricData, + MetricHandler, + MetricsConfig, + NullMetricHandler, + prof, + profile, + publish_metric, + put_metric, +) + + +def initialize_metrics(cfg: Optional[MetricsConfig] = None): + pass + + +try: + from core.distributed.elastic.metrics.static_init import * # type: ignore[import] # noqa: F401 F403 +except ModuleNotFoundError: + pass diff --git a/mindnlp/core/distributed/elastic/metrics/api.py b/mindnlp/core/distributed/elastic/metrics/api.py new file mode 100644 index 000000000..81a62f66f --- /dev/null +++ b/mindnlp/core/distributed/elastic/metrics/api.py @@ -0,0 +1,216 @@ +#!/usr/bin/env python3 +# mypy: allow-untyped-defs + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import abc +import time +from collections import namedtuple +from functools import wraps +from typing import Dict, Optional +from typing_extensions import deprecated + + +__all__ = [ + "MetricsConfig", + "MetricHandler", + "ConsoleMetricHandler", + "NullMetricHandler", + "MetricStream", + "configure", + "getStream", + "prof", + "profile", + "put_metric", + "publish_metric", + "get_elapsed_time_ms", + "MetricData", +] + +MetricData = namedtuple("MetricData", ["timestamp", "group_name", "name", "value"]) + + +class MetricsConfig: + __slots__ = ["params"] + + def __init__(self, params: Optional[Dict[str, str]] = None): + self.params = params + if self.params is None: + self.params = {} + + +class MetricHandler(abc.ABC): + @abc.abstractmethod + def emit(self, metric_data: MetricData): + pass + + +class ConsoleMetricHandler(MetricHandler): + def emit(self, metric_data: MetricData): + print( + f"[{metric_data.timestamp}][{metric_data.group_name}]: {metric_data.name}={metric_data.value}" + ) + + +class NullMetricHandler(MetricHandler): + def emit(self, metric_data: MetricData): + pass + + +class MetricStream: + def __init__(self, group_name: str, handler: MetricHandler): + self.group_name = group_name + self.handler = handler + + def add_value(self, metric_name: str, metric_value: int): + self.handler.emit( + MetricData(time.time(), self.group_name, metric_name, metric_value) + ) + + +_metrics_map: Dict[str, MetricHandler] = {} +_default_metrics_handler: MetricHandler = NullMetricHandler() + + +# pyre-fixme[9]: group has type `str`; used as `None`. +def configure(handler: MetricHandler, group: Optional[str] = None): + if group is None: + global _default_metrics_handler + # pyre-fixme[9]: _default_metrics_handler has type `NullMetricHandler`; used + # as `MetricHandler`. + _default_metrics_handler = handler + else: + _metrics_map[group] = handler + + +def getStream(group: str): + if group in _metrics_map: + handler = _metrics_map[group] + else: + handler = _default_metrics_handler + return MetricStream(group, handler) + + +def _get_metric_name(fn): + qualname = fn.__qualname__ + split = qualname.split(".") + if len(split) == 1: + module = fn.__module__ + if module: + return module.split(".")[-1] + "." + split[0] + else: + return split[0] + else: + return qualname + + +def prof(fn=None, group: str = "torchelastic"): + r""" + @profile decorator publishes duration.ms, count, success, failure metrics for the function that it decorates. + + The metric name defaults to the qualified name (``class_name.def_name``) of the function. + If the function does not belong to a class, it uses the leaf module name instead. + + Usage + + :: + + @metrics.prof + def x(): + pass + + @metrics.prof(group="agent") + def y(): + pass + """ + + def wrap(f): + @wraps(f) + def wrapper(*args, **kwargs): + key = _get_metric_name(f) + try: + start = time.time() + result = f(*args, **kwargs) + put_metric(f"{key}.success", 1, group) + except Exception: + put_metric(f"{key}.failure", 1, group) + raise + finally: + put_metric(f"{key}.duration.ms", get_elapsed_time_ms(start), group) # type: ignore[possibly-undefined] + return result + + return wrapper + + if fn: + return wrap(fn) + else: + return wrap + + +@deprecated("Deprecated, use `@prof` instead", category=FutureWarning) +def profile(group=None): + """ + @profile decorator adds latency and success/failure metrics to any given function. + + Usage + + :: + + @metrics.profile("my_metric_group") + def some_function(): + """ + + def wrap(func): + @wraps(func) + def wrapper(*args, **kwargs): + try: + start_time = time.time() + result = func(*args, **kwargs) + publish_metric(group, f"{func.__name__}.success", 1) + except Exception: + publish_metric(group, f"{func.__name__}.failure", 1) + raise + finally: + publish_metric( + group, + f"{func.__name__}.duration.ms", + get_elapsed_time_ms(start_time), # type: ignore[possibly-undefined] + ) + return result + + return wrapper + + return wrap + + +def put_metric(metric_name: str, metric_value: int, metric_group: str = "torchelastic"): + """ + Publish a metric data point. + + Usage + + :: + + put_metric("metric_name", 1) + put_metric("metric_name", 1, "metric_group_name") + """ + getStream(metric_group).add_value(metric_name, metric_value) + + +@deprecated( + "Deprecated, use `put_metric(metric_group)(metric_name, metric_value)` instead", + category=FutureWarning, +) +def publish_metric(metric_group: str, metric_name: str, metric_value: int): + metric_stream = getStream(metric_group) + metric_stream.add_value(metric_name, metric_value) + + +def get_elapsed_time_ms(start_time_in_seconds: float): + """Return the elapsed time in millis from the given start time.""" + end_time = time.time() + return int((end_time - start_time_in_seconds) * 1000) diff --git a/mindnlp/core/distributed/elastic/multiprocessing/__init__.py b/mindnlp/core/distributed/elastic/multiprocessing/__init__.py new file mode 100644 index 000000000..7809f3c6f --- /dev/null +++ b/mindnlp/core/distributed/elastic/multiprocessing/__init__.py @@ -0,0 +1,233 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Library that launches and manages ``n`` copies of worker subprocesses either specified by a function or a binary. + +For functions, it uses ``core.multiprocessing`` (and therefore python +``multiprocessing``) to spawn/fork worker processes. For binaries it uses python +``subprocessing.Popen`` to create worker processes. + + +Usage 1: Launching two trainers as a function + +:: + + from core.distributed.elastic.multiprocessing import Std, start_processes + + def trainer(a, b, c): + pass # train + + + # runs two trainers + # LOCAL_RANK=0 trainer(1,2,3) + # LOCAL_RANK=1 trainer(4,5,6) + ctx = start_processes( + name="trainer", + entrypoint=trainer, + args={0: (1,2,3), 1: (4,5,6)}, + envs={0: {"LOCAL_RANK": 0}, 1: {"LOCAL_RANK": 1}}, + log_dir="/tmp/foobar", + redirects=Std.ALL, # write all worker stdout/stderr to a log file + tee={0: Std.ERR}, # tee only local rank 0's stderr to console + ) + + # waits for all copies of trainer to finish + ctx.wait() + +Usage 2: Launching 2 echo workers as a binary + +:: + + # same as invoking + # echo hello + # echo world > stdout.log + ctx = start_processes( + name="echo" + entrypoint="echo", + log_dir="/tmp/foobar", + args={0: "hello", 1: "world"}, + redirects={1: Std.OUT}, + ) + +Just like ``core.multiprocessing``, the return value of the function +:func:`start_processes` is a process context (:class:`api.PContext`). If a function +was launched, a :class:`api.MultiprocessContext` is returned and if a binary +was launched a :class:`api.SubprocessContext` is returned. Both are specific +implementations of the parent :class:`api.PContext` class. +""" + +from typing import Callable, Dict, Optional, Tuple, Union + +from core.distributed.elastic.multiprocessing.api import ( # noqa: F401 + _validate_full_rank, + DefaultLogsSpecs, + LogsDest, + LogsSpecs, + MultiprocessContext, + PContext, + ProcessFailure, + RunProcsResult, + SignalException, + Std, + SubprocessContext, + to_map, +) +from core.distributed.elastic.utils.logging import get_logger + + +__all__ = [ + "start_processes", + "MultiprocessContext", + "PContext", + "ProcessFailure", + "RunProcsResult", + "SignalException", + "Std", + "LogsDest", + "LogsSpecs", + "DefaultLogsSpecs", + "SubprocessContext", + "to_map", +] + + +def start_processes( + name: str, + entrypoint: Union[Callable, str], + args: Dict[int, Tuple], + envs: Dict[int, Dict[str, str]], + logs_specs: LogsSpecs, + log_line_prefixes: Optional[Dict[int, str]] = None, + start_method: str = "spawn", +) -> PContext: + """ + Start ``n`` copies of ``entrypoint`` processes with the provided options. + + ``entrypoint`` is either a ``Callable`` (function) or a ``str`` (binary). + The number of copies is determined by the number of entries for ``args`` and + ``envs`` arguments, which need to have the same key set. + + ``args`` and ``env`` parameters are the arguments and environment variables + to pass down to the entrypoint mapped by the replica index (local rank). + All local ranks must be accounted for. + That is, the keyset should be ``{0,1,...,(nprocs-1)}``. + + .. note:: When the ``entrypoint`` is a binary (``str``), ``args`` can only be strings. + If any other type is given, then it is casted to a string representation + (e.g. ``str(arg1)``). Furthermore, a binary failure will only write + an ``error.json`` error file if the main function is annotated with + ``core.distributed.elastic.multiprocessing.errors.record``. For function launches, + this is done by default and there is no need to manually annotate + with the ``@record`` annotation. + + ``redirects`` and ``tee`` are bitmasks specifying which std stream(s) to redirect + to a log file in the ``log_dir``. Valid mask values are defined in ``Std``. + To redirect/tee only certain local ranks, pass ``redirects`` as a map with the key as + the local rank to specify the redirect behavior for. + Any missing local ranks will default to ``Std.NONE``. + + ``tee`` acts like the unix "tee" command in that it redirects + prints to console. + To avoid worker stdout/stderr from printing to console, use the ``redirects`` parameter. + + For each process, the ``log_dir`` will contain: + + #. ``{local_rank}/error.json``: if the process failed, a file with the error info + #. ``{local_rank}/stdout.json``: if ``redirect & STDOUT == STDOUT`` + #. ``{local_rank}/stderr.json``: if ``redirect & STDERR == STDERR`` + + .. note:: It is expected that the ``log_dir`` exists, is empty, and is a directory. + + Example: + :: + + log_dir = "/tmp/test" + + # ok; two copies of foo: foo("bar0"), foo("bar1") + start_processes( + name="trainer", + entrypoint=foo, + args:{0:("bar0",), 1:("bar1",), + envs:{0:{}, 1:{}}, + log_dir=log_dir + ) + + # invalid; envs missing for local rank 1 + start_processes( + name="trainer", + entrypoint=foo, + args:{0:("bar0",), 1:("bar1",), + envs:{0:{}}, + log_dir=log_dir + ) + + # ok; two copies of /usr/bin/touch: touch file1, touch file2 + start_processes( + name="trainer", + entrypoint="/usr/bin/touch", + args:{0:("file1",), 1:("file2",), + envs:{0:{}, 1:{}}, + log_dir=log_dir + ) + + # caution; arguments casted to string, runs: + # echo "1" "2" "3" and echo "[1, 2, 3]" + start_processes( + name="trainer", + entrypoint="/usr/bin/echo", + args:{0:(1,2,3), 1:([1,2,3],), + envs:{0:{}, 1:{}}, + log_dir=log_dir + ) + + Args: + name: a human readable short name that describes what the processes are + (used as header when tee'ing stdout/stderr outputs) + entrypoint: either a ``Callable`` (function) or ``cmd`` (binary) + args: arguments to each replica + envs: env vars to each replica + log_dir: directory used to write log files + start_method: multiprocessing start method (spawn, fork, forkserver) + ignored for binaries + redirects: which std streams to redirect to a log file + tee: which std streams to redirect + print to console + local_ranks_filter: which ranks' logs to print to console + + """ + + nprocs = len(args) + _validate_full_rank(args, nprocs, "args") + _validate_full_rank(envs, nprocs, "envs") + + context: PContext + if isinstance(entrypoint, str): + context = SubprocessContext( + name=name, + entrypoint=entrypoint, + args=args, + envs=envs, + logs_specs=logs_specs, + log_line_prefixes=log_line_prefixes, + ) + else: + context = MultiprocessContext( + name=name, + entrypoint=entrypoint, + args=args, + envs=envs, + log_line_prefixes=log_line_prefixes, + start_method=start_method, + logs_specs=logs_specs, + ) + + try: + context.start() + return context + except Exception: + context.close() + raise diff --git a/mindnlp/core/distributed/elastic/multiprocessing/api.py b/mindnlp/core/distributed/elastic/multiprocessing/api.py new file mode 100644 index 000000000..198902cfe --- /dev/null +++ b/mindnlp/core/distributed/elastic/multiprocessing/api.py @@ -0,0 +1,923 @@ +#!/usr/bin/env python3 +# mypy: allow-untyped-defs + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import abc +import logging +import os +import re +import shutil +import signal +import subprocess +import sys +import tempfile +import threading +import time +from abc import ABC, abstractmethod +from contextlib import nullcontext +from dataclasses import dataclass, field +from enum import IntFlag +from multiprocessing import synchronize +from types import FrameType +from typing import Any, Callable, Dict, Optional, Set, Tuple, Union + +from mindnlp import core.multiprocessing as mp +from core.distributed.elastic.multiprocessing.errors import ProcessFailure, record +from core.distributed.elastic.multiprocessing.redirects import ( + redirect_stderr, + redirect_stdout, +) +from core.distributed.elastic.multiprocessing.subprocess_handler import ( + get_subprocess_handler, + SubprocessHandler, +) +from core.distributed.elastic.multiprocessing.tail_log import TailLog + + +IS_WINDOWS = sys.platform == "win32" +IS_MACOS = sys.platform == "darwin" + + +logger = logging.getLogger(__name__) + +__all__ = [ + "DefaultLogsSpecs", + "SignalException", + "Std", + "to_map", + "RunProcsResult", + "PContext", + "get_std_cm", + "MultiprocessContext", + "SubprocessContext", + "LogsDest", + "LogsSpecs", +] + + +class SignalException(Exception): + """ + Exception is raised inside the torchelastic agent process by the termination handler + if the death signal got received by the process. + """ + + def __init__(self, msg: str, sigval: signal.Signals) -> None: + super().__init__(msg) + self.sigval = sigval + + +def _terminate_process_handler(signum: int, frame: Optional[FrameType]) -> None: + """Termination handler that raises exceptions on the main process. + + When the process receives death signal(SIGTERM, SIGINT), this termination handler will + be invoked. It raises the ``SignalException`` exception that should be processed by the + user code. Python does not terminate process after the termination handler is finished, + so the exception should not be silently ignored, otherwise the process will never + be terminated. + """ + sigval = signal.Signals(signum) + raise SignalException(f"Process {os.getpid()} got signal: {sigval}", sigval=sigval) + + +def _get_kill_signal() -> signal.Signals: + """Get the kill signal. SIGKILL for unix, CTRL_C_EVENT for windows.""" + if IS_WINDOWS: + return signal.CTRL_C_EVENT # type: ignore[attr-defined] # noqa: F821 + else: + return signal.SIGKILL + + +def _get_default_signal() -> signal.Signals: + """Get the default termination signal. SIGTERM for unix, CTRL_C_EVENT for windows.""" + if IS_WINDOWS: + return signal.CTRL_C_EVENT # type: ignore[attr-defined] # noqa: F821 + else: + return signal.SIGTERM + + +def _validate_full_rank(d: Dict[int, Any], nprocs: int, what: str): + actual_keys = set(d.keys()) + expected_keys = set(range(nprocs)) + + if actual_keys != expected_keys: + raise RuntimeError( + f"{what}, local rank mapping mismatch," + f" expected: {expected_keys}, actual: {actual_keys}" + ) + + +_MAPPING_REGEX = r"^(\d:[0123],)*(\d:[0123])$" +_VALUE_REGEX = r"^[0123]$" + + +class Std(IntFlag): + NONE = 0 + OUT = 1 + ERR = 2 + ALL = OUT | ERR + + @classmethod + def from_str(cls, vm: str) -> Union["Std", Dict[int, "Std"]]: + """ + Example: + :: + + from_str("0") -> Std.NONE + from_str("1") -> Std.OUT + from_str("0:3,1:0,2:1,3:2") -> {0: Std.ALL, 1: Std.NONE, 2: Std.OUT, 3: Std.ERR} + + Any other input raises an exception + """ + + def to_std(v: str) -> Std: # type: ignore[return] + s = Std(int(v)) + if s in Std: + return s + # return None -> should NEVER reach here since we regex check input + + if re.match(_VALUE_REGEX, vm): # vm is a number (e.g. 0) + return to_std(vm) + elif re.match(_MAPPING_REGEX, vm): # vm is a mapping (e.g. 0:1,1:2) + d: Dict[int, Std] = {} + for m in vm.split(","): + i, v = m.split(":") + d[int(i)] = to_std(v) + return d + else: + raise ValueError( + f"{vm} does not match: <{_VALUE_REGEX}> or <{_MAPPING_REGEX}>" + ) + + +def to_map( + val_or_map: Union[Std, Dict[int, Std]], local_world_size: int +) -> Dict[int, Std]: + """ + Certain APIs take redirect settings either as a single value (e.g. apply to all + local ranks) or as an explicit user-provided mapping. This method is a convenience + method that converts a value or mapping into a mapping. + + Example: + :: + + to_map(Std.OUT, local_world_size=2) # returns: {0: Std.OUT, 1: Std.OUT} + to_map({1: Std.OUT}, local_world_size=2) # returns: {0: Std.NONE, 1: Std.OUT} + to_map({0: Std.OUT, 1: Std.OUT}, local_world_size=2) # returns: {0: Std.OUT, 1: Std.OUT} + """ + if isinstance(val_or_map, Std): + return dict.fromkeys(range(local_world_size), val_or_map) + else: + map = {} + for i in range(local_world_size): + map[i] = val_or_map.get(i, Std.NONE) + return map + + +@dataclass +class LogsDest: + """ + For each log type, holds mapping of local rank ids to file paths. + """ + + stdouts: Dict[int, str] = field(default_factory=dict) + stderrs: Dict[int, str] = field(default_factory=dict) + tee_stdouts: Dict[int, str] = field(default_factory=dict) + tee_stderrs: Dict[int, str] = field(default_factory=dict) + error_files: Dict[int, str] = field(default_factory=dict) + + +class LogsSpecs(ABC): + """ + Defines logs processing and redirection for each worker process. + + Args: + log_dir: + Base directory where logs will be written. + redirects: + Streams to redirect to files. Pass a single ``Std`` + enum to redirect for all workers, or a mapping keyed + by local_rank to selectively redirect. + tee: + Streams to duplicate to stdout/stderr. + Pass a single ``Std`` enum to duplicate streams for all workers, + or a mapping keyed by local_rank to selectively duplicate. + """ + + def __init__( + self, + log_dir: Optional[str] = None, + redirects: Union[Std, Dict[int, Std]] = Std.NONE, + tee: Union[Std, Dict[int, Std]] = Std.NONE, + local_ranks_filter: Optional[Set[int]] = None, + ) -> None: + self._root_log_dir = log_dir + self._redirects = redirects + self._tee = tee + self._local_ranks_filter = local_ranks_filter + + @abstractmethod + def reify( + self, + envs: Dict[int, Dict[str, str]], + ) -> LogsDest: + """ + Given the environment variables, builds destination of log files for each of the local ranks. + + Envs parameter contains env variables dict for each of the local ranks, where entries are defined in: + :func:`~torchelastic.distributed.elastic.agent.server.local_elastic_agent.LocalElasticAgent._start_workers`. + """ + + @property + @abstractmethod + def root_log_dir(self) -> str: + pass + + +class DefaultLogsSpecs(LogsSpecs): + """ + Default LogsSpecs implementation: + + - `log_dir` will be created if it doesn't exist + - Generates nested folders for each attempt and rank. + """ + + def __init__( + self, + log_dir: Optional[str] = None, + redirects: Union[Std, Dict[int, Std]] = Std.NONE, + tee: Union[Std, Dict[int, Std]] = Std.NONE, + local_ranks_filter: Optional[Set[int]] = None, + ) -> None: + if log_dir != os.devnull: + if not log_dir: + log_dir = tempfile.mkdtemp(prefix="torchelastic_") + elif not os.path.exists(log_dir): + os.makedirs(log_dir, exist_ok=True) + else: + if os.path.isfile(log_dir): + raise NotADirectoryError(f"log_dir: {log_dir} is a file") + super().__init__(log_dir, redirects, tee, local_ranks_filter) + # initialized only once + self._run_log_dir = None + + @property + def root_log_dir(self) -> str: + return str(self._root_log_dir) + + def _make_log_dir(self, log_dir: Optional[str], rdzv_run_id: str): + base_log_dir = log_dir or tempfile.mkdtemp(prefix="torchelastic_") + os.makedirs(base_log_dir, exist_ok=True) + dir = tempfile.mkdtemp(prefix=f"{rdzv_run_id}_", dir=base_log_dir) + logger.info("log directory set to: %s", dir) + return dir + + def reify( + self, + envs: Dict[int, Dict[str, str]], + ) -> LogsDest: + """ + Uses following scheme to build log destination paths: + + - `//attempt_//stdout.log` + - `//attempt_//stderr.log` + - `//attempt_//error.json` + """ + nprocs = len(envs) + global_env = {} # use only to query properies that are not dependent on a rank + if nprocs > 0: + global_env = envs[0] + else: + logger.warning( + "Empty envs map provided when defining logging destinations." + ) + # Keys are always defined, but values can be missing in unit tests + run_id = global_env.get("TORCHELASTIC_RUN_ID", "test_run_id") + restart_count = global_env.get("TORCHELASTIC_RESTART_COUNT", "0") + + attempt_log_dir: str = "" + if self._root_log_dir != os.devnull: + if not self._run_log_dir: + self._run_log_dir = self._make_log_dir(self._root_log_dir, run_id) + + attempt_log_dir = os.path.join(self._run_log_dir, f"attempt_{restart_count}") # type: ignore[call-overload] + shutil.rmtree(attempt_log_dir, ignore_errors=True) + os.makedirs(attempt_log_dir) + + if self._root_log_dir == os.devnull: + attempt_log_dir = os.devnull + + # create subdirs for each local rank in the logs_dir + # logs_dir + # |- 0 + # |- error.json + # |- stdout.log + # |- stderr.log + # |- ... + # |- (nprocs-1) + redirs = to_map(self._redirects, nprocs) + ts = to_map(self._tee, nprocs) + + # to tee stdout/stderr we first redirect into a file + # then tail -f stdout.log/stderr.log so add tee settings to redirects + for local_rank, tee_std in ts.items(): + redirect_std = redirs[local_rank] + redirs[local_rank] = redirect_std | tee_std + + SYS_STREAM = "" # special case to indicate to output to console + stdouts = dict.fromkeys(range(nprocs), SYS_STREAM) + stderrs = dict.fromkeys(range(nprocs), SYS_STREAM) + tee_stdouts: Dict[int, str] = {} + tee_stderrs: Dict[int, str] = {} + error_files = {} + + for local_rank in range(nprocs): + if attempt_log_dir == os.devnull: + tee_stdouts[local_rank] = os.devnull + tee_stderrs[local_rank] = os.devnull + error_files[local_rank] = os.devnull + envs[local_rank]["TORCHELASTIC_ERROR_FILE"] = "" + else: + clogdir = os.path.join(attempt_log_dir, str(local_rank)) + os.mkdir(clogdir) + + rd = redirs[local_rank] + if (rd & Std.OUT) == Std.OUT: + stdouts[local_rank] = os.path.join(clogdir, "stdout.log") + if (rd & Std.ERR) == Std.ERR: + stderrs[local_rank] = os.path.join(clogdir, "stderr.log") + + t = ts[local_rank] + if t & Std.OUT == Std.OUT: + tee_stdouts[local_rank] = stdouts[local_rank] + if t & Std.ERR == Std.ERR: + tee_stderrs[local_rank] = stderrs[local_rank] + + if ( + self._local_ranks_filter + and local_rank not in self._local_ranks_filter + ): + # If stream is tee'd, only write to file, but don't tail + if local_rank in tee_stdouts: + tee_stdouts.pop(local_rank, None) + if local_rank in tee_stderrs: + tee_stderrs.pop(local_rank, None) + + # If stream is not redirected, don't print + if stdouts[local_rank] == SYS_STREAM: + stdouts[local_rank] = os.devnull + if stderrs[local_rank] == SYS_STREAM: + stderrs[local_rank] = os.devnull + + error_file = os.path.join(clogdir, "error.json") + error_files[local_rank] = error_file + logger.info( + "Setting worker%s reply file to: %s", local_rank, error_file + ) + envs[local_rank]["TORCHELASTIC_ERROR_FILE"] = error_file + + return LogsDest(stdouts, stderrs, tee_stdouts, tee_stderrs, error_files) + + def __repr__(self) -> str: + return ( + f"DefaultLogsSpecs(root_log_dir={self._root_log_dir}, redirects={self._redirects}, " + f"tee={self._tee}, local_ranks_filter={self._local_ranks_filter})" + ) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, DefaultLogsSpecs): + return False + + return ( + self._root_log_dir == other._root_log_dir + and self._redirects == other._redirects + and self._tee == other._tee + and self._local_ranks_filter == other._local_ranks_filter + ) + + +@dataclass +class RunProcsResult: + """ + Results of a completed run of processes started with ``start_processes()``. Returned by ``PContext``. + + Note the following: + + 1. All fields are mapped by local rank + 2. ``return_values`` - only populated for functions (not the binaries). + 3. ``stdouts`` - path to stdout.log (empty string if no redirect) + 4. ``stderrs`` - path to stderr.log (empty string if no redirect) + + """ + + return_values: Dict[int, Any] = field(default_factory=dict) + failures: Dict[int, ProcessFailure] = field(default_factory=dict) + stdouts: Dict[int, str] = field(default_factory=dict) + stderrs: Dict[int, str] = field(default_factory=dict) + + def is_failed(self) -> bool: + return len(self.failures) > 0 + + +class PContext(abc.ABC): + """ + The base class that standardizes operations over a set of processes that are launched via different mechanisms. + + The name ``PContext`` is intentional to disambiguate with ``core.multiprocessing.ProcessContext``. + + .. warning:: stdouts and stderrs should ALWAYS be a superset of + tee_stdouts and tee_stderrs (respectively) this is b/c + tee is implemented as a redirect + tail -f + """ + + def __init__( + self, + name: str, + entrypoint: Union[Callable, str], + args: Dict[int, Tuple], + envs: Dict[int, Dict[str, str]], + logs_specs: LogsSpecs, + log_line_prefixes: Optional[Dict[int, str]] = None, + ): + self.name = name + # validate that all mappings have the same number of keys and + # all local ranks are accounted for + nprocs = len(args) + + # TODO log_line_prefixes can be exanded too + logs_dest = logs_specs.reify(envs) + + _validate_full_rank(logs_dest.stdouts, nprocs, "stdouts") + _validate_full_rank(logs_dest.stderrs, nprocs, "stderrs") + + self.entrypoint = entrypoint + self.args = args + self.envs = envs + self.stdouts = logs_dest.stdouts + self.stderrs = logs_dest.stderrs + self.error_files = logs_dest.error_files + self.nprocs = nprocs + + self._stdout_tail = TailLog( + name, logs_dest.tee_stdouts, sys.stdout, log_line_prefixes + ) + self._stderr_tail = TailLog( + name, logs_dest.tee_stderrs, sys.stderr, log_line_prefixes + ) + + def start(self) -> None: + """Start processes using parameters defined in the constructor.""" + if threading.current_thread() is threading.main_thread(): + signal.signal(signal.SIGTERM, _terminate_process_handler) + signal.signal(signal.SIGINT, _terminate_process_handler) + if not IS_WINDOWS: + signal.signal(signal.SIGHUP, _terminate_process_handler) + signal.signal(signal.SIGQUIT, _terminate_process_handler) + else: + logger.warning( + "Failed to register signal handlers since torchelastic is running on a child thread. " + "This could lead to orphaned worker processes if the torchrun is terminated." + ) + self._start() + self._stdout_tail.start() + self._stderr_tail.start() + + @abc.abstractmethod + def _start(self) -> None: + """Start processes using strategy defined in a particular context.""" + raise NotImplementedError + + @abc.abstractmethod + def _poll(self) -> Optional[RunProcsResult]: + """ + Poll the run status of the processes running under this context. + This method follows an "all-or-nothing" policy and returns + a ``RunProcessResults`` object if either all processes complete + successfully or any process fails. Returns ``None`` if + all processes are still running. + """ + raise NotImplementedError + + def wait(self, timeout: float = -1, period: float = 1) -> Optional[RunProcsResult]: + """ + Wait for the specified ``timeout`` seconds, polling every ``period`` seconds + for the processes to be done. Returns ``None`` if the processes are still running + on timeout expiry. Negative timeout values are interpreted as "wait-forever". + A timeout value of zero simply queries the status of the processes (e.g. equivalent + to a poll). + + ..note: Multiprocessing library registers SIGTERM and SIGINT signal handlers that raise + ``SignalException`` when the signals received. It is up to the consumer of the code + to properly handle the exception. It is important not to swallow the exception otherwise + the process would not terminate. Example of the typical workflow can be: + + .. code-block:: python + pc = start_processes(...) + try: + pc.wait(1) + .. do some other work + except SignalException as e: + pc.shutdown(e.sigval, timeout=30) + + If SIGTERM or SIGINT occurs, the code above will try to shutdown child processes by propagating + received signal. If child processes will not terminate in the timeout time, the process will send + the SIGKILL. + """ + if timeout == 0: + return self._poll() + + if timeout < 0: + timeout = sys.maxsize + + expiry = time.time() + timeout + while time.time() < expiry: + pr = self._poll() + if pr: + return pr + time.sleep(period) + + return None + + @abc.abstractmethod + def pids(self) -> Dict[int, int]: + """Return pids of processes mapped by their respective local_ranks.""" + raise NotImplementedError + + @abc.abstractmethod + def _close(self, death_sig: signal.Signals, timeout: int = 30) -> None: + r""" + Terminates all processes managed by this context and cleans up any + meta resources (e.g. redirect, error_file files). + """ + raise NotImplementedError + + def close( + self, death_sig: Optional[signal.Signals] = None, timeout: int = 30 + ) -> None: + r""" + Terminates all processes managed by this context and cleans up any + meta resources (e.g. redirect, error_file files). + + Args: + death_sig: Death signal to terminate processes. + timeout: Time to wait for processes to finish, if process is + still alive after this time, it will be terminated via SIGKILL. + """ + if not death_sig: + death_sig = _get_default_signal() + self._close(death_sig=death_sig, timeout=timeout) + if self._stdout_tail: + self._stdout_tail.stop() + if self._stderr_tail: + self._stderr_tail.stop() + + +def get_std_cm(std_rd: str, redirect_fn): + if IS_WINDOWS or IS_MACOS or not std_rd: + return nullcontext() + else: + return redirect_fn(std_rd) + + +def _wrap( + local_rank: int, + fn: Callable, + args: Dict[int, Tuple], + envs: Dict[int, Dict[str, str]], + stdout_redirects: Dict[int, str], # redirect file for stdout (to console if None) + stderr_redirects: Dict[int, str], # redirect file for stderr (to console if None) + ret_vals: Dict[int, mp.SimpleQueue], + queue_finished_reading_event: synchronize.Event, +) -> None: + # get the per-rank params up front so we fail fast if no mapping is found + args_ = args[local_rank] + env_ = envs[local_rank] + ret_val_ = ret_vals[local_rank] + + stdout_rd = stdout_redirects[local_rank] + stderr_rd = stderr_redirects[local_rank] + + stdout_cm = get_std_cm(stdout_rd, redirect_stdout) + stderr_cm = get_std_cm(stderr_rd, redirect_stderr) + + for k, v in env_.items(): + os.environ[k] = v + + with stdout_cm, stderr_cm: + ret = record(fn)(*args_) + ret_val_.put(ret) + queue_finished_reading_event.wait() + + +class MultiprocessContext(PContext): + """``PContext`` holding worker processes invoked as a function.""" + + def __init__( + self, + name: str, + entrypoint: Callable, + args: Dict[int, Tuple], + envs: Dict[int, Dict[str, str]], + start_method: str, + logs_specs: LogsSpecs, + log_line_prefixes: Optional[Dict[int, str]] = None, + ): + super().__init__( + name, + entrypoint, + args, + envs, + logs_specs, + log_line_prefixes, + ) + + self.start_method = start_method + # each ret_val queue will always contain a single element. + self._ret_vals = { + local_rank: mp.get_context(self.start_method).SimpleQueue() + for local_rank in range(self.nprocs) + } + + # see comments in ``join()`` for what this is + self._return_values: Dict[int, Any] = {} + self._pc: Optional[mp.ProcessContext] = None + # Note: set method should ONLY be invoked for the use case when all processes finished + # successfully. If any process died on event.wait() calling set() method will deadlock. + self._worker_finished_event = mp.get_context(self.start_method).Event() + + def _start(self): + if self._pc: + raise ValueError( + "The process context already initialized." + " Most likely the start method got called twice." + ) + self._pc = mp.start_processes( + fn=_wrap, + args=( + self.entrypoint, + self.args, + self.envs, + self.stdouts, + self.stderrs, + self._ret_vals, + self._worker_finished_event, + ), + nprocs=self.nprocs, + join=False, + daemon=False, + start_method=self.start_method, + ) + + def _is_done(self) -> bool: + return len(self._return_values) == self.nprocs + + def _poll(self) -> Optional[RunProcsResult]: + assert self._pc is not None # assertion for mypy type checker + + try: + # core.mp.ProcessContext Throws an Exception if some/all of + # worker processes failed + # timeout < 0 checks worker status and return immediately + # Join will never return success since we use synchronize.Event to wait + # for all processes to finish. + self._pc.join(-1) + + # IMPORTANT: we use multiprocessing.Queue to carry worker return values + # back to the parent, the worker process will wait before terminating + # until all the buffered items are fed by the feeder thread to the underlying + # pipe. Hence to prevent deadlocks on large return values, + # we opportunistically try queue.get on each join call + # See: https://docs.python.org/2/library/multiprocessing.html#all-platforms + for local_rank in range(0, self.nprocs): + return_queue = self._ret_vals[local_rank] + if not return_queue.empty(): + # save the return values temporarily into a member var + self._return_values[local_rank] = return_queue.get() + + if self._is_done(): + # we should ALWAYS have ALL the return values when all the processes are done + self._worker_finished_event.set() + + # At this point workers finished running the user function + # But the child process might still have not exited. Wait for them. + # pc.join() blocks [forever] until "a" proc exits. Loop until all of them exits. + while not self._pc.join(): + logger.debug( + "entrypoint fn finished, waiting for all child procs to exit..." + ) + + _validate_full_rank( + self._return_values, self.nprocs, "return_value queue" + ) + self.close() + return RunProcsResult( + return_values=self._return_values, + stdouts=self.stdouts, + stderrs=self.stderrs, + ) + else: + return None + except (mp.ProcessRaisedException, mp.ProcessExitedException) as e: + failed_local_rank = e.error_index + + # entrypoint for MultiprocessContext will always be a Callable + fn_name = self.entrypoint.__qualname__ # type: ignore[union-attr] + failed_proc = self._pc.processes[failed_local_rank] + error_filepath = self.error_files[failed_local_rank] + + logger.exception( + "failed (exitcode: %s)" + " local_rank: %s (pid: %s)" + " of fn: %s (start_method: %s)", + failed_proc.exitcode, + failed_local_rank, + e.pid, + fn_name, + self.start_method, + ) + + self.close() + return RunProcsResult( + failures={ + failed_local_rank: ProcessFailure( + local_rank=failed_local_rank, + pid=e.pid, + exitcode=failed_proc.exitcode, + error_file=error_filepath, + ) + }, + stdouts=self.stdouts, + stderrs=self.stderrs, + ) + + def pids(self) -> Dict[int, int]: + assert self._pc is not None # assertion for mypy type checking + return dict(enumerate(self._pc.pids())) + + def _close(self, death_sig: signal.Signals, timeout: int = 30) -> None: + if not self._pc: + return + for proc in self._pc.processes: + if proc.is_alive(): + logger.warning( + "Closing process %s via signal %s", proc.pid, death_sig.name + ) + try: + os.kill(proc.pid, death_sig) + except ProcessLookupError: + # If the process exited because of some reason, + # `ProcessLookupError` will be raised, it is safe to ignore it. + pass + end = time.monotonic() + timeout + for proc in self._pc.processes: + time_to_wait = end - time.monotonic() + if time_to_wait <= 0: + break + proc.join(time_to_wait) + for proc in self._pc.processes: + if proc.is_alive(): + logger.warning( + "Unable to shutdown process %s via %s, forcefully exiting via %s", + proc.pid, + death_sig, + _get_kill_signal(), + ) + try: + os.kill(proc.pid, _get_kill_signal()) + except ProcessLookupError: + # If the process exited because of some reason, + # `ProcessLookupError` will be raised, it is safe to ignore it. + pass + proc.join() + + +class SubprocessContext(PContext): + """``PContext`` holding worker processes invoked as a binary.""" + + def __init__( + self, + name: str, + entrypoint: str, + args: Dict[int, Tuple], + envs: Dict[int, Dict[str, str]], + logs_specs: LogsSpecs, + log_line_prefixes: Optional[Dict[int, str]] = None, + ): + super().__init__( + name, + entrypoint, + args, + envs, + logs_specs, + log_line_prefixes, + ) + + # state vector; _vdone[local_rank] -> is local_rank finished or not + self._running_local_ranks: Set[int] = set(range(self.nprocs)) + self._failures: Dict[int, ProcessFailure] = {} + self.subprocess_handlers: Dict[int, SubprocessHandler] = {} + + def _start(self): + if self.subprocess_handlers: + raise ValueError( + "The subprocess handlers already initialized. Most likely the start method got called twice." + ) + self.subprocess_handlers = { + local_rank: get_subprocess_handler( + entrypoint=self.entrypoint, # type: ignore[arg-type] # entrypoint is always a str + args=self.args[local_rank], + env=self.envs[local_rank], + stdout=self.stdouts[local_rank], + stderr=self.stderrs[local_rank], + local_rank_id=local_rank, + ) + for local_rank in range(self.nprocs) + } + + def _poll(self) -> Optional[RunProcsResult]: + done_local_ranks = set() + for local_rank in self._running_local_ranks: + handler = self.subprocess_handlers[local_rank] + exitcode = handler.proc.poll() + if exitcode is not None: + done_local_ranks.add(local_rank) + if exitcode != 0: # failed or signaled + self._failures[local_rank] = ProcessFailure( + local_rank=local_rank, + pid=handler.proc.pid, + exitcode=exitcode, + error_file=self.error_files[local_rank], + ) + # else: --> succeeded; nothing to do + + self._running_local_ranks.difference_update(done_local_ranks) + + # if ALL procs are finished or ANY have failed + if not self._running_local_ranks or self._failures: + self.close() # terminate all running procs + result = RunProcsResult( + failures=self._failures, + stdouts=self.stdouts, + stderrs=self.stderrs, + ) + if result.is_failed(): + first_failure = min(result.failures.values(), key=lambda f: f.timestamp) + logger.error( + "failed (exitcode: %s)" + " local_rank: %s (pid: %s)" + " of binary: %s", + first_failure.exitcode, + first_failure.local_rank, + first_failure.pid, + self.entrypoint, + ) + else: + # Populate return with dummy values. This provides consistency with MultiprocessingHandler + result.return_values = dict.fromkeys(range(self.nprocs)) + + return result + else: # there are no failures and procs still running + return None + + def pids(self) -> Dict[int, int]: + return { + local_rank: sh.proc.pid + for local_rank, sh in self.subprocess_handlers.items() + } + + def _close(self, death_sig: signal.Signals, timeout: int = 30) -> None: + if not self.subprocess_handlers: + return + for handler in self.subprocess_handlers.values(): + if handler.proc.poll() is None: + logger.warning( + "Sending process %s closing signal %s", + handler.proc.pid, + death_sig.name, + ) + handler.close(death_sig=death_sig) + end = time.monotonic() + timeout + for handler in self.subprocess_handlers.values(): + time_to_wait = end - time.monotonic() + if time_to_wait <= 0: + break + try: + handler.proc.wait(time_to_wait) + except subprocess.TimeoutExpired: + # Ignore the timeout expired exception, since + # the child process will be forcefully terminated via SIGKILL + pass + for handler in self.subprocess_handlers.values(): + if handler.proc.poll() is None: + logger.warning( + "Unable to shutdown process %s via %s, forcefully exiting via %s", + handler.proc.pid, + death_sig, + _get_kill_signal(), + ) + handler.close(death_sig=_get_kill_signal()) + handler.proc.wait() diff --git a/mindnlp/core/distributed/elastic/multiprocessing/errors/__init__.py b/mindnlp/core/distributed/elastic/multiprocessing/errors/__init__.py new file mode 100644 index 000000000..905716855 --- /dev/null +++ b/mindnlp/core/distributed/elastic/multiprocessing/errors/__init__.py @@ -0,0 +1,383 @@ +#!/usr/bin/env python3 +# mypy: allow-untyped-defs + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Each host in a distributed PyTorch job runs with a single TorchElastic agent, +and multiple workers (as children processes of the TorchElastic agent). +Since the workers are user-provided (your PyTorch script/job), TorchElastic +has a way to propagate errors on the trainers through the agent and up to the +scheduler, which ultimately informs the end-user about the state of the job +and applies any retry policies. + +TorchElastic categorizes errors into 3 categories: + ++----------------+----------------+--------------------------------------------------------------+ +| Category | Sub-Category | Description | ++================+================+==============================================================+ +| User Error | Input Error | invalid inputs to TorchElastic APIs (e.g. min > max nodes) | +| +----------------+--------------------------------------------------------------+ +| | Worker Failure | any failures on the worker child process | ++----------------+----------------+--------------------------------------------------------------+ +| Platform Error | n/a | failures caused by the agent | ++----------------+----------------+--------------------------------------------------------------+ +| Infra Error | n/a | failures outside the domain of the agent and workers | +| | | (e.g. host failures) | ++----------------+----------------+--------------------------------------------------------------+ + +All errors other than "Worker Failure" are either raised canonically from the +agent process or implicitly or explicitly crash the agent process. So the +standard language (python) provided exception handling strategies apply. + +Worker Failures are special because the exception/failure originates on a different +process from the agent so the error needs to be propagated inter-process +(e.g. the agent cannot simply ``try-catch`` an exception raised on the worker process). + +TorchElastic agents use :func:`core.distributed.elastic.multiprocessing.start_processes` +to launch the workers which has a simple file based inter-process error propagation +built-in. + +Any function or binary entrypoint decorated with :func:`record` +will write uncaught exceptions (with the trace information) to a file specified by the +environment variable ``TORCHELASTIC_ERROR_FILE``. The parent process (e.g. agent) +sets this env var on each child it launches, then aggregates the error files for all +children, and propagates the one with the **smallest** timestamp (e.g. the **first** error). +""" + +import json +import os +import signal +import socket +import time +import warnings +from dataclasses import dataclass, field +from datetime import datetime +from functools import wraps +from string import Template +from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar + +from core.distributed.elastic.utils.logging import get_logger + +from .error_handler import ErrorHandler # noqa: F401 +from .handlers import get_error_handler # noqa: F401 + + +__all__ = [ + "ProcessFailure", + "ChildFailedError", + "record", + "ErrorHandler", + "get_error_handler", +] + +logger = get_logger(__name__) + + +JSON = Dict + +_EMPTY_ERROR_DATA = {"message": ""} +_NOT_AVAILABLE = "" + +T = TypeVar("T") + + +@dataclass +class ProcessFailure: + """ + Represent the failed process result. When the worker process fails, it may record failure root cause into the file. + + Tries to read the failure timestamp from the provided ``error_file``, + if the ``error_file`` does not exist, the timestamp is the current + timestamp (seconds since epoch). + + The ``message`` field is a concise explanation of the failure. If + the error file exists then the message is obtained from the error file. + Otherwise one is generated based on the failure signature. + + .. note:: It is assumed that the ``error_file`` is written by + ``core.distributed.elastic.multiprocessing.errors.error_handler.ErrorHandler``. + Otherwise the behavior is undefined. + + """ + + local_rank: int + pid: int + exitcode: int + error_file: str + error_file_data: JSON = field(init=False) + message: str = field(init=False) + timestamp: int = field(init=False) + + def __post_init__(self): + self.error_file_data = _EMPTY_ERROR_DATA + if os.path.isfile(self.error_file): + try: + with open(self.error_file) as fp: + self.error_file_data = json.load(fp) + logger.debug( + "User process failed with error data: %s", + json.dumps(self.error_file_data, indent=2), + ) + self.message, self.timestamp = self._get_error_data( + self.error_file_data + ) + except Exception: + logger.exception("Failed to parse reply file: %s", self.error_file) + raise + else: + self._set_no_reply_file() + + # make up an informative message if not already present + if not self.message: + # signals typically do not generate an error file message + if self.exitcode < 0: + self.message = ( + f"Signal {-self.exitcode} ({self.signal_name()})" + f" received by PID {self.pid}" + ) + else: + self.message = "To enable traceback see: https://pycore.org/docs/stable/elastic/errors.html" + + def _get_error_data(self, error_file_data: Dict[str, Any]) -> Tuple[str, int]: + message = error_file_data["message"] + if isinstance(message, str): + timestamp = int(error_file_data.get("timestamp", 0)) + else: + timestamp = int(message["extraInfo"]["timestamp"]) + return (message, timestamp) + + def _set_no_reply_file(self): + self.error_file = _NOT_AVAILABLE + self.error_file_data = _EMPTY_ERROR_DATA + self.message = "" + self.timestamp = int(time.time()) + + def signal_name(self) -> str: + if self.exitcode < 0: + # We don't want to kill the parent process trying to find the signal name. + # if the signal doesn't map to a known name, use not available. + try: + return signal.Signals(-self.exitcode).name + except Exception: + return _NOT_AVAILABLE + else: + return _NOT_AVAILABLE + + def timestamp_isoformat(self): + """Return timestamp in ISO format (YYYY-MM-DD_HH:MM:SS).""" + return datetime.fromtimestamp(self.timestamp).isoformat(sep="_") + + +GlobalRank = int + +_FAILURE_FORMAT_TEMPLATE = """[${idx}]: + time : ${time} + host : ${hostname} + rank : ${rank} (local_rank: ${local_rank}) + exitcode : ${exitcode} (pid: ${pid}) + error_file: ${error_file} + traceback : ${message}""" + +# extra new lines before and after are intentional +_MSG_FORMAT_TEMPLATE = """ +${boarder} +${title} +${section} +Failures: +${other_failures} +${section} +Root Cause (first observed failure): +${root_failure} +${boarder}""" + + +class ChildFailedError(Exception): + """ + Special exception type that can be raised from a function annotated with the + ``@record`` decorator to have the child process' (root exception) propagate + up the stack as-is (e.g. without being wrapped in the parent's traceback). + + Useful in cases where the parent is a simple nanny process + and the child (worker) processes are actually doing meaningful compute. + In this case, errors typically occur on the child process as the parent + is not doing anything non-trivial, and child errors should be propagated + to the scheduler for accurate root cause diagnostics. + + .. note:: The propagation relies on error files rather than exception handling to + support both function and binary launches. + + Example: + :: + + # process tree on a host (container) + 0: scheduler-init-process: + |- 1: torchelastic_agent: + |- 2: trainer_0 (ok) + |- 3: trainer_1 (fail) -> error.json + |- ... + |- n+2: trainer_n (ok) + |- n+3: other processes + |- ... + + In the example above, trainer 1's failure (written into error.json) is + the root cause and should be reported to the scheduler's init process. + The torchelastic agent raises a ``ChildFailedError("trainer", {1: "trainer_1/error.json"})`` + upon detecting trainer 1's failure which would propagate the contents + of trainer 1's error file to the scheduler's init process. + """ + + def __init__(self, name: str, failures: Dict[GlobalRank, ProcessFailure]): + self.name = name + self.failures = failures + assert ( + self.failures + ) # does not make sense to create a ChildFaileError with no failures + super().__init__(self.format_msg()) + + def get_first_failure(self) -> Tuple[GlobalRank, ProcessFailure]: + rank = min(self.failures.keys(), key=lambda r: self.failures[r].timestamp) + return rank, self.failures[rank] + + def format_msg(self, boarder_delim="=", section_delim="-"): + title = f"{self.name} FAILED" + root_rank, _root_failure = self.get_first_failure() + + root_failure_fmt: str = "" + other_failures_fmt: List[str] = [] + width = len(title) + for idx, (rank, failure) in enumerate(self.failures.items()): + fmt, w = self._format_failure(idx, rank, failure) + width = max(width, w) + if rank == root_rank: + root_failure_fmt = fmt + else: + other_failures_fmt.append(fmt) + + # upper boundary on width + width = min(width, 60) + + return Template(_MSG_FORMAT_TEMPLATE).substitute( + boarder=boarder_delim * width, + title=title, + section=section_delim * width, + root_failure=root_failure_fmt, + other_failures="\n".join(other_failures_fmt or [" "]), + ) + + def _format_failure( + self, idx: int, rank: int, failure: ProcessFailure + ) -> Tuple[str, int]: + # failure.message is either a str (when the failure does not generate a traceback - e.g. signals) + # or a dict (json) of the form + # {"message": $ERROR_MSG, "extraInfo": {"py_callstack": $TRACEBACK, timestamp: $TS}} + # so the display logic is: + # 1. if failure.message is not a dict (it is a str) just show it as is + # 2. else try to get the traceback (py_callstack) + # 3. if the traceback is not there, use the message + # 4. if the message is not there show + msg = failure.message + if isinstance(failure.message, dict): + msg = ( + failure.message.get("extraInfo", {}) + .get("py_callstack", failure.message.get("message", "")) + .replace("\n", "\n ") # to properly indent the traceback + ) + + fmt = Template(_FAILURE_FORMAT_TEMPLATE).substitute( + idx=idx, + time=failure.timestamp_isoformat(), + hostname=socket.getfqdn(), + rank=rank, + local_rank=failure.local_rank, + exitcode=failure.exitcode, + pid=failure.pid, + error_file=failure.error_file, + message=msg, + ) + width = 0 + for line in fmt.split("\n"): + width = max(width, len(line)) + return fmt, width + + +def record( + fn: Callable[..., T], error_handler: Optional[ErrorHandler] = None +) -> Callable[..., T]: + """ + Syntactic sugar to record errors/exceptions that happened in the decorated + function using the provided ``error_handler``. + + Using this decorator is equivalent to: + + :: + + error_handler = get_error_handler() + error_handler.initialize() + try: + foobar() + except ChildFailedError as e: + _, failure = e.get_first_failure() + error_handler.dump_error_file(failure.error_file, failure.exitcode) + raise + except Exception as e: + error_handler.record(e) + raise + + .. important:: use this decorator once per process at the top level method, + typically this is the main method. + + Example + + :: + + @record + def main(): + pass + + if __name__=="__main__": + main() + + """ + if not error_handler: + error_handler = get_error_handler() + + def wrap(f): + @wraps(f) + def wrapper(*args, **kwargs): + assert error_handler is not None # assertion for mypy type checker + error_handler.initialize() + try: + return f(*args, **kwargs) + except SystemExit as se: + # For run_path based entrypoints, SystemExit with code = 0 will never exit. + # Handling it here by returning a value: + if se.code == 0: + return None + else: + raise + except ChildFailedError as e: + rank, failure = e.get_first_failure() + if failure.error_file != _NOT_AVAILABLE: + error_handler.dump_error_file(failure.error_file, failure.exitcode) + else: + logger.info( + ( + "local_rank %s FAILED with no error file." + " Decorate your entrypoint fn with @record for traceback info." + " See: https://pycore.org/docs/stable/elastic/errors.html", + rank, + ) + ) + raise + except Exception as e: + error_handler.record_exception(e) + raise + + return wrapper + + return wrap(fn) diff --git a/mindnlp/core/distributed/elastic/multiprocessing/errors/error_handler.py b/mindnlp/core/distributed/elastic/multiprocessing/errors/error_handler.py new file mode 100644 index 000000000..f0fb72ac9 --- /dev/null +++ b/mindnlp/core/distributed/elastic/multiprocessing/errors/error_handler.py @@ -0,0 +1,166 @@ +#!/usr/bin/env python3 +# mypy: allow-untyped-defs + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import faulthandler +import json +import logging +import os +import time +import traceback +import warnings +from typing import Any, Dict, Optional + + +__all__ = ["ErrorHandler"] + +logger = logging.getLogger(__name__) + + +class ErrorHandler: + """ + Write the provided exception object along with some other metadata about + the error in a structured way in JSON format to an error file specified by the + environment variable: ``TORCHELASTIC_ERROR_FILE``. If this environment + variable is not set, then simply logs the contents of what would have been + written to the error file. + + This handler may be subclassed to customize the handling of the error. + Subclasses should override ``initialize()`` and ``record_exception()``. + """ + + def _get_error_file_path(self) -> Optional[str]: + """ + Return the error file path. + + May return ``None`` to have the structured error be logged only. + """ + return os.environ.get("TORCHELASTIC_ERROR_FILE", None) + + def initialize(self) -> None: + """ + Call prior to running code that we wish to capture errors/exceptions. + + Typically registers signal/fault handlers. Users can override this + function to add custom initialization/registrations that aid in + propagation/information of errors/signals/exceptions/faults. + """ + try: + faulthandler.enable(all_threads=True) + except Exception as e: + warnings.warn(f"Unable to enable fault handler. {type(e).__name__}: {e}") + + def _write_error_file(self, file_path: str, error_msg: str) -> None: + """Write error message to the file.""" + try: + with open(file_path, "w") as fp: + fp.write(error_msg) + except Exception as e: + warnings.warn(f"Unable to write error to file. {type(e).__name__}: {e}") + + def record_exception(self, e: BaseException) -> None: + """ + Write a structured information about the exception into an error file in JSON format. + + If the error file cannot be determined, then logs the content + that would have been written to the error file. + """ + file = self._get_error_file_path() + if file: + data = { + "message": { + "message": f"{type(e).__name__}: {e}", + "extraInfo": { + "py_callstack": traceback.format_exc(), + "timestamp": str(int(time.time())), + }, + } + } + with open(file, "w") as fp: + json.dump(data, fp) + + def override_error_code_in_rootcause_data( + self, + rootcause_error_file: str, + rootcause_error: Dict[str, Any], + error_code: int = 0, + ): + """Modify the rootcause_error read from the file, to correctly set the exit code.""" + if "message" not in rootcause_error: + logger.warning( + "child error file (%s) does not have field `message`. \n" + "cannot override error code: %s", + rootcause_error_file, + error_code, + ) + elif isinstance(rootcause_error["message"], str): + logger.warning( + "child error file (%s) has a new message format. \n" + "skipping error code override", + rootcause_error_file, + ) + else: + rootcause_error["message"]["errorCode"] = error_code + + def dump_error_file(self, rootcause_error_file: str, error_code: int = 0): + """Dump parent error file from child process's root cause error and error code.""" + with open(rootcause_error_file) as fp: + rootcause_error = json.load(fp) + # Override error code since the child process cannot capture the error code if it + # is terminated by signals like SIGSEGV. + if error_code: + self.override_error_code_in_rootcause_data( + rootcause_error_file, rootcause_error, error_code + ) + logger.debug( + "child error file (%s) contents:\n" "%s", + rootcause_error_file, + json.dumps(rootcause_error, indent=2), + ) + + my_error_file = self._get_error_file_path() + if my_error_file: + # Guard against existing error files + # This can happen when the child is created using multiprocessing + # and the same env var (TORCHELASTIC_ERROR_FILE) is used on the + # parent and child to specify the error files (respectively) + # because the env vars on the child is set in the wrapper function + # and by default the child inherits the parent's env vars, if the child + # process receives a signal before the wrapper function kicks in + # and the signal handler writes to the error file, then the child + # will write to the parent's error file. In this case just log the + # original error file contents and overwrite the error file. + self._rm(my_error_file) + self._write_error_file(my_error_file, json.dumps(rootcause_error)) + logger.info("dumped error file to parent's %s", my_error_file) + else: + logger.error( + "no error file defined for parent, to copy child error file (%s)", + rootcause_error_file, + ) + + def _rm(self, my_error_file): + if os.path.isfile(my_error_file): + # Log the contents of the original file. + with open(my_error_file) as fp: + try: + original = json.dumps(json.load(fp), indent=2) + logger.warning( + "%s already exists" + " and will be overwritten." + " Original contents:\n%s", + my_error_file, + original, + ) + except json.decoder.JSONDecodeError: + logger.warning( + "%s already exists" + " and will be overwritten." + " Unable to load original contents:\n", + my_error_file, + ) + os.remove(my_error_file) diff --git a/mindnlp/core/distributed/elastic/multiprocessing/errors/handlers.py b/mindnlp/core/distributed/elastic/multiprocessing/errors/handlers.py new file mode 100644 index 000000000..688fcca8a --- /dev/null +++ b/mindnlp/core/distributed/elastic/multiprocessing/errors/handlers.py @@ -0,0 +1,19 @@ +#!/usr/bin/env python3 +# mypy: allow-untyped-defs + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# Multiprocessing error-reporting module + + +from core.distributed.elastic.multiprocessing.errors.error_handler import ErrorHandler + + +__all__ = ["get_error_handler"] + + +def get_error_handler(): + return ErrorHandler() diff --git a/mindnlp/core/distributed/elastic/multiprocessing/redirects.py b/mindnlp/core/distributed/elastic/multiprocessing/redirects.py new file mode 100644 index 000000000..4553fbebd --- /dev/null +++ b/mindnlp/core/distributed/elastic/multiprocessing/redirects.py @@ -0,0 +1,104 @@ +# mypy: allow-untyped-defs +# !/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Taken and modified from original source: +# https://eli.thegreenplace.net/2015/redirecting-all-kinds-of-stdout-in-python/ +import ctypes +import logging +import os +import sys +from contextlib import contextmanager +from functools import partial + + +IS_WINDOWS = sys.platform == "win32" +IS_MACOS = sys.platform == "darwin" + + +logger = logging.getLogger(__name__) + + +def get_libc(): + if IS_WINDOWS or IS_MACOS: + logger.warning( + "NOTE: Redirects are currently not supported in Windows or MacOs." + ) + return None + else: + return ctypes.CDLL("libc.so.6") + + +libc = get_libc() + + +def _c_std(stream: str): + return ctypes.c_void_p.in_dll(libc, stream) + + +def _python_std(stream: str): + return {"stdout": sys.stdout, "stderr": sys.stderr}[stream] + + +_VALID_STD = {"stdout", "stderr"} + + +@contextmanager +def redirect(std: str, to_file: str): + """ + Redirect ``std`` (one of ``"stdout"`` or ``"stderr"``) to a file in the path specified by ``to_file``. + + This method redirects the underlying std file descriptor (not just python's ``sys.stdout|stderr``). + See usage for details. + + Directory of ``dst_filename`` is assumed to exist and the destination file + is overwritten if it already exists. + + .. note:: Due to buffering cross source writes are not guaranteed to + appear in wall-clock order. For instance in the example below + it is possible for the C-outputs to appear before the python + outputs in the log file. + + Usage: + + :: + + # syntactic-sugar for redirect("stdout", "tmp/stdout.log") + with redirect_stdout("/tmp/stdout.log"): + print("python stdouts are redirected") + libc = ctypes.CDLL("libc.so.6") + libc.printf(b"c stdouts are also redirected" + os.system("echo system stdouts are also redirected") + + print("stdout restored") + + """ + if std not in _VALID_STD: + raise ValueError( + f"unknown standard stream <{std}>, must be one of {_VALID_STD}" + ) + + c_std = _c_std(std) + python_std = _python_std(std) + std_fd = python_std.fileno() + + def _redirect(dst): + libc.fflush(c_std) + python_std.flush() + os.dup2(dst.fileno(), std_fd) + + with os.fdopen(os.dup(std_fd)) as orig_std, open(to_file, mode="w+b") as dst: + _redirect(dst) + try: + yield + finally: + _redirect(orig_std) + + +redirect_stdout = partial(redirect, "stdout") +redirect_stderr = partial(redirect, "stderr") diff --git a/mindnlp/core/distributed/elastic/multiprocessing/subprocess_handler/__init__.py b/mindnlp/core/distributed/elastic/multiprocessing/subprocess_handler/__init__.py new file mode 100644 index 000000000..c0a4b06ea --- /dev/null +++ b/mindnlp/core/distributed/elastic/multiprocessing/subprocess_handler/__init__.py @@ -0,0 +1,16 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from core.distributed.elastic.multiprocessing.subprocess_handler.handlers import ( + get_subprocess_handler, +) +from core.distributed.elastic.multiprocessing.subprocess_handler.subprocess_handler import ( + SubprocessHandler, +) + + +__all__ = ["SubprocessHandler", "get_subprocess_handler"] diff --git a/mindnlp/core/distributed/elastic/multiprocessing/subprocess_handler/handlers.py b/mindnlp/core/distributed/elastic/multiprocessing/subprocess_handler/handlers.py new file mode 100644 index 000000000..6cc073c8a --- /dev/null +++ b/mindnlp/core/distributed/elastic/multiprocessing/subprocess_handler/handlers.py @@ -0,0 +1,34 @@ +#!/usr/bin/env python3 +# mypy: allow-untyped-defs + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from typing import Dict, Tuple + +from core.distributed.elastic.multiprocessing.subprocess_handler.subprocess_handler import ( + SubprocessHandler, +) + + +__all__ = ["get_subprocess_handler"] + + +def get_subprocess_handler( + entrypoint: str, + args: Tuple, + env: Dict[str, str], + stdout: str, + stderr: str, + local_rank_id: int, +): + return SubprocessHandler( + entrypoint=entrypoint, + args=args, + env=env, + stdout=stdout, + stderr=stderr, + local_rank_id=local_rank_id, + ) diff --git a/mindnlp/core/distributed/elastic/multiprocessing/subprocess_handler/subprocess_handler.py b/mindnlp/core/distributed/elastic/multiprocessing/subprocess_handler/subprocess_handler.py new file mode 100644 index 000000000..a00905af4 --- /dev/null +++ b/mindnlp/core/distributed/elastic/multiprocessing/subprocess_handler/subprocess_handler.py @@ -0,0 +1,78 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import os +import signal +import subprocess +import sys +from typing import Any, Dict, Optional, Tuple + + +__all__ = ["SubprocessHandler"] + +IS_WINDOWS = sys.platform == "win32" + + +def _get_default_signal() -> signal.Signals: + """Get the default termination signal. SIGTERM for unix, CTRL_C_EVENT for windows.""" + if IS_WINDOWS: + return signal.CTRL_C_EVENT # type: ignore[attr-defined] # noqa: F821 + else: + return signal.SIGTERM + + +class SubprocessHandler: + """ + Convenience wrapper around python's ``subprocess.Popen``. Keeps track of + meta-objects associated to the process (e.g. stdout and stderr redirect fds). + """ + + def __init__( + self, + entrypoint: str, + args: Tuple, + env: Dict[str, str], + stdout: Optional[str], + stderr: Optional[str], + local_rank_id: int, + ): + self._stdout = open(stdout, "w") if stdout else None + self._stderr = open(stderr, "w") if stderr else None + # inherit parent environment vars + env_vars = os.environ.copy() + env_vars.update(env) + + args_str = (entrypoint, *[str(e) for e in args]) + self.local_rank_id = local_rank_id + self.proc: subprocess.Popen = self._popen(args_str, env_vars) + + def _popen(self, args: Tuple, env: Dict[str, str]) -> subprocess.Popen: + kwargs: Dict[str, Any] = {} + if not IS_WINDOWS: + kwargs["start_new_session"] = True + return subprocess.Popen( + # pyre-fixme[6]: Expected `Union[typing.Sequence[Union[_PathLike[bytes], + # _PathLike[str], bytes, str]], bytes, str]` for 1st param but got + # `Tuple[str, *Tuple[Any, ...]]`. + args=args, + env=env, + stdout=self._stdout, + stderr=self._stderr, + **kwargs, + ) + + def close(self, death_sig: Optional[signal.Signals] = None) -> None: + if not death_sig: + death_sig = _get_default_signal() + if IS_WINDOWS: + self.proc.send_signal(death_sig) + else: + os.killpg(self.proc.pid, death_sig) + if self._stdout: + self._stdout.close() + if self._stderr: + self._stderr.close() diff --git a/mindnlp/core/distributed/elastic/multiprocessing/tail_log.py b/mindnlp/core/distributed/elastic/multiprocessing/tail_log.py new file mode 100644 index 000000000..9d4e649c3 --- /dev/null +++ b/mindnlp/core/distributed/elastic/multiprocessing/tail_log.py @@ -0,0 +1,158 @@ +#!/usr/bin/env python3 +# mypy: allow-untyped-defs + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os +import time +from concurrent.futures.thread import ThreadPoolExecutor +from threading import Event +from typing import Dict, List, Optional, TextIO, TYPE_CHECKING + + +if TYPE_CHECKING: + from concurrent.futures._base import Future + +__all__ = ["tail_logfile", "TailLog"] + +logger = logging.getLogger(__name__) + + +def tail_logfile( + header: str, file: str, dst: TextIO, finished: Event, interval_sec: float +): + while not os.path.exists(file): + if finished.is_set(): + return + time.sleep(interval_sec) + + with open(file, errors="replace") as fp: + while True: + line = fp.readline() + + if line: + dst.write(f"{header}{line}") + else: # reached EOF + if finished.is_set(): + # log line producer is finished + break + else: + # log line producer is still going + # wait for a bit before looping again + time.sleep(interval_sec) + + +class TailLog: + """ + Tail the given log files. + + The log files do not have to exist when the ``start()`` method is called. The tail-er will gracefully wait until + the log files are created by the producer and will tail the contents of the + log files until the ``stop()`` method is called. + + .. warning:: ``TailLog`` will wait indefinitely for the log file to be created! + + Each log file's line will be suffixed with a header of the form: ``[{name}{idx}]:``, + where the ``name`` is user-provided and ``idx`` is the index of the log file + in the ``log_files`` mapping. ``log_line_prefixes`` can be used to override the + header for each log file. + + Usage: + + :: + + log_files = {0: "/tmp/0_stdout.log", 1: "/tmp/1_stdout.log"} + tailer = TailLog("trainer", log_files, sys.stdout).start() + # actually run the trainers to produce 0_stdout.log and 1_stdout.log + run_trainers() + tailer.stop() + + # once run_trainers() start writing the ##_stdout.log files + # the tailer will print to sys.stdout: + # >>> [trainer0]:log_line1 + # >>> [trainer1]:log_line1 + # >>> [trainer0]:log_line2 + # >>> [trainer0]:log_line3 + # >>> [trainer1]:log_line2 + + .. note:: Due to buffering log lines between files may not necessarily + be printed out in order. You should configure your application's + logger to suffix each log line with a proper timestamp. + + """ + + def __init__( + self, + name: str, + log_files: Dict[int, str], + dst: TextIO, + log_line_prefixes: Optional[Dict[int, str]] = None, + interval_sec: float = 0.1, + ): + n = len(log_files) + self._threadpool = None + if n > 0: + self._threadpool = ThreadPoolExecutor( + max_workers=n, + thread_name_prefix=f"{self.__class__.__qualname__}_{name}", + ) + + self._name = name + self._dst = dst + self._log_files = log_files + self._log_line_prefixes = log_line_prefixes + self._finished_events: Dict[int, Event] = { + local_rank: Event() for local_rank in log_files.keys() + } + self._futs: List[Future] = [] + self._interval_sec = interval_sec + self._stopped = False + + def start(self) -> "TailLog": + if not self._threadpool: + return self + + for local_rank, file in self._log_files.items(): + header = f"[{self._name}{local_rank}]:" + if self._log_line_prefixes and local_rank in self._log_line_prefixes: + header = self._log_line_prefixes[local_rank] + self._futs.append( + self._threadpool.submit( + tail_logfile, + header=header, + file=file, + dst=self._dst, + finished=self._finished_events[local_rank], + interval_sec=self._interval_sec, + ) + ) + return self + + def stop(self) -> None: + for finished in self._finished_events.values(): + finished.set() + + for local_rank, f in enumerate(self._futs): + try: + f.result() + except Exception as e: + logger.error( + "error in log tailor for %s%s. %s: %s", + self._name, + local_rank, + e.__class__.__qualname__, + e, + ) + + if self._threadpool: + self._threadpool.shutdown(wait=True) + + self._stopped = True + + def stopped(self) -> bool: + return self._stopped diff --git a/mindnlp/core/distributed/elastic/rendezvous/__init__.py b/mindnlp/core/distributed/elastic/rendezvous/__init__.py new file mode 100644 index 000000000..b4b62cbb0 --- /dev/null +++ b/mindnlp/core/distributed/elastic/rendezvous/__init__.py @@ -0,0 +1,167 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +In the context of Torch Distributed Elastic we use the term *rendezvous* to +refer to a particular functionality that combines a **distributed +synchronization** primitive with **peer discovery**. + +It is used by Torch Distributed Elastic to gather participants of a training +job (i.e. nodes) such that they all agree on the same list of participants and +everyone's roles, as well as make a consistent collective decision on when +training can begin/resume. + +Torch Distributed Elastic rendezvous provides the following critical +functionalities: + +**Barrier**: + +Nodes performing rendezvous will all block until the rendezvous is considered +complete - this happens when at least ``min`` total number of nodes have joined +the rendezvous barrier (for the same job). This also implies the barrier is not +necessarily of fixed size. + +There's an additional small waiting time after reaching ``min`` number of +nodes - this is used to ensure the rendezvous is not completed "too quickly" +(which could potentially exclude additional nodes attempting to join at +approximately the same time). + +If ``max`` number of nodes is gathered at the barrier, the rendezvous is +completed immediately. + +There's also an overall timeout which causes the rendezvous to fail if ``min`` +number of nodes is never reached - this is meant to be a simple fail-safe to +help release partially allocated job resources, in case there's a problem with +the resource manager, and is meant to be interpreted as non-retryable. + +**Exclusivity**: + +A simple distributed barrier would not be sufficient, as we also need to ensure +that only one group of nodes exists at any given time (for a given job). In +other words, new nodes (i.e. joining late) should not be able to form a parallel +independent group of workers for the same job. + +Torch Distributed Elastic rendezvous ensures that if a group of nodes has +already completed a rendezvous (and hence might already be training), then +additional "late" nodes attempting to rendezvous will only announce themselves +as waiting, and will have to wait until the (previously completed) existing +rendezvous is destroyed first. + +**Consistency**: + +When a rendezvous is completed, all its members will agree on the job membership +and everyone's role in it. This role is represented using an integer, called +rank, that is between between 0 and world size. + +Note that ranks are *not stable*, in the sense that the same node can be +assigned a different rank in the next (re-)rendezvous. + +**Fault-tolerance**: + +Torch Distributed Elastic rendezvous is designed to tolerate node failures +during the rendezvous process. Should a process crash (or lose network +connectivity, etc), between joining the rendezvous and it being completed, then +a re-rendezvous with remaining healthy nodes will happen automatically. + +A node can also fail *after* it has completed (or *has been observered* by other +nodes to have completed) the rendezvous - this scenario will be handled by the +Torch Distributed Elastic ``train_loop`` instead (where it will also trigger a +re-rendezvous). + +**Shared key-value store**: + +When the rendezvous is completed, a shared key-value store is created and +returned. This store implements a ``core.distributed.Store`` API (see +`distributed communication docs +`__). + +This store is only shared by the members of the completed rendezvous. It +is intended to be used by Torch Distributed Elastic to exchange information +necessary to initialize job control and data-planes. + +**Waiting workers and rendezvous closing**: + +Torch Distributed Elastic rendezvous handler object provides additional +functionalities, which are technically not part of the rendezvous process: + +1. Querying how many workers arrived late at the barrier, who can participate in + *next* rendezvous. + +2. Setting the rendezvous *closed* to signal all nodes not to participate in + next rendezvous. + +**DynamicRendezvousHandler**: + +Torch Distributed Elastic comes with the :py:class:`.DynamicRendezvousHandler` +class that implements the rendezvous mechanism described above. It is a backend- +agnostic type that expects a particular :py:class:`.RendezvousBackend` instance +to be specified during construction. + +Torch distributed users can either implement their own backend type or use one +of the following implementations that come with PyTorch: + +- :py:class:`.C10dRendezvousBackend`: Uses a C10d store (by default + ``TCPStore``) as the rendezvous backend. The main advantage of using a C10d + store is that it requires no 3rd-party dependency (such as etcd) to establish + a rendezvous. +- :py:class:`.EtcdRendezvousBackend`: Supersedes the legacy + :py:class:`.EtcdRendezvousHandler` class. Passing an + :py:class:`.EtcdRendezvousBackend` instance to + :py:class:`.DynamicRendezvousHandler` is functionally equivalent to + instantiating an :py:class:`.EtcdRendezvousHandler`. + + :: + + store = TCPStore("localhost") + + backend = C10dRendezvousBackend(store, "my_run_id") + + rdzv_handler = DynamicRendezvousHandler.from_backend( + run_id="my_run_id", + store=store, + backend=backend, + min_nodes=2, + max_nodes=4 + ) +""" + +from .api import ( + rendezvous_handler_registry, + RendezvousClosedError, + RendezvousConnectionError, + RendezvousError, + RendezvousGracefulExitError, + RendezvousHandler, + RendezvousHandlerCreator, + RendezvousHandlerRegistry, + RendezvousInfo, + RendezvousParameters, + RendezvousStateError, + RendezvousStoreInfo, + RendezvousTimeoutError, +) +from .registry import _register_default_handlers, _register_out_of_tree_handlers + + +_register_default_handlers() +_register_out_of_tree_handlers() + + +__all__ = [ + "RendezvousClosedError", + "RendezvousConnectionError", + "RendezvousError", + "RendezvousGracefulExitError", + "RendezvousHandler", + "RendezvousHandlerCreator", + "RendezvousHandlerRegistry", + "RendezvousInfo", + "RendezvousParameters", + "RendezvousStateError", + "RendezvousStoreInfo", + "RendezvousTimeoutError", + "rendezvous_handler_registry", +] diff --git a/mindnlp/core/distributed/elastic/rendezvous/api.py b/mindnlp/core/distributed/elastic/rendezvous/api.py new file mode 100644 index 000000000..e94042168 --- /dev/null +++ b/mindnlp/core/distributed/elastic/rendezvous/api.py @@ -0,0 +1,384 @@ +# mypy: allow-untyped-defs +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import socket +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any, Callable, ClassVar, Dict, Optional + +from core.distributed import Store +from core.distributed.elastic.utils.distributed import get_free_port + + +__all__ = [ + "RendezvousClosedError", + "RendezvousConnectionError", + "RendezvousError", + "RendezvousGracefulExitError", + "RendezvousHandler", + "RendezvousHandlerCreator", + "RendezvousHandlerRegistry", + "RendezvousInfo", + "RendezvousParameters", + "RendezvousStateError", + "RendezvousStoreInfo", + "RendezvousTimeoutError", + "rendezvous_handler_registry", +] + + +class RendezvousError(Exception): + """Represents the base type for rendezvous errors.""" + + +class RendezvousClosedError(RendezvousError): + """Raised when a rendezvous is closed.""" + + +class RendezvousTimeoutError(RendezvousError): + """Raised when a rendezvous did not complete on time.""" + + +class RendezvousConnectionError(RendezvousError): + """Raised when the connection to a rendezvous backend has failed.""" + + +class RendezvousStateError(RendezvousError): + """Raised when the state of a rendezvous is corrupt.""" + + +class RendezvousGracefulExitError(RendezvousError): + """Raised when node wasn't not included in rendezvous and gracefully exits. + + Exception is a mechanism to exit the stack, however does not mean a failure. + """ + + +@dataclass +class RendezvousStoreInfo: + """Store address and port that can be used to bootstrap trainer distributed comms""" + + MASTER_ADDR_KEY: ClassVar[str] = "MASTER_ADDR" + MASTER_PORT_KEY: ClassVar[str] = "MASTER_PORT" + master_addr: str + master_port: int + + @staticmethod + def build( + rank: int, + store: Store, + local_addr: Optional[str], + server_port: Optional[int] = None, + ) -> "RendezvousStoreInfo": + """Factory method, finds unused new port on rank0 host and addr/port info with all ranks. + + If master_addr/master_port is knowns (useful when sharing existing tcp store server) use the constructor. + + Args: + rank: rank of the current node + store: store to use for rendezvous + local_addr: address of the current node, if not provided will be resolved from hostname + server_port: port of the TCPStore server, when the TCPStore is shared. + """ + # TODO swap to collectives comms API + if rank == 0: + addr = local_addr or socket.getfqdn() + # When TCPStore is not shared, we fallback to get_free_port. + port = server_port or get_free_port() + store.set(RendezvousStoreInfo.MASTER_ADDR_KEY, addr.encode(encoding="UTF-8")) # type: ignore[arg-type] + store.set(RendezvousStoreInfo.MASTER_PORT_KEY, str(port).encode(encoding="UTF-8")) # type: ignore[arg-type] + + addr = store.get(RendezvousStoreInfo.MASTER_ADDR_KEY).decode(encoding="UTF-8") + port = int( + store.get(RendezvousStoreInfo.MASTER_PORT_KEY).decode(encoding="UTF-8") + ) + return RendezvousStoreInfo(master_addr=addr, master_port=port) + + +class RendezvousInfo: + """Holds the information about the rendezvous.""" + + def __init__( + self, + store: Store, + rank: int, + world_size: int, + bootstrap_store_info: RendezvousStoreInfo, + ): + self._store = store + self._rank = rank + self._world_size = world_size + self._bootstrap_store_info = bootstrap_store_info + + @property + def store(self) -> Store: + """Store used by torchelastic control plane""" + return self._store + + @property + def rank(self) -> int: + """Rank within a group""" + return self._rank + + @property + def world_size(self) -> int: + """Global group size""" + return self._world_size + + @property + def bootstrap_store_info(self) -> Optional[RendezvousStoreInfo]: + """Store information that can used by trainer code to bootstrap distributed comms.""" + return self._bootstrap_store_info + + +class RendezvousHandler(ABC): + """Main rendezvous interface. + + Note: + Distributed Torch users normally **do not** need to implement their own + ``RendezvousHandler``. An implementation based on C10d Store is already + provided, and is recommended for most users. + """ + + @abstractmethod + def get_backend(self) -> str: + """Return the name of the rendezvous backend.""" + + @property + def use_agent_store(self) -> bool: + """Indicates that store reference returned by :py:meth:`next_rendezvous` can be shared with user + applications and will be available during application lifecyle. + + Rendezous handler impl will share store details as instance of :py:class:`RendezvousStoreInfo`. + Applications as a convention use `MASTER_ADDR`/`MASTER_PORT` env variables to lookup the store. + """ + return False + + @abstractmethod + def next_rendezvous(self) -> RendezvousInfo: + """Main entry-point into the rendezvous barrier. + + Blocks until the rendezvous is complete and the current process is + included in the formed worker group, or a timeout occurs, or the + rendezvous was marked closed. + + Returns: + Instance of :py:class:`RendezvousInfo`. + + Raises: + RendezvousClosedError: + The rendezvous is closed. + RendezvousConnectionError: + The connection to the rendezvous backend has failed. + RendezvousStateError: + The rendezvous state is corrupt. + RendezvousTimeoutError: + The rendezvous did not complete on time. + """ + + @abstractmethod + def is_closed(self) -> bool: + """Check whether the rendezvous has been closed. + + A closed rendezvous means all future attempts to re-rendezvous within + same job will fail. + + ``is_closed()`` and :py:meth:`set_closed` have semantics of eventual + propagation and should not be used for synchronization. The intention is + that if at least one node decides the job is finished, it will close the + rendezvous, and other nodes will soon observe this and stop running as + well. + """ + + @abstractmethod + def set_closed(self): + """Mark the rendezvous as closed.""" + + @abstractmethod + def num_nodes_waiting(self) -> int: + """Return the number of nodes who arrived late at the rendezvous + barrier, hence were not included in the current worker group. + + Callers should periodically call this method to check whether new + nodes are waiting to join the job and if so admit them by calling + :py:meth:`next_rendezvous()` (re-rendezvous). + """ + + @abstractmethod + def get_run_id(self) -> str: + """Return the run id of the rendezvous. + + The run id is a user-defined id that uniquely identifies an instance of + a distributed application. It typically maps to a job id and is used to + allow nodes to join the correct distributed application. + """ + + @abstractmethod + def shutdown(self) -> bool: + """Close all resources that were open for the rendezvous. + + Example:: + + rdzv_handler = ... + try: + store, rank, world_size = rdzv_handler.next_rendezvous() + finally: + rdzv_handler.shutdown() + """ + + +class RendezvousParameters: + """Hold the parameters to construct a :py:class:`RendezvousHandler`. + + Args: + backend: + The name of the backend to use to handle the rendezvous. + endpoint: + The endpoint of the rendezvous, usually in form [:]. + run_id: + The id of the rendezvous. + min_nodes: + The minimum number of nodes to admit to the rendezvous. + max_nodes: + The maximum number of nodes to admit to the rendezvous. + local_addr: + The address of the local node. + **kwargs: + Additional parameters for the specified backend. + """ + + def __init__( + self, + backend: str, + endpoint: str, + run_id: str, + min_nodes: int, + max_nodes: int, + local_addr: Optional[str] = None, + **kwargs, + ): + if not backend: + raise ValueError("The rendezvous backend name must be a non-empty string.") + + if min_nodes < 1: + raise ValueError( + f"The minimum number of rendezvous nodes ({min_nodes}) must be greater than zero." + ) + if max_nodes < min_nodes: + raise ValueError( + f"The maximum number of rendezvous nodes ({max_nodes}) must be greater than or " + f"equal to the minimum number of rendezvous nodes ({min_nodes})." + ) + + self.backend = backend + self.endpoint = endpoint + self.run_id = run_id + self.min_nodes = min_nodes + self.max_nodes = max_nodes + self.config = kwargs + self.local_addr = local_addr + + def get(self, key: str, default: Any = None) -> Any: + """Return the value for ``key`` if ``key`` exists, else ``default``.""" + return self.config.get(key, default) + + def get_as_bool(self, key: str, default: Optional[bool] = None) -> Optional[bool]: + """Return the value for ``key`` as a ``bool``.""" + value = self.get(key, default) + if value is None or isinstance(value, bool): + return value + if isinstance(value, int): + if value == 1: + return True + if value == 0: + return False + elif isinstance(value, str): + if value.lower() in ["1", "true", "t", "yes", "y"]: + return True + if value.lower() in ["0", "false", "f", "no", "n"]: + return False + raise ValueError( + f"The rendezvous configuration option '{key}' does not represent a valid boolean value." + ) + + def get_as_int(self, key: str, default: Optional[int] = None) -> Optional[int]: + """Return the value for ``key`` as an ``int``.""" + value = self.get(key, default) + if value is None: + return value + try: + return int(value) + except ValueError as e: + raise ValueError( + f"The rendezvous configuration option '{key}' does not represent a valid integer " + "value." + ) from e + + +RendezvousHandlerCreator = Callable[[RendezvousParameters], RendezvousHandler] + + +class RendezvousHandlerRegistry: + """Represent a registry of :py:class:`RendezvousHandler` backends.""" + + _registry: Dict[str, RendezvousHandlerCreator] + + def __init__(self) -> None: + self._registry = {} + + def register(self, backend: str, creator: RendezvousHandlerCreator) -> None: + """Register a new rendezvous backend. + + Args: + backend: + The name of the backend. + creator: + The callback to invoke to construct the + :py:class:`RendezvousHandler`. + """ + if not backend: + raise ValueError("The rendezvous backend name must be a non-empty string.") + + current_creator: Optional[RendezvousHandlerCreator] + try: + current_creator = self._registry[backend] + except KeyError: + current_creator = None + + if current_creator is not None and current_creator != creator: + raise ValueError( + f"The rendezvous backend '{backend}' cannot be registered with '{creator}' as it " + f"is already registered with '{current_creator}'." + ) + + self._registry[backend] = creator + + def create_handler(self, params: RendezvousParameters) -> RendezvousHandler: + """Create a new :py:class:`RendezvousHandler`.""" + try: + creator = self._registry[params.backend] + except KeyError as e: + raise ValueError( + f"The rendezvous backend '{params.backend}' is not registered. Did you forget " + f"to call `{self.register.__name__}`?" + ) from e + + handler = creator(params) + + # Do some sanity check. + if handler.get_backend() != params.backend: + raise RuntimeError( + f"The rendezvous backend '{handler.get_backend()}' does not match the requested " + f"backend '{params.backend}'." + ) + + return handler + + +# The default global registry instance used by launcher scripts to instantiate +# rendezvous handlers. +rendezvous_handler_registry = RendezvousHandlerRegistry() diff --git a/mindnlp/core/distributed/elastic/rendezvous/c10d_rendezvous_backend.py b/mindnlp/core/distributed/elastic/rendezvous/c10d_rendezvous_backend.py new file mode 100644 index 000000000..aa20d2bf3 --- /dev/null +++ b/mindnlp/core/distributed/elastic/rendezvous/c10d_rendezvous_backend.py @@ -0,0 +1,273 @@ +# mypy: allow-untyped-defs +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import binascii +import logging +import os +import tempfile +from base64 import b64decode, b64encode +from datetime import timedelta +from typing import Any, cast, Optional, Tuple + +from core.distributed import FileStore, Store, TCPStore +from core.distributed.elastic.events import construct_and_record_rdzv_event, NodeState + +from .api import ( + RendezvousConnectionError, + RendezvousError, + RendezvousParameters, + RendezvousStateError, +) +from .dynamic_rendezvous import RendezvousBackend, Token +from .utils import _matches_machine_hostname, parse_rendezvous_endpoint + + +logger = logging.getLogger(__name__) + +# default port for the TCP store +DEFAULT_PORT = 29400 + + +class C10dRendezvousBackend(RendezvousBackend): + """Represents a C10d-backed rendezvous backend. + + Args: + store: + The :py:class:`core.distributed.Store` instance to use to + communicate with the C10d store. + run_id: + The run id of the rendezvous. + """ + + # See the explanation in the __init__ method. + _NULL_SENTINEL = "Y2FuaW1hZGFt" + + _store: Store + _key: str + + def __init__(self, store: Store, run_id: str) -> None: + if not run_id: + raise ValueError("The run id must be a non-empty string.") + + self._store = store + + self._key = "core.rendezvous." + run_id + + # The read operation of a store blocks the caller until the specified + # key becomes available. This behavior makes it tricky to use a store + # as a regular key-value dictionary. + # + # As a workaround we initially set a sentinel value as the rendezvous + # state. Whenever this value gets returned we treat it as a None. + self._call_store("compare_set", self._key, "", self._NULL_SENTINEL) + + @property + def name(self) -> str: + """See base class.""" + return "c10d" + + def get_state(self) -> Optional[Tuple[bytes, Token]]: + """See base class.""" + base64_state: bytes = self._call_store("get", self._key) + + return self._decode_state(base64_state) + + def set_state( + self, state: bytes, token: Optional[Token] = None + ) -> Optional[Tuple[bytes, Token, bool]]: + """See base class.""" + base64_state_str: str = b64encode(state).decode() + + if token: + # Shortcut if we know for sure that the token is not valid. + if not isinstance(token, bytes): + result = self.get_state() + if result is not None: + tmp = *result, False + # Python 3.6 does not support tuple unpacking in return + # statements. + return tmp + return None + + token = token.decode() + else: + token = self._NULL_SENTINEL + + base64_state: bytes = self._call_store( + "compare_set", self._key, token, base64_state_str + ) + + state_token_pair = self._decode_state(base64_state) + if state_token_pair is None: + return None + + new_state, new_token = state_token_pair + + # C10d Store's compare_set method does not offer an easy way to find out + # whether our write attempt was successful. As a brute-force solution we + # perform a bitwise comparison of our local state and the remote state. + return new_state, new_token, new_state == state + + def _call_store(self, store_op: str, *args, **kwargs) -> Any: + try: + return getattr(self._store, store_op)(*args, **kwargs) + except (ValueError, RuntimeError, TimeoutError) as exc: + raise RendezvousConnectionError( + "The connection to the C10d store has failed. See inner exception for details." + ) from exc + + def _decode_state(self, base64_state: bytes) -> Optional[Tuple[bytes, Token]]: + if base64_state == self._NULL_SENTINEL.encode(): + return None + + try: + state = b64decode(base64_state) + except binascii.Error as exc: + raise RendezvousStateError( + "The state object is corrupt. See inner exception for details." + ) from exc + + return state, base64_state + + +def _create_tcp_store(params: RendezvousParameters) -> TCPStore: + host, port = parse_rendezvous_endpoint(params.endpoint, default_port=DEFAULT_PORT) + + cfg_is_host = params.get_as_bool("is_host") + # If the user has explicitly specified whether our process should host the + # the store, respect it. + if cfg_is_host is not None: + is_host = cfg_is_host + # Otherwise try to determine whether we are the host based on our hostname + # and IP address. + else: + is_host = _matches_machine_hostname(host) + + # The timeout + read_timeout = cast(int, params.get_as_int("read_timeout", 60)) + if read_timeout <= 0: + raise ValueError("The read timeout must be a positive integer.") + + # In specific cases we attempt to instantiate the store twice. For details + # see the explanation in the except clause below. + for is_server in [is_host, False]: + try: + store = TCPStore( + host, + port, + is_master=is_server, + multi_tenant=True, + timeout=timedelta(seconds=read_timeout), + ) + + if is_server: + msg = f"Process {os.getpid()} hosts the TCP store for the C10d rendezvous backend." + construct_and_record_rdzv_event( + run_id=params.run_id, message=msg, node_state=NodeState.INIT + ) + logger.info(msg) + + break + except (ValueError, RuntimeError, TimeoutError) as exc: + # If we heuristically inferred the value of is_host as True and our + # first attempt to instantiate the TCP store has failed, try it one + # more time with is_host set to False. As an edge case there can be + # more than one process that is part of the same rendezvous on this + # machine and only one of them will eventually host the store. + + if not is_server or cfg_is_host is not None: + raise RendezvousConnectionError( + "The connection to the C10d store has failed. See inner exception for details." + ) from exc + + return store # type: ignore[possibly-undefined] + + +def _create_file_store(params: RendezvousParameters) -> FileStore: + # If a user specifies an endpoint, we treat it as a path to a file. + if params.endpoint: + path = params.endpoint + else: + try: + # The temporary file is readable and writable only by the user of + # this process. + _, path = tempfile.mkstemp() + except OSError as exc: + raise RendezvousError( + "The file creation for C10d store has failed. See inner exception for details." + ) from exc + + try: + store = FileStore(path) + except (ValueError, RuntimeError) as exc: + raise RendezvousConnectionError( + "The connection to the C10d store has failed. See inner exception for details." + ) from exc + + return store + + +def create_backend(params: RendezvousParameters) -> Tuple[C10dRendezvousBackend, Store]: + """Create a new :py:class:`C10dRendezvousBackend` from the specified parameters. + + +--------------+-----------------------------------------------------------+ + | Parameter | Description | + +==============+===========================================================+ + | store_type | The type of the C10d store. The currently supported types | + | | are "tcp" and "file" which correspond to | + | | :py:class:`core.distributed.TCPStore` and | + | | :py:class:`core.distributed.FileStore`, respectively. | + | | Defaults to "tcp". | + +--------------+-----------------------------------------------------------+ + | read_timeout | The read timeout, in seconds, for store operations. | + | | Defaults to 60 seconds. | + | | | + | | Note this only applies to | + | | :py:class:`core.distributed.TCPStore`. It is not relevant| + | | to :py:class:`core.distributed.FileStore` which does not | + | | take in timeout as a parameter. | + +--------------+-----------------------------------------------------------+ + | is_host | A boolean value indicating whether this backend instance | + | | will host the C10d store. If not specified it will be | + | | inferred heuristically by matching the hostname or the IP | + | | address of this machine against the specified rendezvous | + | | endpoint. Defaults to ``None``. | + | | | + | | Note that this configuration option only applies to | + | | :py:class:`core.distributed.TCPStore`. In normal | + | | circumstances you can safely skip it; the only time when | + | | it is needed is if its value cannot be correctly | + | | determined (e.g. the rendezvous endpoint has a CNAME as | + | | the hostname or does not match the FQDN of the machine). | + +--------------+-----------------------------------------------------------+ + """ + # As of today we only support TCPStore and FileStore. Other store types do + # not have the required functionality (e.g. compare_set) yet. + store_type = params.get("store_type", "tcp").strip().lower() + store: Store + + try: + if store_type == "file": + store = _create_file_store(params) + elif store_type == "tcp": + store = _create_tcp_store(params) + else: + raise ValueError( + "Invalid store type given. Currently only supports file and tcp." + ) + + backend = C10dRendezvousBackend(store, params.run_id) + + except Exception as e: + construct_and_record_rdzv_event( + message=f"{type(e).__name__}: {str(e)}", + run_id=params.run_id, + node_state=NodeState.FAILED, + ) + raise + + return backend, store diff --git a/mindnlp/core/distributed/elastic/rendezvous/dynamic_rendezvous.py b/mindnlp/core/distributed/elastic/rendezvous/dynamic_rendezvous.py new file mode 100644 index 000000000..4b44561b4 --- /dev/null +++ b/mindnlp/core/distributed/elastic/rendezvous/dynamic_rendezvous.py @@ -0,0 +1,1431 @@ +# mypy: allow-untyped-defs +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import inspect +import logging +import os +import pickle +import socket +import threading +import time +import weakref +from abc import ABC, abstractmethod +from dataclasses import dataclass +from datetime import datetime, timedelta, timezone +from enum import Enum +from typing import Any, Callable, Dict, List, Optional, Set, Tuple + +from mindnlp import core.distributed as dist +from core.distributed import Store +from core.distributed.elastic.events import construct_and_record_rdzv_event, NodeState + +from .api import ( + RendezvousClosedError, + RendezvousError, + RendezvousGracefulExitError, + RendezvousHandler, + RendezvousInfo, + RendezvousParameters, + RendezvousStateError, + RendezvousStoreInfo, + RendezvousTimeoutError, +) +from .utils import _delay, _PeriodicTimer + + +__all__ = [ + "RendezvousBackend", + "RendezvousTimeout", + "RendezvousSettings", + "DynamicRendezvousHandler", + "create_handler", +] + +logger = logging.getLogger(__name__) + + +def get_method_name(depth=2): + if len(inspect.stack()) > depth: + return inspect.stack()[depth].function + return "no_method_name" + + +Token = Any +"""Represent an opaque fencing token used by the rendezvous backend.""" + + +class RendezvousBackend(ABC): + """Represent a backend that holds the rendezvous state.""" + + @property + @abstractmethod + def name(self) -> str: + """Get the name of the backend.""" + + @abstractmethod + def get_state(self) -> Optional[Tuple[bytes, Token]]: + """Get the rendezvous state. + + Returns: + A tuple of the encoded rendezvous state and its fencing token or + ``None`` if no state is found in the backend. + + Raises: + RendezvousConnectionError: + The connection to the backend has failed. + RendezvousStateError: + The rendezvous state is corrupt. + """ + + @abstractmethod + def set_state( + self, state: bytes, token: Optional[Token] = None + ) -> Optional[Tuple[bytes, Token, bool]]: + """Set the rendezvous state. + + The new rendezvous state is set conditionally: + + - If the specified ``token`` matches the fencing token stored in the + backend, the state will be updated. The new state will be returned + to the caller along with its fencing token. + - If the specified ``token`` does not match the fencing token stored + in the backend, the state won't be updated; instead the existing + state along with its fencing token will be returned to the caller. + - If the specified ``token`` is ``None``, the new state will be set + only if there is no existing state in the backend. Either the new + state or the existing state along with its fencing token will be + returned to the caller. + + Args: + state: + The encoded rendezvous state. + token: + An optional fencing token that was retrieved by a previous call + to :py:meth:`get_state` or ``set_state()``. + + Returns: + A tuple of the serialized rendezvous state, its fencing token, and + a boolean value indicating whether our set attempt succeeded. + + Raises: + RendezvousConnectionError: + The connection to the backend has failed. + RendezvousStateError: + The rendezvous state is corrupt. + """ + + +class RendezvousTimeout: + """Hold the timeout configuration of a rendezvous. + + Args: + join: + The time within which the rendezvous is expected to complete. + last_call: + An additional wait amount before completing the rendezvous once the + rendezvous has the minimum number of required participants. + close: + The time within which the rendezvous is expected to close after a + call to :py:meth:`RendezvousHandler.set_closed` or + :py:meth:`RendezvousHandler.shutdown`. + keep_alive: + The time within which a keep-alive heartbeat is expected to + complete. + """ + + _ZERO = timedelta(0) + + _DEFAULT_TIMEOUTS = { + "join": timedelta(seconds=600), + "last_call": timedelta(seconds=30), + "close": timedelta(seconds=30), + "heartbeat": timedelta(seconds=5), + } + + _join: timedelta + _last_call: timedelta + _close: timedelta + _heartbeat: timedelta + + def __init__( + self, + join: Optional[timedelta] = None, + last_call: Optional[timedelta] = None, + close: Optional[timedelta] = None, + heartbeat: Optional[timedelta] = None, + ) -> None: + self._set_timeouts( + join=join, last_call=last_call, close=close, heartbeat=heartbeat + ) + + @property + def join(self) -> timedelta: + """Get the join timeout.""" + return self._join + + @property + def last_call(self) -> timedelta: + """Get the last call timeout.""" + return self._last_call + + @property + def close(self) -> timedelta: + """Get the close timeout.""" + return self._close + + @property + def heartbeat(self) -> timedelta: + """Get the keep-alive heartbeat timeout.""" + return self._heartbeat + + def _set_timeouts(self, **timeouts: Optional[timedelta]): + for name, timeout in timeouts.items(): + if timeout is None: + timeout = self._DEFAULT_TIMEOUTS[name] + if timeout <= self._ZERO: + raise ValueError(f"The {name} timeout ({timeout}) must be positive.") + setattr(self, "_" + name, timeout) + + +@dataclass(repr=False, eq=False, frozen=True) +class RendezvousSettings: + """Hold the settings of the rendezvous. + + Attributes: + run_id: + The run id of the rendezvous. + min_nodes: + The minimum number of nodes to admit to the rendezvous. + max_nodes: + The maximum number of nodes to admit to the rendezvous. + timeout: + The timeout configuration of the rendezvous. + keep_alive_interval: + The amount of time a node waits before sending a heartbeat to keep + it alive in the rendezvous. + keep_alive_max_attempt: + The maximum number of failed heartbeat attempts after which a node + is considered dead. + """ + + run_id: str + min_nodes: int + max_nodes: int + timeout: RendezvousTimeout + keep_alive_interval: timedelta + keep_alive_max_attempt: int + + +@dataclass(eq=True, order=True, frozen=True) +class _NodeDesc: + """Describe a node in the rendezvous. + + Attributes: + addr: + The FQDN of the node or user specified local node address. + pid: + The id of the process in which the rendezvous handler runs. + local_id: + A process-wide unique id. + """ + + addr: str + pid: int + local_id: int + + def __repr__(self) -> str: + return f"{self.addr}_{self.pid}_{self.local_id}" + + +class _NodeDescGenerator: + """Generate node descriptors. + + A node descriptor is a combination of an FQDN, a process id, and an auto- + incremented integer that uniquely identifies a node in the rendezvous. + """ + + _lock: threading.Lock + _local_id: int + + def __init__(self) -> None: + self._lock = threading.Lock() + + # An integer that is incremented with each call to generate(). + self._local_id = 0 + + def generate(self, local_addr: Optional[str] = None) -> _NodeDesc: + # This method can be called by multiple threads concurrently; therefore, + # we must increment the integer atomically. + with self._lock: + local_id = self._local_id + + self._local_id += 1 + + return _NodeDesc(local_addr or socket.getfqdn(), os.getpid(), local_id) + + +class _RendezvousState: + """Hold the state of a rendezvous. + + Attributes: + round: + The current round of the rendezvous. + complete: + A boolean value indicating whether the current round of the + rendezvous is complete. + deadline: + The time at which the current round of the rendezvous will be + considered complete if it is still waiting for nodes to join. + closed: + A boolean value indicating whether the rendezvous is closed. + participants: + A dictionary of the participants and their corresponding ranks. + wait_list: + A set of nodes that are waiting to participate in the next round of + the rendezvous. + redundancy_list: + A set of nodes that are redundant in the current round and can join + the next rendezvous without triggering re-rendezvous. + last_heartbeats: + A dictionary containing each node's last heartbeat time. + """ + + round: int + complete: bool + deadline: Optional[datetime] + closed: bool + participants: Dict[_NodeDesc, int] + wait_list: Set[_NodeDesc] + redundancy_list: Set[_NodeDesc] + last_heartbeats: Dict[_NodeDesc, datetime] + + def __init__(self) -> None: + self.round = 0 + self.complete = False + self.deadline = None + self.closed = False + self.participants = {} + self.wait_list = set() + self.redundancy_list = set() + self.last_heartbeats = {} + + +def _remove_participant_epilogue( + state: _RendezvousState, settings: RendezvousSettings +) -> None: + if state.complete: + # If we do not have any participants left, move to the next round. + if not state.participants: + msg = "No participants left in the rendezvous, marking rendezvous as incomplete" + logger.debug(msg) + state.complete = False + + state.round += 1 + else: + if len(state.participants) < settings.min_nodes: + msg = ( + f"Number of participants {len(state.participants)}) less than" + f"min_nodes {settings.min_nodes}, clearning deadline in state" + ) + logger.debug(msg) + state.deadline = None + + +class _RendezvousStateHolder(ABC): + """Hold the shared rendezvous state synced with other nodes.""" + + @property + @abstractmethod + def state(self) -> _RendezvousState: + """Get the local state.""" + + @abstractmethod + def sync(self) -> Optional[bool]: + """Read or writes the latest state. + + Returns: + A boolean value indicating whether the local state, in case marked + as dirty, was successfully synced with other nodes. + """ + + @abstractmethod + def mark_dirty(self) -> None: + """Mark the local state as dirty.""" + + +class _BackendRendezvousStateHolder(_RendezvousStateHolder): + """Hold the rendezvous state synced with other nodes via a backend. + + Args: + backend: + The rendezvous backend to use. + settings: + The rendezvous settings. + cache_duration: + The amount of time, in seconds, to cache the last rendezvous state + before requesting it from the backend again. + """ + + _backend: RendezvousBackend + _state: _RendezvousState + _settings: RendezvousSettings + _cache_duration: int + _token: Token + _dirty: bool + _last_sync_time: float + _dead_nodes: List[_NodeDesc] + + def __init__( + self, + backend: RendezvousBackend, + settings: RendezvousSettings, + cache_duration: int = 1, + ) -> None: + self._backend = backend + self._state = _RendezvousState() + self._settings = settings + self._cache_duration = cache_duration + self._token = None + self._dirty = False + self._last_sync_time = -1 + self._dead_nodes = [] + + def _record(self, message: str, node_state: NodeState = NodeState.RUNNING): + construct_and_record_rdzv_event( + name=f"{self.__class__.__name__}.{get_method_name()}", + run_id=self._settings.run_id, + message=message, + node_state=node_state, + ) + + @property + def state(self) -> _RendezvousState: + """See base class.""" + return self._state + + def sync(self) -> Optional[bool]: + """See base class.""" + state_bits: Optional[bytes] = None + + token = None + + has_set: Optional[bool] + + if self._dirty: + has_set = False + + state_bits = pickle.dumps(self._state) + + set_response = self._backend.set_state(state_bits, self._token) + if set_response is not None: + state_bits, token, has_set = set_response + else: + has_set = None + + if self._cache_duration > 0: + # Avoid overloading the backend if we are asked to retrieve the + # state repeatedly. Try to serve the cached state. + if self._last_sync_time >= max( + time.monotonic() - self._cache_duration, 0 + ): + return None + + get_response = self._backend.get_state() + if get_response is not None: + state_bits, token = get_response + + if state_bits is not None: + try: + self._state = pickle.loads(state_bits) + except pickle.PickleError as exc: + raise RendezvousStateError( + "The rendezvous state is corrupt. See inner exception for details." + ) from exc + else: + self._state = _RendezvousState() + + if has_set and self._dead_nodes and logger.isEnabledFor(logging.DEBUG): + node_list = ", ".join(f"'{dead_node}'" for dead_node in self._dead_nodes) + + msg = ( + f"As part of the sync operation the node(s) {node_list} have been removed from the " + f"rendezvous '{self._settings.run_id}' since they had no heartbeat." + ) + self._record(message=msg) + logger.debug(msg) + + self._token = token + + self._dirty = False + + self._last_sync_time = time.monotonic() + + self._sanitize() + + return has_set + + def _sanitize(self) -> None: + state = self._state + + expire_time = datetime.now(timezone.utc) - ( + self._settings.keep_alive_interval * self._settings.keep_alive_max_attempt + ) + + # Filter out the dead nodes. + self._dead_nodes = [ + node + for node, last_heartbeat in state.last_heartbeats.items() + if last_heartbeat < expire_time + ] + + participant_removed = False + + for dead_node in self._dead_nodes: + msg = f"Detected dead node '{dead_node}', removing it from the rendezvous" + logger.debug(msg) + del state.last_heartbeats[dead_node] + + try: + del state.participants[dead_node] + + participant_removed = True + except KeyError: + pass + + try: + state.wait_list.remove(dead_node) + except KeyError: + pass + + try: + state.redundancy_list.remove(dead_node) + except KeyError: + pass + + if participant_removed: + # Common epilogue shared with the _remove_from_participants() + # function of _DistributedRendezvousOpExecutor. + _remove_participant_epilogue(state, self._settings) + + def mark_dirty(self) -> None: + """See base class. + + If the local rendezvous state is dirty, the next sync call will try to + write the changes back to the backend. However this attempt might fail + if another node, which had the same state, also made changes and wrote + them before us. + """ + self._dirty = True + + +class _Action(Enum): + """Specifies the possible actions based on the state of the rendezvous.""" + + KEEP_ALIVE = 1 + ADD_TO_PARTICIPANTS = 2 + ADD_TO_WAIT_LIST = 3 + ADD_TO_REDUNDANCY_LIST = 4 + REMOVE_FROM_PARTICIPANTS = 5 + REMOVE_FROM_WAIT_LIST = 6 + REMOVE_FROM_REDUNDANCY_LIST = 7 + MARK_RENDEZVOUS_COMPLETE = 8 + MARK_RENDEZVOUS_CLOSED = 9 + SYNC = 10 + ERROR_CLOSED = 11 + ERROR_TIMEOUT = 12 + FINISH = 13 + + +class _RendezvousContext: + """Holds the context of the rendezvous. + + Attributes: + node: + The node descriptor associated with the current rendezvous handler + instance. + state: + The current state of the rendezvous. + settings: + The rendezvous settings. + """ + + node: _NodeDesc + state: _RendezvousState + settings: RendezvousSettings + + def __init__( + self, node: _NodeDesc, state: _RendezvousState, settings: RendezvousSettings + ) -> None: + self.node = node + self.state = state + self.settings = settings + + +class _RendezvousOpExecutor(ABC): + """Execute rendezvous operations.""" + + @abstractmethod + def run( + self, + state_handler: Callable[[_RendezvousContext, float], _Action], + deadline: float, + update_deadline: Optional[Callable[[timedelta], float]] = None, + ) -> None: + """Execute a rendezvous operation. + + An operation is run inside a state machine and is expected to transition + the rendezvous from one state to another. + + Args: + state_handler: + A callable that is expected to return the next state transition + action based on the current state of the rendezvous. + deadline: + The time, in seconds, at which the operation will be considered + timed-out. + update_deadline: + Function to generate a new operation deadline if the current + node may participate in the next rendezvous. + """ + + +class _DistributedRendezvousOpExecutor(_RendezvousOpExecutor): + """Execute rendezvous operations using a shared state. + + Args: + node: + The node descriptor associated with the current rendezvous handler + instance. + state_holder: + The ``RendezvousStateHolder`` to use to sync the rendezvous state + with other nodes. + settings: + The rendezvous settings. + """ + + _node: _NodeDesc + _state: _RendezvousState + _state_holder: _RendezvousStateHolder + _settings: RendezvousSettings + + def __init__( + self, + node: _NodeDesc, + state_holder: _RendezvousStateHolder, + settings: RendezvousSettings, + ) -> None: + self._node = node + self._state_holder = state_holder + self._settings = settings + + def _record(self, message: str, node_state: NodeState = NodeState.RUNNING) -> None: + construct_and_record_rdzv_event( + name=f"{self.__class__.__name__}.{get_method_name()}", + run_id=self._settings.run_id, + message=message, + node_state=node_state, + hostname=self._node.addr, + pid=self._node.pid, + local_id=self._node.local_id, + ) + + def run( + self, + state_handler: Callable[[_RendezvousContext, float], _Action], + deadline: float, + update_deadline: Optional[Callable[[timedelta], float]] = None, + ) -> None: + """See base class.""" + action = None + while action != _Action.FINISH: + # Reads or writes the latest rendezvous state shared by all nodes in + # the rendezvous. Note that our local changes might get overridden + # by another node if that node synced its changes before us. + has_set = self._state_holder.sync() + if has_set is not None: + if has_set: + msg = ( + f"The node '{self._node}' has successfully synced its local changes with " + f"other nodes in the rendezvous '{self._settings.run_id}'." + ) + else: + msg = ( + f"The node '{self._node}' has a stale state and failed to sync its local " + f"changes with other nodes in the rendezvous '{self._settings.run_id}'." + ) + + self._record(message=msg) + logger.debug(msg) + + self._state = self._state_holder.state + + ctx = _RendezvousContext(self._node, self._state, self._settings) + + # Determine the next action to take based on the current state of + # the rendezvous. + action = state_handler(ctx, deadline) + + if action == _Action.FINISH: + continue + + if action == _Action.ERROR_CLOSED: + raise RendezvousClosedError + + if action == _Action.ERROR_TIMEOUT: + raise RendezvousTimeoutError + + if action == _Action.SYNC: + # Delay the execution by one second to avoid overloading the + # backend if we are asked to poll for state changes. + _delay(seconds=1) + else: + if action == _Action.KEEP_ALIVE: + self._keep_alive() + elif action == _Action.ADD_TO_PARTICIPANTS: + self._add_to_participants() + elif action == _Action.ADD_TO_WAIT_LIST: + self._add_to_wait_list() + elif action == _Action.ADD_TO_REDUNDANCY_LIST: + self._add_to_redundancy_list() + elif action == _Action.REMOVE_FROM_PARTICIPANTS: + self._remove_from_participants() + elif action == _Action.REMOVE_FROM_WAIT_LIST: + self._remove_from_wait_list() + elif action == _Action.REMOVE_FROM_REDUNDANCY_LIST: + self._remove_from_redundancy_list() + # update deadline since the node may participate in rendezvous process + if update_deadline: + deadline = update_deadline(self._settings.timeout.join) + elif action == _Action.MARK_RENDEZVOUS_COMPLETE: + self._mark_rendezvous_complete() + elif action == _Action.MARK_RENDEZVOUS_CLOSED: + self._mark_rendezvous_closed() + + # Attempt to sync our changes back to other nodes. + self._state_holder.mark_dirty() + + def _keep_alive(self) -> None: + msg = ( + f"The node '{self._node}' updated its keep-alive heartbeat time for the rendezvous " + f"'{self._settings.run_id}'. Pending sync." + ) + self._record(message=msg) + logger.debug(msg) + + self._state.last_heartbeats[self._node] = datetime.now(timezone.utc) + + def _add_to_participants(self) -> None: + msg = ( + f"The node '{self._node}' added itself to the participants of round " + f"{self._state.round} of the rendezvous '{self._settings.run_id}'. Pending sync." + ) + self._record(message=msg) + logger.debug(msg) + + state = self._state + + try: + state.wait_list.remove(self._node) + except KeyError: + pass + + # The ranks of the participants will be set once the rendezvous is + # complete. + state.participants[self._node] = 0 + + self._keep_alive() + + if len(state.participants) == self._settings.min_nodes: + state.deadline = ( + datetime.now(timezone.utc) + self._settings.timeout.last_call + ) + + if len(state.participants) == self._settings.max_nodes: + self._mark_rendezvous_complete() + + def _add_to_wait_list(self) -> None: + msg = ( + f"The node '{self._node}' added itself to the wait list of round " + f"{self._state.round + 1} of the rendezvous '{self._settings.run_id}'. Pending sync." + ) + self._record(message=msg) + logger.debug(msg) + + if self._node in self._state.redundancy_list: + self._state.redundancy_list.remove(self._node) + self._state.wait_list.add(self._node) + + self._keep_alive() + + def _add_to_redundancy_list(self) -> None: + msg = ( + f"The node '{self._node}' added itself to the redundancy list of round " + f"{self._state.round + 1} of the rendezvous '{self._settings.run_id}'. Pending sync." + ) + self._record(message=msg) + logger.debug(msg) + + self._state.redundancy_list.add(self._node) + + self._keep_alive() + + def _remove_from_participants(self) -> None: + msg = ( + f"The node '{self._node}' removed itself from the participants of round " + f"{self._state.round} of the rendezvous '{self._settings.run_id}'. Pending sync." + ) + self._record(message=msg) + logger.debug(msg) + + state = self._state + + del state.participants[self._node] + + del state.last_heartbeats[self._node] + + # Common epilogue shared with the sanitizer() function of + # _BackendRendezvousStateHolder. + _remove_participant_epilogue(state, self._settings) + + def _remove_from_wait_list(self) -> None: + msg = ( + f"The node '{self._node}' removed itself from the wait list of round " + f"{self._state.round + 1} of the rendezvous '{self._settings.run_id}'. Pending sync." + ) + self._record(message=msg) + logger.debug(msg) + + self._state.wait_list.remove(self._node) + + del self._state.last_heartbeats[self._node] + + def _remove_from_redundancy_list(self) -> None: + msg = ( + f"The node '{self._node}' removed itself from the redunant list of round " + f"{self._state.round + 1} of the rendezvous '{self._settings.run_id}'. Pending sync." + ) + self._record(message=msg) + logger.debug(msg) + + self._state.redundancy_list.remove(self._node) + + del self._state.last_heartbeats[self._node] + + def _mark_rendezvous_complete(self) -> None: + msg = ( + f"The node '{self._node}' marked round {self._state.round} of the rendezvous " + f"'{self._settings.run_id}' as complete. Pending sync." + ) + self._record(message=msg, node_state=NodeState.SUCCEEDED) + logger.debug(msg) + + state = self._state + + state.complete = True + state.deadline = None + + # Assign the ranks. + for rank, node in enumerate(sorted(state.participants)): + state.participants[node] = rank + + def _mark_rendezvous_closed(self) -> None: + msg = ( + f"The node '{self._node}' marked the rendezvous '{self._settings.run_id}' as closed. " + "Pending sync." + ) + self._record(message=msg, node_state=NodeState.SUCCEEDED) + logger.debug(msg) + + self._state.closed = True + + +def _should_keep_alive(ctx: _RendezvousContext) -> bool: + """Determine whether a keep-alive heartbeat should be sent.""" + try: + last_heartbeat = ctx.state.last_heartbeats[ctx.node] + except KeyError: + return False + + return ( + last_heartbeat <= datetime.now(timezone.utc) - ctx.settings.keep_alive_interval + ) + + +class _RendezvousExitOp: + """Represent a rendezvous exit operation.""" + + def __call__(self, ctx: _RendezvousContext, deadline: float) -> _Action: + if ctx.node in ctx.state.participants: + if time.monotonic() > deadline: + return _Action.ERROR_TIMEOUT + return _Action.REMOVE_FROM_PARTICIPANTS + return _Action.FINISH + + +class _RendezvousJoinOp: + """Represent a rendezvous join operation.""" + + def __call__(self, ctx: _RendezvousContext, deadline: float) -> _Action: + state = ctx.state + + # A closed rendezvous means that it no longer accepts new nodes. + if state.closed: + if ctx.node in state.redundancy_list: + msg = f"The rendezvous '{ctx.settings.run_id}' is closed, terminating pending rendezvous." + raise RendezvousGracefulExitError(msg) + return _Action.ERROR_CLOSED + + if ctx.node in state.redundancy_list: + msg = f"The node {ctx.node} is in redunancy list" + logger.debug(msg) + # don't apply the timeout logic here, since we want to allow the node to rejoin + if len(state.participants) == ctx.settings.max_nodes: + if _should_keep_alive(ctx): + return _Action.KEEP_ALIVE + else: + return _Action.SYNC + else: + # transition to waiting state that will respect timeouts. + msg = f"The node {ctx.node} is removed from redunancy list" + logger.debug(msg) + return _Action.REMOVE_FROM_REDUNDANCY_LIST + + is_participant = ctx.node in state.participants + + # If we are part of the rendezvous and it is already complete there is + # no further action to take. + if state.complete and is_participant: + return _Action.FINISH + + now = time.monotonic() + if now > deadline: + rollback_period = 5 # 5 seconds + + # If we still have time to rollback (a short period on top of the + # operation deadline), try to remove ourself from the rendezvous. + # It is okay if we can't though as our keep-alive will eventually + # expire. + if now <= deadline + rollback_period: + # If we are part of the rendezvous, it means we couldn't find + # enough participants to complete it on time. + if is_participant: + return _Action.REMOVE_FROM_PARTICIPANTS + # If we are in the wait list, it means we couldn't wait till the + # next round of the rendezvous. + if ctx.node in state.wait_list: + return _Action.REMOVE_FROM_WAIT_LIST + return _Action.ERROR_TIMEOUT + + if state.complete: + # If we are here, it means we are not part of the rendezvous. In + # case the rendezvous has capacity for additional participants add + # ourself to the wait list for the next round. + if len(state.participants) < ctx.settings.max_nodes: + if ctx.node not in state.wait_list: + return _Action.ADD_TO_WAIT_LIST + elif len(state.participants) >= ctx.settings.max_nodes: + if ( + ctx.node not in state.redundancy_list + and ctx.node not in state.wait_list + ): + return _Action.ADD_TO_REDUNDANCY_LIST + elif is_participant: + # If the rendezvous has enough number of participants including us, + # check whether we have passed the rendezvous deadline. If yes, + # complete it. + if ( + len(state.participants) >= ctx.settings.min_nodes + and len(state.participants) <= ctx.settings.max_nodes + and state.deadline is not None + ): + if state.deadline < datetime.now(timezone.utc): + msg = ( + f"The node '{ctx.node}' marking the rendezvous complete, " + f"quorum established within deadline" + ) + logger.debug(msg) + return _Action.MARK_RENDEZVOUS_COMPLETE + else: + msg = f"The node '{ctx.node}' can't complete rendezvous: deadline reached" + logger.debug(msg) + else: + msg = f"The node '{ctx.node}' can't complete rendezvous: not enough participants" + logger.debug(msg) + else: + # The rendezvous is not complete yet and we are not part of it. Try + # to join. + return _Action.ADD_TO_PARTICIPANTS + + if _should_keep_alive(ctx): + return _Action.KEEP_ALIVE + + # At this point either the rendezvous is not complete, but we are part + # of it, which means we have to wait for other participants to join; or + # the rendezvous is complete, but we are not part of it, which means we + # have to wait for the next round. + return _Action.SYNC + + +class _RendezvousCloseOp: + """Represent a rendezvous close operation.""" + + def __call__(self, ctx: _RendezvousContext, deadline: float) -> _Action: + if ctx.state.closed: + return _Action.FINISH + if time.monotonic() > deadline: + return _Action.ERROR_TIMEOUT + return _Action.MARK_RENDEZVOUS_CLOSED + + +class _RendezvousKeepAliveOp: + """Represent a rendezvous keep-alive update operation.""" + + def __call__(self, ctx: _RendezvousContext, deadline: float) -> _Action: + if _should_keep_alive(ctx): + if time.monotonic() > deadline: + return _Action.ERROR_TIMEOUT + return _Action.KEEP_ALIVE + return _Action.FINISH + + +class DynamicRendezvousHandler(RendezvousHandler): + """Represent a handler that sets up a rendezvous among a set of nodes.""" + + # Static + _node_desc_generator = _NodeDescGenerator() + + _this_node: _NodeDesc + _settings: RendezvousSettings + _backend_name: str + _store: Store + _state_holder: _RendezvousStateHolder + _op_executor: _RendezvousOpExecutor + _heartbeat_lock: threading.Lock + _keep_alive_timer: Optional[_PeriodicTimer] + + @classmethod + def from_backend( + cls, + run_id: str, + store: Store, + backend: RendezvousBackend, + min_nodes: int, + max_nodes: int, + local_addr: Optional[str] = None, + timeout: Optional[RendezvousTimeout] = None, + ): + """Create a new :py:class:`DynamicRendezvousHandler`. + + Args: + run_id: + The run id of the rendezvous. + store: + The C10d store to return as part of the rendezvous. + backend: + The backend to use to hold the rendezvous state. + min_nodes: + The minimum number of nodes to admit to the rendezvous. + max_nodes: + The maximum number of nodes to admit to the rendezvous. + local_addr: + The local node address. + timeout: + The timeout configuration of the rendezvous. + """ + # We associate each handler instance with a unique node descriptor. + node = cls._node_desc_generator.generate(local_addr) + + settings = RendezvousSettings( + run_id, + min_nodes, + max_nodes, + timeout or RendezvousTimeout(), + keep_alive_interval=timedelta(seconds=5), + keep_alive_max_attempt=3, + ) + + state_holder = _BackendRendezvousStateHolder(backend, settings) + + return cls(node, settings, backend.name, store, state_holder) + + def __init__( + self, + node: _NodeDesc, + settings: RendezvousSettings, + backend_name: str, + store: Store, + state_holder: _RendezvousStateHolder, + ) -> None: + if not settings.run_id: + raise ValueError("The run id must be a non-empty string.") + + if settings.min_nodes < 1: + raise ValueError( + f"The minimum number of nodes ({settings.min_nodes}) must be greater than zero." + ) + + if settings.max_nodes < settings.min_nodes: + raise ValueError( + f"The maximum number of nodes ({settings.max_nodes}) must be greater than or equal " + f"to the minimum number of nodes ({settings.min_nodes})." + ) + + self._this_node = node + + self._settings = settings + + self._backend_name = backend_name + + self._store = store + + self._state_holder = state_holder + + self._op_executor = _DistributedRendezvousOpExecutor( + self._this_node, self._state_holder, self._settings + ) + + self._heartbeat_lock = threading.Lock() + + self._keep_alive_timer = None + + # Cached shared store server reference + self._shared_tcp_store_server: Optional[dist.Store] = None + + self._bootstrap_store_info: Optional[RendezvousStoreInfo] = None + + def _record( + self, + message: str, + node_state: NodeState = NodeState.RUNNING, + rank: Optional[int] = None, + ) -> None: + construct_and_record_rdzv_event( + name=f"{self.__class__.__name__}.{get_method_name()}", + run_id=self._settings.run_id, + message=message, + node_state=node_state, + hostname=self._this_node.addr, + pid=self._this_node.pid, + local_id=self._this_node.local_id, + rank=rank, + ) + + def _create_tcp_store_server(self, master_addr, master_port) -> dist.TCPStore: + return dist.TCPStore( + host_name=master_addr, + port=master_port, + is_master=True, + multi_tenant=True, + ) + + @property + def settings(self) -> RendezvousSettings: + """Get the settings of the rendezvous.""" + return self._settings + + def get_backend(self) -> str: + """See base class.""" + return self._backend_name + + @property + def use_agent_store(self) -> bool: + """See base class.""" + return os.getenv("TORCH_DISABLE_SHARE_RDZV_TCP_STORE", "0") != "1" + + def next_rendezvous(self) -> RendezvousInfo: + """See base class.""" + msg = ( + f"The node '{self._this_node}' attempts to join the next round of the rendezvous " + f"'{self._settings.run_id}'." + ) + self._record(message=msg) + logger.info(msg) + + try: + self._stop_heartbeats() + + # Delay the execution for a small random amount of time if this is our + # first run. This will slightly skew the rendezvous attempts across the + # nodes and reduce the load on the backend. + if self._state_holder.state.round == 0: + _delay(seconds=(0, 0.3)) + + exit_op = _RendezvousExitOp() + join_op = _RendezvousJoinOp() + + deadline = self._get_deadline(self._settings.timeout.join) + self._op_executor.run(exit_op, deadline) + self._op_executor.run(join_op, deadline, self._get_deadline) + + self._start_heartbeats() + + rank, world_size = self._get_world() + store = self._get_store() + + except Exception as e: + self._record( + message=f"{type(e).__name__}: {str(e)}", + node_state=NodeState.FAILED, + ) + raise + + msg = ( + f"The node '{self._this_node}' has joined round {self._state_holder.state.round} of " + f"the rendezvous '{self._settings.run_id}' as rank {rank} in a world of size " + f"{world_size}." + ) + self._record(message=msg, rank=rank) + logger.info(msg) + + # opt-out option of TCPStore sharing + if os.getenv("TORCH_DISABLE_SHARE_RDZV_TCP_STORE", "0") == "1": + bootstrap_store_info = RendezvousStoreInfo.build( + rank, store, local_addr=self._this_node.addr + ) + return RendezvousInfo( + store, + rank, + world_size, + bootstrap_store_info, + ) + + # This will only be hit when TCPStore sharing is enabled. + if self._bootstrap_store_info is None: + # To avoid race in get_free_port because we release the port after the call, + # we want to create a TCPStore server soon afterwards. + server_port = 0 + if rank == 0: + self._shared_tcp_store_server = self._create_tcp_store_server( + self._this_node.addr, server_port + ) + server_port = self._shared_tcp_store_server.port + self._bootstrap_store_info = RendezvousStoreInfo.build( + rank, + store, + local_addr=self._this_node.addr, + server_port=server_port, # For non-0 rank, this is a no-op + ) + + assert self._bootstrap_store_info is not None + if rank == 0: + assert self._shared_tcp_store_server is not None + + return RendezvousInfo( + store, + rank, + world_size, + self._bootstrap_store_info, # type: ignore[assignment] + ) + + def is_closed(self) -> bool: + """See base class.""" + try: + with self._heartbeat_lock: + self._state_holder.sync() + + return self._state_holder.state.closed + + except Exception as e: + self._record( + message=f"{type(e).__name__}: {str(e)}", + node_state=NodeState.FAILED, + ) + raise + + def set_closed(self) -> None: + """See base class.""" + try: + with self._heartbeat_lock: + self._close() + except Exception as e: + self._record( + message=f"{type(e).__name__}: {str(e)}", + node_state=NodeState.FAILED, + ) + raise + + def num_nodes_waiting(self) -> int: + """See base class.""" + try: + with self._heartbeat_lock: + self._state_holder.sync() + + return len(self._state_holder.state.wait_list) + + except Exception as e: + self._record( + message=f"{type(e).__name__}: {str(e)}", + node_state=NodeState.FAILED, + ) + raise + + def get_run_id(self) -> str: + """See base class.""" + return self._settings.run_id + + def shutdown(self) -> bool: + """See base class.""" + self._stop_heartbeats() + + try: + self._close() + + return True + except RendezvousError as ex: + msg = ( + f"The node '{self._this_node}' has failed to shutdown the rendezvous " + f"'{self._settings.run_id}' due to an error of type {type(ex).__name__}." + ) + self._record(message=msg, node_state=NodeState.FAILED) + logger.warning(msg) + + return False + except Exception as e: + self._record( + message=f"{type(e).__name__}: {str(e)}", + node_state=NodeState.FAILED, + ) + raise + + def _close(self) -> None: + op = _RendezvousCloseOp() + + deadline = self._get_deadline(self._settings.timeout.close) + + self._op_executor.run(op, deadline) + + msg = f"The node '{self._this_node}' has closed the rendezvous '{self._settings.run_id}'." + self._record(message=msg, node_state=NodeState.SUCCEEDED) + logger.info(msg) + + @staticmethod + def _keep_alive_weak(weak_self) -> None: + self = weak_self() + if self is not None: + self._keep_alive() + + def _keep_alive(self) -> None: + self._heartbeat_lock.acquire() + + op = _RendezvousKeepAliveOp() + + deadline = self._get_deadline(self._settings.timeout.heartbeat) + + try: + self._op_executor.run(op, deadline) + + msg = ( + f"The node '{self._this_node}' has sent a keep-alive heartbeat to the rendezvous " + f"'{self._settings.run_id}'." + ) + self._record(message=msg) + logger.debug(msg) + except RendezvousError as ex: + msg = ( + f"The node '{self._this_node}' has failed to send a keep-alive heartbeat to the " + f"rendezvous '{self._settings.run_id}' due to an error of type {type(ex).__name__}." + ) + self._record(message=msg, node_state=NodeState.FAILED) + logger.warning(msg) + finally: + self._heartbeat_lock.release() + + def _start_heartbeats(self) -> None: + self._keep_alive_timer = _PeriodicTimer( + self._settings.keep_alive_interval, self._keep_alive_weak, weakref.ref(self) + ) + + self._keep_alive_timer.set_name( + f"RendezvousKeepAliveTimer_{self._this_node.local_id}" + ) + + self._keep_alive_timer.start() + + def _stop_heartbeats(self) -> None: + if self._keep_alive_timer is None: + return + + self._keep_alive_timer.cancel() + + def _get_world(self) -> Tuple[int, int]: + state = self._state_holder.state + + return state.participants[self._this_node], len(state.participants) + + def _wrap_store(self, store: Store) -> Store: + key_prefix = ( + f"core.rendezvous.{self._settings.run_id}.{self._state_holder.state.round}" + ) + + return dist.PrefixStore(key_prefix, store) + + def _get_store(self) -> Store: + return self._wrap_store(self._store) + + def _get_deadline(self, timeout: timedelta) -> float: + return time.monotonic() + timeout.total_seconds() + + +def _get_timeout(params: RendezvousParameters, key: str) -> Optional[timedelta]: + timeout = params.get_as_int(key + "_timeout") + if timeout is None: + return None + return timedelta(seconds=timeout) + + +def create_handler( + store: Store, backend: RendezvousBackend, params: RendezvousParameters +) -> DynamicRendezvousHandler: + """Create a new :py:class:`DynamicRendezvousHandler` from the specified parameters. + + Args: + store: + The C10d store to return as part of the rendezvous. + backend: + The backend to use to hold the rendezvous state. + + +-------------------+------------------------------------------------------+ + | Parameter | Description | + +===================+======================================================+ + | join_timeout | The total time, in seconds, within which the | + | | rendezvous is expected to complete. Defaults to 600 | + | | seconds. | + +-------------------+------------------------------------------------------+ + | last_call_timeout | An additional wait amount, in seconds, before | + | | completing the rendezvous once the minimum number of | + | | nodes has been reached. Defaults to 30 seconds. | + +-------------------+------------------------------------------------------+ + | close_timeout | The time, in seconds, within which the rendezvous is | + | | expected to close after a call to | + | | :py:meth:`RendezvousHandler.set_closed` or | + | | :py:meth:`RendezvousHandler.shutdown`. Defaults to | + | | 30 seconds. | + +-------------------+------------------------------------------------------+ + """ + try: + timeout = RendezvousTimeout( + _get_timeout(params, "join"), + _get_timeout(params, "last_call"), + _get_timeout(params, "close"), + ) + + return DynamicRendezvousHandler.from_backend( + params.run_id, + store, + backend, + params.min_nodes, + params.max_nodes, + params.local_addr, + timeout, + ) + except Exception as e: + construct_and_record_rdzv_event( + message=f"{type(e).__name__}: {str(e)}", + run_id=params.run_id, + node_state=NodeState.FAILED, + ) + raise diff --git a/mindnlp/core/distributed/elastic/rendezvous/etcd_rendezvous.py b/mindnlp/core/distributed/elastic/rendezvous/etcd_rendezvous.py new file mode 100644 index 000000000..d1f6b8f35 --- /dev/null +++ b/mindnlp/core/distributed/elastic/rendezvous/etcd_rendezvous.py @@ -0,0 +1,1077 @@ +#!/usr/bin/env python3 +# mypy: allow-untyped-defs + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import json +import logging +import sys +import threading +import time +from typing import Optional + +import etcd # type: ignore[import] + +from core.distributed.elastic.rendezvous import ( + RendezvousClosedError, + RendezvousError, + RendezvousHandler, + RendezvousInfo, + RendezvousParameters, + RendezvousStoreInfo, + RendezvousTimeoutError, +) + +from .etcd_store import cas_delay, EtcdStore +from .utils import parse_rendezvous_endpoint + + +__all__ = [ + "EtcdRendezvousRetryableFailure", + "EtcdRendezvousRetryImmediately", + "EtcdRendezvousHandler", + "EtcdRendezvous", + "create_rdzv_handler", +] + +_log_fmt = logging.Formatter("%(levelname)s %(asctime)s %(message)s") +_log_handler = logging.StreamHandler(sys.stderr) +_log_handler.setFormatter(_log_fmt) + +logger = logging.getLogger(__name__) +logger.propagate = False +logger.setLevel(logging.INFO) +logger.addHandler(_log_handler) + + +# Retryable failure exception means the we were too late to make +# a desired state transition (e.g. because of a race condition), +# and should now restart from the beginning. +# A small delay is recommended to avoid spamming Etcd. +class EtcdRendezvousRetryableFailure(Exception): + pass + + +# Similar to retryable failure, but the new state we observed suggests we +# can re-try immediately, i.e. without a need for "safety delay". +class EtcdRendezvousRetryImmediately(Exception): + pass + + +# Default timeout for the rendezvous. +_DEFAULT_TIMEOUT: int = 600 # 10 minutes + +# Additional waiting time after reaching the minimum number of nodes +# in case the rendezvous is elastic (min != max). +_DEFAULT_LAST_CALL_TIMEOUT: int = 30 # 30 seconds + +# Various constants used internally in EtcdRendezvous +CONST_ETCD_SETUP_TTL = 5 +CONST_ETCD_FROZEN_TTL = 10 +CONST_ETCD_JOINABLE_EPHEMERAL_TTL = 10 + +# Ephemeral node TTL for worker's keep-alive key: +CONST_WORKER_KEEPALIVE_TTL = 10 + +# TTL for the ephemeral run_id-specific directory. All rendezvous state data +# for a specific run_id (job instance) is contained within directory. +# Its only role is to clean-up rendezvous data from old runs (for the case when +# etcd server is persistent), and has no affect on correctness, but should be +# larger than any timeouts that a worker process is expected to survive: +CONST_RUNID_SUBROOT_TTL = 7200 # 2 hours + + +class EtcdRendezvousHandler(RendezvousHandler): + """ + Implements a + :py:class:`core.distributed.elastic.rendezvous.RendezvousHandler` interface + backed by + :py:class:`core.distributed.elastic.rendezvous.etcd_rendezvous.EtcdRendezvous`. + ``EtcdRendezvousHandler`` uses a URL to configure the type of rendezvous to + use and to pass implementation specific configurations to the rendezvous + module. The basic etcd rendezvous configuration URL looks like the following + :: + + etcd://:/?min_workers=&max_workers= # noqa: W605 + + -- example -- + + etcd://localhost:2379/1234?min_workers=1&max_workers=3 + + The URL above is interpreted as follows: + + 1. Use the rendezvous handler that is registered with the ``etcd`` + scheme + 2. The ``etcd`` endpoint to use is ``localhost:2379`` + 3. ``job_id == 1234`` is used as the prefix in etcd (this allows one to + share a common etcd server for multiple jobs so long as the + ``job_ids`` are guaranteed to be unique). Note that the job id can be + any string (e.g. does not need to be a number) as long as it is + unique. + 4. ``min_workers=1`` and ``max_workers=3`` specifies a range for + membership size - Torch Distributed Elastic starts running the job as + long as the cluster size is greater than or equal to ``min_workers`` + and admits up to ``max_workers`` into the cluster. + + Below are a full list of the parameters that can be passed to etcd + rendezvous: + + +--------------------------------------------+--------------------------+ + | Parameter | Description | + +============================================+==========================+ + | min_workers | minimum number of | + | | workers for the | + | | rendezvous to be valid | + +--------------------------------------------+--------------------------+ + | max_workers | maximum number of | + | | workers to admit | + +--------------------------------------------+--------------------------+ + | timeout | total timeout within | + | | which next_rendezvous is | + | | expected to succeed | + | | (default 600s) | + +--------------------------------------------+--------------------------+ + | last_call_timeout | additional wait amount | + | | ("last call") after min | + | | number of workers has | + | | been reached (defaults | + | | to 30s) | + +--------------------------------------------+--------------------------+ + | etcd_prefix | path prefix (from etcd | + | | root), inside which all | + | | etcd nodes will be | + | | created (defaults to | + | | ``/torchelastic/p2p``) | + +--------------------------------------------+--------------------------+ + """ + + def __init__(self, rdzv_impl: "EtcdRendezvous", local_addr: Optional[str]): + """ + Args: + rdzv_impl: the implementation of the rendezvous + local_addr: the local address of the current node + """ + + self._rdzv_impl = rdzv_impl + self._local_addr = local_addr + + def __del__(self): + # TODO: look into using weakref here instead. + del self._rdzv_impl + + def get_backend(self) -> str: + return "etcd" + + def next_rendezvous(self): + rdzv_version, rank, world_size = self._rdzv_impl.rendezvous_barrier() + + logger.info("Creating EtcdStore as the c10d::Store implementation") + store = self._rdzv_impl.setup_kv_store(rdzv_version) + + bootstrap_store_info = RendezvousStoreInfo.build( + rank, store, local_addr=self._local_addr + ) + return RendezvousInfo(store, rank, world_size, bootstrap_store_info) + + def is_closed(self): + try: + _, state = self._rdzv_impl.get_rdzv_state() + return state["status"] == "closed" + except etcd.EtcdKeyNotFound: + # No rendezvous state, so it cannot be closed. + return False + + def set_closed(self): + self._rdzv_impl.set_closed() + + def num_nodes_waiting(self): + try: + _, state = self._rdzv_impl.get_rdzv_state() + if state["status"] == "final": + return state["num_workers_waiting"] + except etcd.EtcdKeyNotFound: + pass + return 0 + + def get_run_id(self) -> str: + return self._rdzv_impl._run_id + + def shutdown(self) -> bool: + try: + self.set_closed() + return True + except BaseException as e: + logger.warning("Shutdown failed. Error occurred: %s", str(e)) + return False + + +# TODO: we should probably handle a few additional errors, +# like EtcdLeaderElectionInProgress and EtcdWatcherCleared. These are +# only relevant for multi-node Etcd ensemble. A simple retry would work, +# but is verbose to add everywhere. Consider wrapping the client calls +# into auto-retry for these errors? +# +class EtcdRendezvous: + """A rendezvous implementation that uses `etcd `__ as the backend store.""" + + def __init__( + self, + client, + prefix, + run_id, + num_min_workers, + num_max_workers, + timeout, + last_call_timeout, + ): + self.client = client + logger.info("Etcd machines: %s", self.client.machines) + + self._prefix = prefix + self._run_id = run_id + self._num_min_workers = num_min_workers + self._num_max_workers = num_max_workers + self._timeout = timeout + self._last_call_timeout = last_call_timeout + + # For cleaning up TTL refresher threads (for ephemeral keys) + self._lease_run_id_stop = None + self._lease_this_rank_stop = None + + if not self._prefix.endswith("/"): + self._prefix += "/" + + # Setup a permanent prefix dir, if didn't exist + if self._prefix != "/": + self.create_path_if_not_exists(self._prefix) + + # Lease a "sub-root" node specific to this job instance (run_id) + self.create_path_if_not_exists(self.get_path(""), ttl=CONST_RUNID_SUBROOT_TTL) + self._lease_run_id_stop = self.setup_lease_renewal( + self.get_path(""), ttl=CONST_RUNID_SUBROOT_TTL + ) + + # Subdir for all rendezvous work + self.create_path_if_not_exists(self.get_path("/rdzv")) + + # Create a rendezvous version counter, if doesn't exist + try: + self.client.write( + key=self.get_path("/rdzv/version_counter"), value="0", prevExist=False + ) + except etcd.EtcdAlreadyExist: + pass + + def __del__(self): + # TODO: look into using weakref here instead. + if self._lease_run_id_stop is not None: + self._lease_run_id_stop.set() + + if self._lease_this_rank_stop is not None: + self._lease_this_rank_stop.set() + + def rendezvous_barrier(self): + """ + Main entry point for next rendezvous. + + This method is blocking until rendezvous succeeds or a timeout occurs. + + Returns: + ``(rdzv_version, rank, world_size)`` + + Raises: + RendezvousTimeoutError - timeout waiting for rendezvous + RendezvousClosedError - rendezvous is or was closed while waiting + RendezvousError - other persistent errors that + render the rendezvous non-retryable + """ + self._rendezvous_deadline = time.time() + self._timeout + while True: + if time.time() > self._rendezvous_deadline: + raise RendezvousTimeoutError + + logger.info("Attempting to join next rendezvous") + try: + # Dis-own our lease in the previous rendezvous, if exists + if self._lease_this_rank_stop is not None: + self._lease_this_rank_stop.set() + + return self.init_phase() + + except EtcdRendezvousRetryImmediately: + # The type of failure suggests we can retry without delay + pass + + except EtcdRendezvousRetryableFailure: + # In case of retryable failure, wait a small delay + # to avoid spamming etcd + time.sleep(1) + + except RendezvousTimeoutError: + logger.info("Rendezvous timeout occurred in EtcdRendezvousHandler") + raise + + except RendezvousClosedError: + logger.info( + "Rendezvous for run_id=%s was observed to be closed", self._run_id + ) + raise + + except RendezvousError: + raise + + except Exception as e: + # In case of a general exception, wait a small delay + # to avoid spamming etcd + # FIXME: there are a few things that fall under this like + # etcd.EtcdKeyNotFound, etc, which could be handled more explicitly. + logger.info("Rendezvous attempt failed, will retry. Reason: %s", e) + time.sleep(1) + + def init_phase(self): + """ + Initially, the rendezvous state is expected to be one of: + + 1. empty (non-existent) - in this case we try to create a new one. + 2. joinable - we try to join it. + 3. final - we announce ourselves as waiting, and go into monitoring mode + + Any other state is considered transitional, and will be retried after + a short delay. + + Returns: + ``(rdzv_version, rank, world_size)`` + + Raises: + RendezvousClosedError - current rendezvous was/is closed + EtcdRendezvousRetryableFailure - observed some intermediate + state, which is best handled by retrying later + """ + try: + active_version = self.try_create_rendezvous() + state = json.loads(active_version.value) + logger.info("New rendezvous state created: %s", state) + except etcd.EtcdAlreadyExist: + active_version, state = self.get_rdzv_state() + # Note: it is possible for above query to fail (etcd.EtcdKeyNotFound), + # but this is ok for us - just means we'll restart from beginning. + logger.info("Observed existing rendezvous state: %s", state) + + if state["status"] == "closed": + raise RendezvousClosedError + + if state["status"] == "joinable": + return self.join_phase(state["version"]) + + if state["status"] == "final": + self.handle_existing_rendezvous(state["version"]) + raise EtcdRendezvousRetryImmediately + + self.try_wait_for_state_change(etcd_index=active_version.etcd_index + 1) + raise EtcdRendezvousRetryableFailure + + def join_phase(self, expected_version): + """ + We observed a rendezvous state in 'joinable' state, and attempt to join this + particular version, and then wait for all other peers to join. + """ + # Failure to join will propagate an exception, causing a re-entry. + active_version, this_rank = self.join_rendezvous(expected_version) + state = json.loads(active_version.value) + logger.info( + "Joined rendezvous version %s as rank %s. Full state: %s", + state["version"], + this_rank, + state, + ) + + # If this worker was first to reach num_min_workers requirement, + # and rendezvous is still joinable (therefore it is elastic), + # then this worker will be responsible for waiting out the "last call" + # timeout and closing (i.e. transitioning to 'frozen') the rendezvous + # afterwards. + # As a safety against a potential failure of this worker (during the + # last call timeout), the rendezvous state is made ephemeral + # when min_num_workers is reached. + + if this_rank == self._num_min_workers - 1 and state["status"] == "joinable": + logger.info("Rank %s is responsible for join last call.", this_rank) + last_call_deadline = time.time() + self._last_call_timeout + self.handle_join_last_call(expected_version, last_call_deadline) + logger.info("Rank %s finished join last call.", this_rank) + + # Wait for rendezvous state to be frozen, which means a fixed set of peers + logger.info("Waiting for remaining peers.") + active_version = self.wait_for_peers(expected_version) + state = json.loads(active_version.value) + + assert ( + state["version"] == expected_version + ), "Logic error: failed to observe version mismatch" + + return self.confirm_phase(expected_version, this_rank) + + def confirm_phase(self, expected_version, this_rank): + """ + Once the rendezvous state transitions from 'joinable' to 'frozen', + we have every participant confirm their membership and setup per-member + keep-alive TTL keys, and then wait for all other participants to confirm, + which would then successfully conclude this rendezvous. + """ + logger.info("All peers arrived. Confirming membership.") + self.confirm_membership(expected_version, this_rank) + + logger.info("Waiting for confirmations from all peers.") + active_version = self.wait_for_final(expected_version) + state = json.loads(active_version.value) + + logger.info( + "Rendezvous version %s is complete. Final state: %s", + state["version"], + state, + ) + + # Rendezvous version number; our rank in it; world size + return state["version"], this_rank, len(state["participants"]) + + def handle_existing_rendezvous(self, expected_version): + """ + Handle the case when there's an existing (state 'final) rendezvous already + in place, and we have to announce ourselves waiting, and wait until + the next rendezvous opportunity. + """ + # If state is 'final' -> increment num_workers_waiting + # Then, observe state changes: + # 1. if it's no longer final -> bail out and re-try + # 2. if keep alives are missing, destroy it and bail out. + active_state = self.announce_self_waiting(expected_version) + logger.info( + "Added self to waiting list. Rendezvous full state: %s", active_state.value + ) + + self.wait_for_rendezvous_to_free(expected_version) + logger.info( + "Previously existing rendezvous state changed. Will re-try joining." + ) + + def try_create_rendezvous(self): + """ + Create new rendezvous state or raise an exception that indicates an unexpected state (e.g. already exists). + + Raises: + RendezvousError - on unexpected state + """ + # Initially active_version is ephemeral - this is to handle the + # possibility that might fail to complete the setup transaction, + # i.e. the transition "setup" -> "joinable". + active_version = self.client.write( + key=self.get_path("/rdzv/active_version"), + value=json.dumps({"status": "setup"}), + prevExist=False, + ttl=CONST_ETCD_SETUP_TTL, + ) + + try: + version_counter = self.client.get(self.get_path("/rdzv/version_counter")) + version_counter.value = str(int(version_counter.value) + 1) + self.client.update(version_counter) + except (etcd.EtcdKeyNotFound, etcd.EtcdCompareFailed) as e: + raise RendezvousError( + "Unexpected state of EtcdRendezvousHandler, worker needs to die." + ) from e + + # Any failure below results in declaring a retryable rendezvous failure. + # The ephemeral /rdzv/active_version will expire and someone can then + # re-try the setup process. + + # Create directory node for participant data + self.client.write( + key=self.get_path(f"/rdzv/v_{version_counter.value}"), + value=None, + dir=True, + prevExist=False, + ) + + # Publish rendezvous version and signal it is ready-to-be-joined. + # If rendezvous was set closed just before this, a retry will happen, + # where the closed condition will be handled. + return self.client.test_and_set( + key=self.get_path("/rdzv/active_version"), + value=json.dumps( + { + "status": "joinable", + "version": version_counter.value, + "participants": [], + } + ), + prev_value=active_version.value, + ) + + def join_rendezvous(self, expected_version): + """Helper method for the join phase.""" + # Use compare-and-swap to add self to rendezvous state: + while True: + cas_delay() + active_version, state = self.get_rdzv_state() + + if state["status"] != "joinable": + raise EtcdRendezvousRetryableFailure( + "Rendezvous state became non-joinable before we could join. " + "Must join next one." + ) + + if state["version"] != expected_version: + raise EtcdRendezvousRetryImmediately( + "Rendezvous version changed. Must try join the new one." + ) + + assert ( + len(state["participants"]) < self._num_max_workers + ), "Logic error: joinable rendezvous should always have space left" + + this_rank = len(state["participants"]) + state["participants"].append(this_rank) + + # When reaching min workers, or changing state to frozen, we'll set + # the active_version node to be ephemeral. + set_ttl: Optional[int] = None + if len(state["participants"]) == self._num_max_workers: + state["status"] = "frozen" + state["keep_alives"] = [] + set_ttl = CONST_ETCD_FROZEN_TTL + elif len(state["participants"]) >= self._num_min_workers: + set_ttl = CONST_ETCD_JOINABLE_EPHEMERAL_TTL + + try: + # Compare-and-swap. + active_version = self.client.test_and_set( + key=self.get_path("/rdzv/active_version"), + value=json.dumps(state), + prev_value=active_version.value, + ttl=set_ttl, + ) + # We succeeded joining. + return active_version, this_rank + + except etcd.EtcdCompareFailed: + logger.info("Join rendezvous CAS unsuccessful, retrying") + + def wait_for_peers(self, expected_version): + """Helper method for the join phase.""" + active_version, state = self.get_rdzv_state() + while True: + if state["status"] == "frozen" and state["version"] == expected_version: + # Success, all peers arrived. + return active_version + + elif state["status"] == "joinable" and state["version"] == expected_version: + # Continue waiting for any interesting events. + active_version, state = self.try_wait_for_state_change( + etcd_index=active_version.etcd_index + 1 + ) + + else: + # No valid transition possible at this point + raise EtcdRendezvousRetryableFailure( + "Rendezvous state transition no longer possible. Must re-enter." + ) + + def confirm_membership(self, expected_version, this_rank): + """Helper method for the confirm phase.""" + # Compare-and-swap loop + while True: + cas_delay() + active_version, state = self.get_rdzv_state() + + if state["status"] != "frozen": + raise EtcdRendezvousRetryImmediately( + "Rendezvous no longer frozen, before we confirmed. " + "Must join next one" + ) + if state["version"] != expected_version: + raise EtcdRendezvousRetryImmediately( + "Rendezvous version changed. Must try join the new one." + ) + + this_lease_key = self.get_path( + f"/rdzv/v_{expected_version}/rank_{this_rank}" + ) + self.client.set(this_lease_key, value=None, ttl=CONST_WORKER_KEEPALIVE_TTL) + + state["keep_alives"].append(this_lease_key) + if len(state["keep_alives"]) == len(state["participants"]): + # Everyone confirmed (this rank is last to do so) + state["status"] = "final" + state["num_workers_waiting"] = 0 + finalize = True + else: + finalize = False + + try: + # Compare-and-swap. If new state is still frozen, keep it ephemeral. + active_version = self.client.test_and_set( + key=self.get_path("/rdzv/active_version"), + value=json.dumps(state), + prev_value=active_version.value, + ttl=None if finalize else CONST_ETCD_FROZEN_TTL, + ) + + self._lease_this_rank_stop = self.setup_lease_renewal( + this_lease_key, ttl=CONST_WORKER_KEEPALIVE_TTL + ) + return active_version + + except etcd.EtcdCompareFailed: + logger.info("Confirm membership CAS unsuccessful, retrying") + + def wait_for_final(self, expected_version): + """Helper method for the confirm phase.""" + active_version, state = self.get_rdzv_state() + while True: + if state["status"] == "final" and state["version"] == expected_version: + # Success. This rendezvous is final, and we accept it. + return active_version + + elif state["status"] == "frozen" and state["version"] == expected_version: + # Continue waiting for any interesting events. + active_version, state = self.try_wait_for_state_change( + etcd_index=active_version.etcd_index + 1 + ) + + else: + # No valid transition possible at this point + raise EtcdRendezvousRetryableFailure( + "Rendezvous state transition no longer possible. Must re-enter." + ) + + def announce_self_waiting(self, expected_version): + """ + Announce this worker is waiting (via num_workers_waiting counter) to join next + rendezvous, but only if state and version match. + """ + while True: + cas_delay() + active_version, state = self.get_rdzv_state() + + if state["status"] != "final" or state["version"] != expected_version: + raise EtcdRendezvousRetryImmediately + + # Increment counter to signal an additional waiting worker. + state["num_workers_waiting"] += 1 + + try: + active_version = self.client.test_and_set( + key=self.get_path("/rdzv/active_version"), + value=json.dumps(state), + prev_value=active_version.value, + ) + return active_version + + except etcd.EtcdCompareFailed: + logger.info("Announce self as waiting CAS unsuccessful, retrying") + + def wait_for_rendezvous_to_free(self, expected_version): + """ + When there's an existing valid rendezvous in state 'final', we have to wait until the next opportunity to join. + + Such opportunity may come from: + + 1. rendezvous state changed by someone else, in which case we unblock and retry. + 2. rendezvous becomes invalid because at least one member failed to renew their + leased keep_alive node. We detect this, and destroy the rendezvous. + """ + active_version, state = self.get_rdzv_state() + while True: + if state["status"] != "final" or state["version"] != expected_version: + return + + # Check if current rendezvous state is valid, in the sense that all + # its members are alive (renewing their lease). + # If not, try destroy this rendezvous, so a new one can be created. + alive_members = self.client.get( + self.get_path(f"/rdzv/v_{expected_version}") + ) + keep_alive_keys = [ch.key for ch in alive_members.children] + + for key in state["keep_alives"]: + if key not in keep_alive_keys: + # This participant didn't renew their lease. We'll declare this + # rendezvous version as dead (but only if it hadn't changed) + logger.info("Keep-alive key %s is not renewed.", key) + logger.info( + "Rendezvous version %s is incomplete. ", expected_version + ) + logger.info("Attempting to destroy it.") + + # Compare-and-delete operation. Throws if compare failed, + # which means rendezvous was already destroyed/re-created/closed, + # and we can try to re-enter the barrier. + self.client.delete( + key=self.get_path("/rdzv/active_version"), + prevValue=active_version.value, + ) + + logger.info( + "Destroyed rendezvous version %s successfully.", + expected_version, + ) + + # We can return (and retry) immediately + return + + # Existing rendezvous seems valid, no reason to destroy it. + # We just have to wait until something changes and re-check. + try: + overall_timeout = ( + max(self._rendezvous_deadline - time.time(), 0.0) + 1.0 + ) + self.client.watch( + key=self.get_path("/rdzv"), + index=active_version.etcd_index + 1, + recursive=True, + timeout=overall_timeout, + ) + except (etcd.EtcdEventIndexCleared, etcd.EtcdWatchTimedOut): + pass + + if time.time() > self._rendezvous_deadline: + raise RendezvousTimeoutError + active_version, state = self.get_rdzv_state() + + def handle_join_last_call(self, expected_version, deadline): + """ + After we reach min number of workers, one particular worker takes on the + responsibility of waiting an additional timeout before closing the join window. + If the worker responsible for this fails, the rendezvous will be destroyed due + to expiring TTL, and the other participants will re-rendezvous. + + Here we expect to see state + Exit gracefully if either: + + 1. state becomes + 2. timeout happens (reaching deadline), in which case + we try the transition to + + Exit with exception otherwise. + """ + active_version, state = self.get_rdzv_state() + while True: + if state["status"] == "frozen" and state["version"] == expected_version: + # Worker set became frozen before last-call timeout. This is possible + # when num_max_workers is reached before the timeout. + return + + if state["status"] != "joinable" or state["version"] != expected_version: + raise EtcdRendezvousRetryableFailure( + "Rendezvous state transition no longer possible. Must re-enter." + ) + + # If timeout occurred, attempt a state transition (joinable -> frozen) + if time.time() >= deadline: + state["status"] = "frozen" + state["keep_alives"] = [] + try: + active_version = self.client.test_and_set( + key=self.get_path("/rdzv/active_version"), + value=json.dumps(state), + prev_value=active_version.value, + ttl=CONST_ETCD_FROZEN_TTL, + ) + # We successfully made this rendezvous frozen. + return + except etcd.EtcdCompareFailed: + logger.info( + "Join last-call transition CAS unsuccessful. Will retry" + ) + cas_delay() + active_version, state = self.get_rdzv_state() + continue + + # Timeout did not occur, so we must refresh TTL, and wait for + # further changes. Note: we only want TTL to be refreshed if + # state is still joinable, hence we use CAS for that here, + # even though we don't change any of the data. + try: + active_version = self.client.test_and_set( + key=self.get_path("/rdzv/active_version"), + value=active_version.value, + prev_value=active_version.value, + ttl=CONST_ETCD_JOINABLE_EPHEMERAL_TTL, + ) + + # Minimize "oversleeping": + timeout = min( + CONST_ETCD_JOINABLE_EPHEMERAL_TTL / 2, + deadline - time.time() + 1.0, # Oversleeping by 1s is ok. + ) + active_version, state = self.try_wait_for_state_change( + etcd_index=active_version.etcd_index + 1, timeout=timeout + ) + except etcd.EtcdCompareFailed: + logger.info("Join last-call TTL refresh CAS unsuccessful, will retry") + cas_delay() + active_version, state = self.get_rdzv_state() + + def set_closed(self): + """ + Mark rendezvous 'closed' for current run_id, which is used to signal other + participants to not attempt to perform (re-)rendezvous. This is useful + when one of the workers decides the job is complete. + """ + while True: + active_version, state = self.get_rdzv_state() + + if state["status"] == "closed": + # Already closed by someone else. + return + + state["status"] = "closed" + try: + self.client.test_and_set( + key=self.get_path("/rdzv/active_version"), + value=json.dumps(state), + prev_value=active_version.value, + ) + return + + except etcd.EtcdCompareFailed: + logger.info("Set closed CAS unsuccessful, retrying") + cas_delay() + + def get_rdzv_state(self): + active_version = self.client.get(key=self.get_path("/rdzv/active_version")) + return active_version, json.loads(active_version.value) + + def try_wait_for_state_change(self, etcd_index, timeout=None): + # Don't sleep past the overall deadline (at least more than by 1s) + overall_timeout = max(self._rendezvous_deadline - time.time(), 0.0) + 1.0 + timeout = overall_timeout if timeout is None else min(timeout, overall_timeout) + + try: + self.client.watch( + self.get_path("/rdzv/active_version"), index=etcd_index, timeout=timeout + ) + except (etcd.EtcdEventIndexCleared, etcd.EtcdWatchTimedOut): + pass + + if time.time() > self._rendezvous_deadline: + raise RendezvousTimeoutError + + # Unfortunately, we have to do another fetch in order to get last etcd_index. + return self.get_rdzv_state() + + def get_path(self, path): + if not path.startswith("/"): + path = "/" + path + + return f"{self._prefix}run_{self._run_id}{path}" + + def create_path_if_not_exists(self, full_path, ttl=None): + try: + self.client.write( + key=full_path, value=None, dir=True, prevExist=False, ttl=ttl + ) + except etcd.EtcdAlreadyExist: + pass + + def setup_lease_renewal(self, full_path, ttl): + # NOTE: For ephemeral key TTL renewal (~lease) to work correctly, + # make sure you don't call any long-blocking methods that do not + # release the Python's GIL! An example of this is calling a pybind11 + # extension function that is blocking / long-running, but is not + # doing a scoped release of the GIL. + def lease_worker(client, path, ttl, stop_event): + while True: + try: + client.refresh(path, ttl=ttl) + except etcd.EtcdKeyNotFound: + break + except ConnectionRefusedError: + # This error usually occurs during test when the server already got terminated but the + # python garbage collector have not yet invoked the __del__ method. + break + + if stop_event.wait(timeout=ttl / 2): + break + + lease_stop_event = threading.Event() + lease_thread = threading.Thread( + target=lease_worker, args=(self.client, full_path, ttl, lease_stop_event) + ) + + lease_thread.daemon = True + lease_thread.start() + + return lease_stop_event + + def store_extra_data(self, rdzv_version, key, value): + node = self.get_path(f"/rdzv/v_{rdzv_version}/extra_data") + try: + # If first time we are storing anything: + extra_data = self.client.write( + key=node, value=json.dumps({key: value}), prevExist=False + ) + return + except etcd.EtcdAlreadyExist: + pass + + # CAS loop, to make sure we don't lose concurrent stores. + while True: + # We never delete extra_data. Failure here should be fatal, no special handling. + extra_data = self.client.get(node) + + new_extra_data_value = json.loads(extra_data.value) + new_extra_data_value[key] = value + + try: + extra_data = self.client.test_and_set( + key=node, + value=json.dumps(new_extra_data_value), + prev_value=extra_data.value, + ) + return + except etcd.EtcdCompareFailed: + logger.info("Store extra_data CAS unsuccessful, retrying") + time.sleep(0.1) + + def load_extra_data(self, rdzv_version, key, timeout=None): + # 'extra_data' node itself, and the directory it is located in: + node = self.get_path(f"/rdzv/v_{rdzv_version}/extra_data") + node_dir = self.get_path(f"/rdzv/v_{rdzv_version}") + + # TODO: implement timeout + # https://github.com/pytorch/elastic/issues/12 + while True: + # Combined wait for the node itself, and the key inside it. + root = self.client.get(node_dir) + + # Find the extra_data node, if it exists + extra_data = [n for n in root.children if n.key == node] + assert len(extra_data) <= 1 + + # Node for extra_data exists, check the desired key inside it. + if len(extra_data) == 1: + extra_data_dict = json.loads(extra_data[0].value) + if key in extra_data_dict: + return extra_data_dict[key] + + # The 'extra_data' node doesn't exist, or they key isn't published yet. + # Wait for interesting events on the extra_data node and retry. + try: + self.client.watch(node, index=root.etcd_index + 1) + except (etcd.EtcdEventIndexCleared, etcd.EtcdWatchTimedOut): + pass + + def setup_kv_store(self, rdzv_version): + store_path = self.get_path(f"/rdzv/v_{rdzv_version}/kv") + self.create_path_if_not_exists(store_path) + return EtcdStore(etcd_client=self.client, etcd_store_prefix=store_path) + + +def _create_etcd_client(params: RendezvousParameters) -> etcd.Client: + """Create a new ``etcd.Client`` from the specified ``RendezvousParameters``.""" + hostname, port = parse_rendezvous_endpoint(params.endpoint, 2379) + + # The communication protocol + protocol = params.config.get("protocol") + if protocol is None: + protocol = "http" + else: + if protocol != "http" and protocol != "https": + raise ValueError("The etcd protocol must be HTTP or HTTPS.") + + # The SSL client certificate + ssl_cert = params.config.get("cert") + if ssl_cert is not None: + cert_key = params.config.get("key") + if cert_key is not None: + # The etcd client expects the certificate key as the second element + # of the `cert` tuple. + ssl_cert = (ssl_cert, cert_key) + + # The root certificate + ca_cert = params.config.get("cacert") + + return etcd.Client( + hostname, + port, + protocol=protocol, + cert=ssl_cert, + ca_cert=ca_cert, + allow_reconnect=True, + ) + + +# Handler for core.distributed "static" registration +def create_rdzv_handler(params: RendezvousParameters) -> RendezvousHandler: + """ + Usage: + + :: + + rdzv_params = RendezvousParameters( + backend="etcd", + endpoint="192.168.0.42:2379", + run_id="123", + min_nodes=4, + max_nodes=8, + timeout=300, + last_call_timeout=30, + etcd_prefix="custom_prefix", + protocol="https", + cacert="/etc/kubernetes/certs/ca.crt", + cert="/etc/kubernetes/certs/client.crt", + key="/etc/kubernetes/certs/client.key") + # -- or -- + rdzv_params = RendezvousParameters( + backend="etcd", + endpoint="192.168.0.42:2379", + run_id="123", + min_nodes=4, + max_nodes=8) + + etcd_rdzv_handler = create_etcd_rendezvous_handler(rdzv_params) + + + Where: + run_id - unique id for this training job instance, + min_nodes - min number of workers expected to join the rendezvous, + max_nodes - max number of workers allowed to join the rendezvous, + defaults to min_workers is not specified. + timeout - total timeout within which next_rendezvous is expected to + succeed; a RendezvousTimeoutError is raised otherwise; + Defaults is 600 (10 minutes). + last_call_timeout - additional wait amount ("last call") after + min number of workers has been reached. + Defaults to 30 seconds. + etcd_prefix - path prefix (from etcd root), inside which all + etcd nodes will be created. + Default is "/torchelastic/p2p". + protocol - http (default) or https to access etcd. + cacert - CA cert to access etcd, only makes sense with https. + cert - client cert to access etcd, only makes sense with https. + key - client key to access etcd, only makes sense with https. + """ + client = _create_etcd_client(params) + + etcd_prefix = params.get("etcd_prefix", "/torchelastic/p2p") + + rdzv = EtcdRendezvous( + client=client, + prefix=etcd_prefix, + run_id=params.run_id, + num_min_workers=params.min_nodes, + num_max_workers=params.max_nodes, + timeout=params.get_as_int("timeout", _DEFAULT_TIMEOUT), + last_call_timeout=params.get_as_int( + "last_call_timeout", _DEFAULT_LAST_CALL_TIMEOUT + ), + ) + return EtcdRendezvousHandler( + rdzv_impl=rdzv, + local_addr=params.local_addr, + ) diff --git a/mindnlp/core/distributed/elastic/rendezvous/etcd_rendezvous_backend.py b/mindnlp/core/distributed/elastic/rendezvous/etcd_rendezvous_backend.py new file mode 100644 index 000000000..82c109bce --- /dev/null +++ b/mindnlp/core/distributed/elastic/rendezvous/etcd_rendezvous_backend.py @@ -0,0 +1,217 @@ +# mypy: allow-untyped-defs +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import binascii +from base64 import b64decode, b64encode +from typing import cast, Optional, Tuple + +import urllib3.exceptions # type: ignore[import] +from etcd import ( # type: ignore[import] + Client as EtcdClient, + EtcdAlreadyExist, + EtcdCompareFailed, + EtcdException, + EtcdKeyNotFound, + EtcdResult, +) + +from core.distributed import Store + +from .api import RendezvousConnectionError, RendezvousParameters, RendezvousStateError +from .dynamic_rendezvous import RendezvousBackend, Token +from .etcd_store import EtcdStore +from .utils import parse_rendezvous_endpoint + + +class EtcdRendezvousBackend(RendezvousBackend): + """Represents an etcd-based rendezvous backend. + + Args: + client: + The ``etcd.Client`` instance to use to communicate with etcd. + run_id: + The run id of the rendezvous. + key_prefix: + The path under which to store the rendezvous state in etcd. + ttl: + The TTL of the rendezvous state. If not specified, defaults to two hours. + """ + + _DEFAULT_TTL = 7200 # 2 hours + + _client: EtcdClient + _key: str + _ttl: int + + def __init__( + self, + client: EtcdClient, + run_id: str, + key_prefix: Optional[str] = None, + ttl: Optional[int] = None, + ) -> None: + if not run_id: + raise ValueError("The run id must be a non-empty string.") + + self._client = client + + if key_prefix: + self._key = key_prefix + "/" + run_id + else: + self._key = run_id + + if ttl and ttl > 0: + self._ttl = ttl + else: + self._ttl = self._DEFAULT_TTL + + @property + def name(self) -> str: + """See base class.""" + return "etcd-v2" + + def get_state(self) -> Optional[Tuple[bytes, Token]]: + """See base class.""" + try: + result = self._client.read(self._key) + except EtcdKeyNotFound: + return None + except (EtcdException, urllib3.exceptions.TimeoutError) as exc: + raise RendezvousConnectionError( + "The connection to etcd has failed. See inner exception for details." + ) from exc + + return self._decode_state(result) + + def set_state( + self, state: bytes, token: Optional[Token] = None + ) -> Optional[Tuple[bytes, Token, bool]]: + """See base class.""" + base64_state = b64encode(state).decode() + + kwargs = {} + + def get_state(): + result = self.get_state() + if result is not None: + tmp = *result, False + # Python 3.6 does not support tuple unpacking in return + # statements. + return tmp + return None + + if token: + try: + token = int(token) + except ValueError: + return get_state() + + if token: + kwargs["prevIndex"] = token + else: + kwargs["prevExist"] = False + + try: + result = self._client.write(self._key, base64_state, self._ttl, **kwargs) + except (EtcdAlreadyExist, EtcdCompareFailed): + result = None + except (EtcdException, urllib3.exceptions.TimeoutError) as exc: + raise RendezvousConnectionError( + "The connection to etcd has failed. See inner exception for details." + ) from exc + + if result is None: + return get_state() + + tmp = *self._decode_state(result), True + return tmp + + def _decode_state(self, result: EtcdResult) -> Tuple[bytes, Token]: + base64_state = result.value.encode() + + try: + state = b64decode(base64_state) + except binascii.Error as exc: + raise RendezvousStateError( + "The state object is corrupt. See inner exception for details." + ) from exc + + return state, result.modifiedIndex + + +def _create_etcd_client(params: RendezvousParameters) -> EtcdClient: + host, port = parse_rendezvous_endpoint(params.endpoint, default_port=2379) + + # The timeout + read_timeout = cast(int, params.get_as_int("read_timeout", 60)) + if read_timeout <= 0: + raise ValueError("The read timeout must be a positive integer.") + + # The communication protocol + protocol = params.get("protocol", "http").strip().lower() + if protocol != "http" and protocol != "https": + raise ValueError("The protocol must be HTTP or HTTPS.") + + # The SSL client certificate + ssl_cert = params.get("ssl_cert") + if ssl_cert: + ssl_cert_key = params.get("ssl_cert_key") + if ssl_cert_key: + # The etcd client expects the certificate key as the second element + # of the `cert` tuple. + ssl_cert = (ssl_cert, ssl_cert_key) + + # The root certificate + ca_cert = params.get("ca_cert") + + try: + return EtcdClient( + host, + port, + read_timeout=read_timeout, + protocol=protocol, + cert=ssl_cert, + ca_cert=ca_cert, + allow_reconnect=True, + ) + except (EtcdException, urllib3.exceptions.TimeoutError) as exc: + raise RendezvousConnectionError( + "The connection to etcd has failed. See inner exception for details." + ) from exc + + +def create_backend(params: RendezvousParameters) -> Tuple[EtcdRendezvousBackend, Store]: + """Create a new :py:class:`EtcdRendezvousBackend` from the specified parameters. + + +--------------+-----------------------------------------------------------+ + | Parameter | Description | + +==============+===========================================================+ + | read_timeout | The read timeout, in seconds, for etcd operations. | + | | Defaults to 60 seconds. | + +--------------+-----------------------------------------------------------+ + | protocol | The protocol to use to communicate with etcd. Valid | + | | values are "http" and "https". Defaults to "http". | + +--------------+-----------------------------------------------------------+ + | ssl_cert | The path to the SSL client certificate to use along with | + | | HTTPS. Defaults to ``None``. | + +--------------+-----------------------------------------------------------+ + | ssl_cert_key | The path to the private key of the SSL client certificate | + | | to use along with HTTPS. Defaults to ``None``. | + +--------------+-----------------------------------------------------------+ + | ca_cert | The path to the rool SSL authority certificate. Defaults | + | | to ``None``. | + +--------------+-----------------------------------------------------------+ + """ + client = _create_etcd_client(params) + + backend = EtcdRendezvousBackend( + client, params.run_id, key_prefix="/torch/elastic/rendezvous" + ) + + store = EtcdStore(client, "/torch/elastic/store") + + return backend, store diff --git a/mindnlp/core/distributed/elastic/rendezvous/etcd_server.py b/mindnlp/core/distributed/elastic/rendezvous/etcd_server.py new file mode 100644 index 000000000..99623e0bb --- /dev/null +++ b/mindnlp/core/distributed/elastic/rendezvous/etcd_server.py @@ -0,0 +1,248 @@ +#!/usr/bin/env python3 +# mypy: allow-untyped-defs + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import atexit +import logging +import os +import shlex +import shutil +import socket +import subprocess +import tempfile +import time +from typing import Optional, TextIO, Union + + +try: + import etcd # type: ignore[import] +except ModuleNotFoundError: + pass + + +logger = logging.getLogger(__name__) + + +def find_free_port(): + """ + Find a free port and binds a temporary socket to it so that the port can be "reserved" until used. + + .. note:: the returned socket must be closed before using the port, + otherwise a ``address already in use`` error will happen. + The socket should be held and closed as close to the + consumer of the port as possible since otherwise, there + is a greater chance of race-condition where a different + process may see the port as being free and take it. + + Returns: a socket binded to the reserved free port + + Usage:: + + sock = find_free_port() + port = sock.getsockname()[1] + sock.close() + use_port(port) + """ + addrs = socket.getaddrinfo( + host="localhost", port=None, family=socket.AF_UNSPEC, type=socket.SOCK_STREAM + ) + + for addr in addrs: + family, type, proto, _, _ = addr + try: + s = socket.socket(family, type, proto) + s.bind(("localhost", 0)) + s.listen(0) + return s + except OSError as e: + s.close() # type: ignore[possibly-undefined] + print(f"Socket creation attempt failed: {e}") + raise RuntimeError("Failed to create a socket") + + +def stop_etcd(subprocess, data_dir: Optional[str] = None): + if subprocess and subprocess.poll() is None: + logger.info("stopping etcd server") + subprocess.terminate() + subprocess.wait() + + if data_dir: + logger.info("deleting etcd data dir: %s", data_dir) + shutil.rmtree(data_dir, ignore_errors=True) + + +class EtcdServer: + """ + .. note:: tested on etcd server v3.4.3. + + Starts and stops a local standalone etcd server on a random free + port. Useful for single node, multi-worker launches or testing, + where a sidecar etcd server is more convenient than having to + separately setup an etcd server. + + This class registers a termination handler to shutdown the etcd + subprocess on exit. This termination handler is NOT a substitute for + calling the ``stop()`` method. + + The following fallback mechanism is used to find the etcd binary: + + 1. Uses env var TORCHELASTIC_ETCD_BINARY_PATH + 2. Uses ``/bin/etcd`` if one exists + 3. Uses ``etcd`` from ``PATH`` + + Usage + :: + + server = EtcdServer("/usr/bin/etcd", 2379, "/tmp/default.etcd") + server.start() + client = server.get_client() + # use client + server.stop() + + Args: + etcd_binary_path: path of etcd server binary (see above for fallback path) + """ + + def __init__(self, data_dir: Optional[str] = None): + self._port = -1 + self._host = "localhost" + + root = os.path.dirname(__file__) + default_etcd_bin = os.path.join(root, "bin/etcd") + self._etcd_binary_path = os.environ.get( + "TORCHELASTIC_ETCD_BINARY_PATH", default_etcd_bin + ) + if not os.path.isfile(self._etcd_binary_path): + self._etcd_binary_path = "etcd" + + self._base_data_dir = ( + data_dir if data_dir else tempfile.mkdtemp(prefix="torchelastic_etcd_data") + ) + self._etcd_cmd = None + self._etcd_proc: Optional[subprocess.Popen] = None + + def _get_etcd_server_process(self) -> subprocess.Popen: + if not self._etcd_proc: + raise RuntimeError( + "No etcd server process started. Call etcd_server.start() first" + ) + else: + return self._etcd_proc + + def get_port(self) -> int: + """Return the port the server is running on.""" + return self._port + + def get_host(self) -> str: + """Return the host the server is running on.""" + return self._host + + def get_endpoint(self) -> str: + """Return the etcd server endpoint (host:port).""" + return f"{self._host}:{self._port}" + + def start( + self, + timeout: int = 60, + num_retries: int = 3, + stderr: Union[int, TextIO, None] = None, + ) -> None: + """ + Start the server, and waits for it to be ready. When this function returns the sever is ready to take requests. + + Args: + timeout: time (in seconds) to wait for the server to be ready + before giving up. + num_retries: number of retries to start the server. Each retry + will wait for max ``timeout`` before considering it as failed. + stderr: the standard error file handle. Valid values are + `subprocess.PIPE`, `subprocess.DEVNULL`, an existing file + descriptor (a positive integer), an existing file object, and + `None`. + + Raises: + TimeoutError: if the server is not ready within the specified timeout + """ + curr_retries = 0 + while True: + try: + data_dir = os.path.join(self._base_data_dir, str(curr_retries)) + os.makedirs(data_dir, exist_ok=True) + return self._start(data_dir, timeout, stderr) + except Exception as e: + curr_retries += 1 + stop_etcd(self._etcd_proc) + logger.warning( + "Failed to start etcd server, got error: %s, retrying", str(e) + ) + if curr_retries >= num_retries: + shutil.rmtree(self._base_data_dir, ignore_errors=True) + raise + atexit.register(stop_etcd, self._etcd_proc, self._base_data_dir) + + def _start( + self, data_dir: str, timeout: int = 60, stderr: Union[int, TextIO, None] = None + ) -> None: + sock = find_free_port() + sock_peer = find_free_port() + self._port = sock.getsockname()[1] + peer_port = sock_peer.getsockname()[1] + + etcd_cmd = shlex.split( + " ".join( + [ + self._etcd_binary_path, + "--enable-v2", + "--data-dir", + data_dir, + "--listen-client-urls", + f"http://{self._host}:{self._port}", + "--advertise-client-urls", + f"http://{self._host}:{self._port}", + "--listen-peer-urls", + f"http://{self._host}:{peer_port}", + ] + ) + ) + + logger.info("Starting etcd server: [%s]", etcd_cmd) + + sock.close() + sock_peer.close() + self._etcd_proc = subprocess.Popen(etcd_cmd, close_fds=True, stderr=stderr) + self._wait_for_ready(timeout) + + def get_client(self): + """Return an etcd client object that can be used to make requests to this server.""" + return etcd.Client( + host=self._host, port=self._port, version_prefix="/v2", read_timeout=10 + ) + + def _wait_for_ready(self, timeout: int = 60) -> None: + client = etcd.Client( + host=f"{self._host}", port=self._port, version_prefix="/v2", read_timeout=5 + ) + max_time = time.time() + timeout + + while time.time() < max_time: + if self._get_etcd_server_process().poll() is not None: + # etcd server process finished + exitcode = self._get_etcd_server_process().returncode + raise RuntimeError( + f"Etcd server process exited with the code: {exitcode}" + ) + try: + logger.info("etcd server ready. version: %s", client.version) + return + except Exception: + time.sleep(1) + raise TimeoutError("Timed out waiting for etcd server to be ready!") + + def stop(self) -> None: + """Stop the server and cleans up auto generated resources (e.g. data dir).""" + logger.info("EtcdServer stop method called") + stop_etcd(self._etcd_proc, self._base_data_dir) diff --git a/mindnlp/core/distributed/elastic/rendezvous/etcd_store.py b/mindnlp/core/distributed/elastic/rendezvous/etcd_store.py new file mode 100644 index 000000000..b4155a6f6 --- /dev/null +++ b/mindnlp/core/distributed/elastic/rendezvous/etcd_store.py @@ -0,0 +1,212 @@ +# mypy: allow-untyped-defs +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import datetime +import random +import time +from base64 import b64decode, b64encode +from typing import Optional + +import etcd # type: ignore[import] + +# pyre-ignore[21]: Could not find name `Store` in `core.distributed`. +from core.distributed import Store + + +# Delay (sleep) for a small random amount to reduce CAS failures. +# This does not affect correctness, but will reduce requests to etcd server. +def cas_delay(): + time.sleep(random.uniform(0, 0.1)) + + +# pyre-fixme[11]: Annotation `Store` is not defined as a type. +class EtcdStore(Store): + """ + Implement a c10 Store interface by piggybacking on the rendezvous etcd instance. + + This is the store object returned by ``EtcdRendezvous``. + """ + + def __init__( + self, + etcd_client, + etcd_store_prefix, + # Default timeout same as in c10d/Store.hpp + timeout: Optional[datetime.timedelta] = None, + ): + super().__init__() # required for pybind trampoline. + + self.client = etcd_client + self.prefix = etcd_store_prefix + + if timeout is not None: + self.set_timeout(timeout) + + if not self.prefix.endswith("/"): + self.prefix += "/" + + def set(self, key, value): + """ + Write a key/value pair into ``EtcdStore``. + + Both key and value may be either Python ``str`` or ``bytes``. + """ + self.client.set(key=self.prefix + self._encode(key), value=self._encode(value)) + + def get(self, key) -> bytes: + """ + Get a value by key, possibly doing a blocking wait. + + If key is not immediately present, will do a blocking wait + for at most ``timeout`` duration or until the key is published. + + + Returns: + value ``(bytes)`` + + Raises: + LookupError - If key still not published after timeout + """ + b64_key = self.prefix + self._encode(key) + kvs = self._try_wait_get([b64_key]) + + if kvs is None: + raise LookupError(f"Key {key} not found in EtcdStore") + + return self._decode(kvs[b64_key]) + + def add(self, key, num: int) -> int: + """ + Atomically increment a value by an integer amount. + + The integer is represented as a string using base 10. If key is not present, + a default value of ``0`` will be assumed. + + Returns: + the new (incremented) value + + + """ + b64_key = self._encode(key) + # c10d Store assumes value is an integer represented as a decimal string + try: + # Assume default value "0", if this key didn't yet: + node = self.client.write( + key=self.prefix + b64_key, + value=self._encode(str(num)), # i.e. 0 + num + prevExist=False, + ) + return int(self._decode(node.value)) + except etcd.EtcdAlreadyExist: + pass + + while True: + # Note: c10d Store does not have a method to delete keys, so we + # can be sure it's still there. + node = self.client.get(key=self.prefix + b64_key) + new_value = self._encode(str(int(self._decode(node.value)) + num)) + try: + node = self.client.test_and_set( + key=node.key, value=new_value, prev_value=node.value + ) + return int(self._decode(node.value)) + except etcd.EtcdCompareFailed: + cas_delay() + + def wait(self, keys, override_timeout: Optional[datetime.timedelta] = None): + """ + Wait until all of the keys are published, or until timeout. + + Raises: + LookupError - if timeout occurs + """ + b64_keys = [self.prefix + self._encode(key) for key in keys] + kvs = self._try_wait_get(b64_keys, override_timeout) + if kvs is None: + raise LookupError("Timeout while waiting for keys in EtcdStore") + # No return value on success + + def check(self, keys) -> bool: + """Check if all of the keys are immediately present (without waiting).""" + b64_keys = [self.prefix + self._encode(key) for key in keys] + kvs = self._try_wait_get( + b64_keys, + override_timeout=datetime.timedelta(microseconds=1), # as if no wait + ) + return kvs is not None + + # + # Encode key/value data in base64, so we can store arbitrary binary data + # in EtcdStore. Input can be `str` or `bytes`. + # In case of `str`, utf-8 encoding is assumed. + # + def _encode(self, value) -> str: + if type(value) == bytes: + return b64encode(value).decode() + elif type(value) == str: + return b64encode(value.encode()).decode() + raise ValueError("Value must be of type str or bytes") + + # + # Decode a base64 string (of type `str` or `bytes`). + # Return type is `bytes`, which is more convenient with the Store interface. + # + def _decode(self, value) -> bytes: + if type(value) == bytes: + return b64decode(value) + elif type(value) == str: + return b64decode(value.encode()) + raise ValueError("Value must be of type str or bytes") + + # + # Get all of the (base64-encoded) etcd keys at once, or wait until all the keys + # are published or timeout occurs. + # This is a helper method for the public interface methods. + # + # On success, a dictionary of {etcd key -> etcd value} is returned. + # On timeout, None is returned. + # + def _try_wait_get(self, b64_keys, override_timeout=None): + timeout = self.timeout if override_timeout is None else override_timeout # type: ignore[attr-defined] + deadline = time.time() + timeout.total_seconds() + + while True: + # Read whole directory (of keys), filter only the ones waited for + all_nodes = None + try: + all_nodes = self.client.get(key=self.prefix) + req_nodes = { + node.key: node.value + for node in all_nodes.children + if node.key in b64_keys + } + + if len(req_nodes) == len(b64_keys): + # All keys are available + return req_nodes + except etcd.EtcdKeyNotFound: + pass + + watch_timeout = deadline - time.time() + if watch_timeout <= 0: + return None + + try: + index = all_nodes.etcd_index + 1 if all_nodes else 0 + self.client.watch( + key=self.prefix, + recursive=True, + timeout=watch_timeout, + index=index, + ) + except etcd.EtcdWatchTimedOut: + if time.time() >= deadline: + return None + else: + continue + except etcd.EtcdEventIndexCleared: + continue diff --git a/mindnlp/core/distributed/elastic/rendezvous/registry.py b/mindnlp/core/distributed/elastic/rendezvous/registry.py new file mode 100644 index 000000000..b6ba47a86 --- /dev/null +++ b/mindnlp/core/distributed/elastic/rendezvous/registry.py @@ -0,0 +1,96 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import sys + +from .api import ( + rendezvous_handler_registry as handler_registry, + RendezvousHandler, + RendezvousParameters, +) +from .dynamic_rendezvous import create_handler + + +if sys.version_info < (3, 10): + from importlib_metadata import entry_points +else: + from importlib.metadata import entry_points + +log = logging.getLogger(__name__) + +__all__ = ["get_rendezvous_handler"] + + +def _create_static_handler(params: RendezvousParameters) -> RendezvousHandler: + from . import static_tcp_rendezvous + + return static_tcp_rendezvous.create_rdzv_handler(params) + + +def _create_etcd_handler(params: RendezvousParameters) -> RendezvousHandler: + from . import etcd_rendezvous + + return etcd_rendezvous.create_rdzv_handler(params) + + +def _create_etcd_v2_handler(params: RendezvousParameters) -> RendezvousHandler: + from .etcd_rendezvous_backend import create_backend + + backend, store = create_backend(params) + + return create_handler(store, backend, params) + + +def _create_c10d_handler(params: RendezvousParameters) -> RendezvousHandler: + from .c10d_rendezvous_backend import create_backend + + backend, store = create_backend(params) + + return create_handler(store, backend, params) + + +def _register_default_handlers() -> None: + handler_registry.register("etcd", _create_etcd_handler) + handler_registry.register("etcd-v2", _create_etcd_v2_handler) + handler_registry.register("c10d", _create_c10d_handler) + handler_registry.register("static", _create_static_handler) + + +def _register_out_of_tree_handlers() -> None: + discovered_handler_generators = entry_points(group="torchrun.handlers") + + for handler_generator in discovered_handler_generators: + try: + get_handler = discovered_handler_generators[handler_generator.name].load() + handler_registry.register(handler_generator.name, get_handler()) + except Exception: + log.warning( + "Exception while registering out of tree plugin %s: ", + handler_generator.name, + exc_info=True, + ) + + +def get_rendezvous_handler(params: RendezvousParameters) -> RendezvousHandler: + """ + Obtain a reference to a :py:class`RendezvousHandler`. + + Custom rendezvous handlers can be registered by + + :: + + from core.distributed.elastic.rendezvous import rendezvous_handler_registry + from core.distributed.elastic.rendezvous.registry import get_rendezvous_handler + + def create_my_rdzv(params: RendezvousParameters): + return MyCustomRdzv(params) + + rendezvous_handler_registry.register("my_rdzv_backend_name", create_my_rdzv) + + my_rdzv_handler = get_rendezvous_handler("my_rdzv_backend_name", RendezvousParameters) + """ + return handler_registry.create_handler(params) diff --git a/mindnlp/core/distributed/elastic/rendezvous/static_tcp_rendezvous.py b/mindnlp/core/distributed/elastic/rendezvous/static_tcp_rendezvous.py new file mode 100644 index 000000000..0ab6e641b --- /dev/null +++ b/mindnlp/core/distributed/elastic/rendezvous/static_tcp_rendezvous.py @@ -0,0 +1,128 @@ +#!/usr/bin/env python3 +# mypy: allow-untyped-defs + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import datetime +import logging +from typing import cast, Optional + +from core.distributed import PrefixStore, Store, TCPStore +from core.distributed.elastic.rendezvous import ( + RendezvousHandler, + RendezvousInfo, + RendezvousParameters, + RendezvousStoreInfo, +) +from core.distributed.elastic.rendezvous.utils import parse_rendezvous_endpoint + + +__all__ = ["StaticTCPRendezvous", "create_rdzv_handler"] + +logger = logging.getLogger(__name__) + +_default_timeout_seconds = 600 + + +class StaticTCPRendezvous(RendezvousHandler): + """ + Static rendezvous that is a wrapper around the TCPStore. + + Creates TCPStore based on the input parameters with the + listener on the agent with group_rank=0 + """ + + def __init__( + self, + master_addr: str, + master_port: int, + rank: int, + world_size: int, + run_id: str, + timeout: int, + ): + self.master_addr = master_addr + self.master_port = master_port + self.rank = rank + self.world_size = world_size + self.run_id = run_id + self.timeout = datetime.timedelta(seconds=timeout) + self._store: Optional[Store] = None + + def get_backend(self) -> str: + return "static" + + @property + def use_agent_store(self) -> bool: + return True + + def next_rendezvous(self) -> RendezvousInfo: + logger.info("Creating TCPStore as the c10d::Store implementation") + is_master = self.rank == 0 + if not self._store: + self._store = TCPStore( # type: ignore[call-arg] + self.master_addr, + self.master_port, + self.world_size, + is_master, + self.timeout, + multi_tenant=True, + ) + store = PrefixStore(self.run_id, self._store) + # TCPStore server instance is used by trainer code + bootstrap_store_info = RendezvousStoreInfo(self.master_addr, self.master_port) + return RendezvousInfo( + store, + self.rank, + self.world_size, + bootstrap_store_info, + ) + + def is_closed(self): + return False + + def set_closed(self): + pass + + def num_nodes_waiting(self): + return 0 + + def get_run_id(self) -> str: + return self.run_id + + def shutdown(self) -> bool: + return True + + +def create_rdzv_handler(params: RendezvousParameters) -> RendezvousHandler: + if "rank" not in params.config: + raise ValueError( + "rank is absent in RendezvousParameters." + "Try add --node-rank to the cmd request" + ) + endpoint = params.endpoint.strip() + if not endpoint: + raise ValueError( + "endpoint is absent in RendezvousParameters" + "Try add --master-port and --master-addr to the cmd request" + ) + master_addr, master_port = parse_rendezvous_endpoint(endpoint, -1) + if master_port == -1: + raise ValueError( + f"Port is absent in endpoint: {endpoint}. Try launching with --master-port" + ) + world_size = params.max_nodes + rank = cast(int, params.config.get("rank")) + run_id = params.run_id + if "timeout" in params.config: + timeout = int(params.config["timeout"]) + else: + timeout = _default_timeout_seconds + + return StaticTCPRendezvous( + master_addr, master_port, rank, world_size, run_id, timeout + ) diff --git a/mindnlp/core/distributed/elastic/rendezvous/utils.py b/mindnlp/core/distributed/elastic/rendezvous/utils.py new file mode 100644 index 000000000..f946209a9 --- /dev/null +++ b/mindnlp/core/distributed/elastic/rendezvous/utils.py @@ -0,0 +1,284 @@ +# mypy: allow-untyped-defs +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import ipaddress +import random +import re +import socket +import time +import weakref +from datetime import timedelta +from threading import Event, Thread +from typing import Any, Callable, Dict, Optional, Tuple, Union + + +__all__ = ["parse_rendezvous_endpoint"] + + +def _parse_rendezvous_config(config_str: str) -> Dict[str, str]: + """Extract key-value pairs from a rendezvous configuration string. + + Args: + config_str: + A string in format =,...,=. + """ + config: Dict[str, str] = {} + + config_str = config_str.strip() + if not config_str: + return config + + key_values = config_str.split(",") + for kv in key_values: + key, *values = kv.split("=", 1) + + key = key.strip() + if not key: + raise ValueError( + "The rendezvous configuration string must be in format " + "=,...,=." + ) + + value: Optional[str] + if values: + value = values[0].strip() + else: + value = None + if not value: + raise ValueError( + f"The rendezvous configuration option '{key}' must have a value specified." + ) + + config[key] = value + return config + + +def _try_parse_port(port_str: str) -> Optional[int]: + """Try to extract the port number from ``port_str``.""" + if port_str and re.match(r"^[0-9]{1,5}$", port_str): + return int(port_str) + return None + + +def parse_rendezvous_endpoint( + endpoint: Optional[str], default_port: int +) -> Tuple[str, int]: + """Extract the hostname and the port number from a rendezvous endpoint. + + Args: + endpoint: + A string in format [:]. + default_port: + The port number to use if the endpoint does not include one. + + Returns: + A tuple of hostname and port number. + """ + if endpoint is not None: + endpoint = endpoint.strip() + + if not endpoint: + return ("localhost", default_port) + + # An endpoint that starts and ends with brackets represents an IPv6 address. + if endpoint[0] == "[" and endpoint[-1] == "]": + host, *rest = endpoint, *[] + else: + host, *rest = endpoint.rsplit(":", 1) + + # Sanitize the IPv6 address. + if len(host) > 1 and host[0] == "[" and host[-1] == "]": + host = host[1:-1] + + if len(rest) == 1: + port = _try_parse_port(rest[0]) + if port is None or port >= 2**16: + raise ValueError( + f"The port number of the rendezvous endpoint '{endpoint}' must be an integer " + "between 0 and 65536." + ) + else: + port = default_port + + if not re.match(r"^[\w\.:-]+$", host): + raise ValueError( + f"The hostname of the rendezvous endpoint '{endpoint}' must be a dot-separated list of " + "labels, an IPv4 address, or an IPv6 address." + ) + + return host, port + + +def _matches_machine_hostname(host: str) -> bool: + """Indicate whether ``host`` matches the hostname of this machine. + + This function compares ``host`` to the hostname as well as to the IP + addresses of this machine. Note that it may return a false negative if this + machine has CNAME records beyond its FQDN or IP addresses assigned to + secondary NICs. + """ + if host == "localhost": + return True + + try: + addr = ipaddress.ip_address(host) + except ValueError: + addr = None + + if addr and addr.is_loopback: + return True + + try: + host_addr_list = socket.getaddrinfo( + host, None, proto=socket.IPPROTO_TCP, flags=socket.AI_CANONNAME + ) + except (ValueError, socket.gaierror) as _: + host_addr_list = [] + + host_ip_list = [host_addr_info[4][0] for host_addr_info in host_addr_list] + + this_host = socket.gethostname() + if host == this_host: + return True + + addr_list = socket.getaddrinfo( + this_host, None, proto=socket.IPPROTO_TCP, flags=socket.AI_CANONNAME + ) + for addr_info in addr_list: + # If we have an FQDN in the addr_info, compare it to `host`. + if addr_info[3] and addr_info[3] == host: + return True + + # Otherwise if `host` represents an IP address, compare it to our IP + # address. + if addr and addr_info[4][0] == str(addr): + return True + + # If the IP address matches one of the provided host's IP addresses + if addr_info[4][0] in host_ip_list: + return True + + return False + + +def _delay(seconds: Union[float, Tuple[float, float]]) -> None: + """Suspend the current thread for ``seconds``. + + Args: + seconds: + Either the delay, in seconds, or a tuple of a lower and an upper + bound within which a random delay will be picked. + """ + if isinstance(seconds, tuple): + seconds = random.uniform(*seconds) + # Ignore delay requests that are less than 10 milliseconds. + if seconds >= 0.01: + time.sleep(seconds) + + +class _PeriodicTimer: + """Represent a timer that periodically runs a specified function. + + Args: + interval: + The interval, in seconds, between each run. + function: + The function to run. + """ + + # The state of the timer is hold in a separate context object to avoid a + # reference cycle between the timer and the background thread. + class _Context: + interval: float + function: Callable[..., None] + args: Tuple[Any, ...] + kwargs: Dict[str, Any] + stop_event: Event + + _name: Optional[str] + _thread: Optional[Thread] + _finalizer: Optional[weakref.finalize] + + # The context that is shared between the timer and the background thread. + _ctx: _Context + + def __init__( + self, + interval: timedelta, + function: Callable[..., None], + *args: Any, + **kwargs: Any, + ) -> None: + self._name = None + + self._ctx = self._Context() + self._ctx.interval = interval.total_seconds() + self._ctx.function = function # type: ignore[assignment] + self._ctx.args = args or () + self._ctx.kwargs = kwargs or {} + self._ctx.stop_event = Event() + + self._thread = None + self._finalizer = None + + @property + def name(self) -> Optional[str]: + """Get the name of the timer.""" + return self._name + + def set_name(self, name: str) -> None: + """Set the name of the timer. + + The specified name will be assigned to the background thread and serves + for debugging and troubleshooting purposes. + """ + if self._thread: + raise RuntimeError("The timer has already started.") + + self._name = name + + def start(self) -> None: + """Start the timer.""" + if self._thread: + raise RuntimeError("The timer has already started.") + + self._thread = Thread( + target=self._run, + name=self._name or "PeriodicTimer", + args=(self._ctx,), + daemon=True, + ) + + # We avoid using a regular finalizer (a.k.a. __del__) for stopping the + # timer as joining a daemon thread during the interpreter shutdown can + # cause deadlocks. The weakref.finalize is a superior alternative that + # provides a consistent behavior regardless of the GC implementation. + self._finalizer = weakref.finalize( + self, self._stop_thread, self._thread, self._ctx.stop_event + ) + + # We do not attempt to stop our background thread during the interpreter + # shutdown. At that point we do not even know whether it still exists. + self._finalizer.atexit = False + + self._thread.start() + + def cancel(self) -> None: + """Stop the timer at the next opportunity.""" + if self._finalizer: + self._finalizer() + + @staticmethod + def _run(ctx) -> None: + while not ctx.stop_event.wait(ctx.interval): + ctx.function(*ctx.args, **ctx.kwargs) + + @staticmethod + def _stop_thread(thread, stop_event): + stop_event.set() + + thread.join() diff --git a/mindnlp/core/distributed/elastic/timer/__init__.py b/mindnlp/core/distributed/elastic/timer/__init__.py new file mode 100644 index 000000000..3c85cba34 --- /dev/null +++ b/mindnlp/core/distributed/elastic/timer/__init__.py @@ -0,0 +1,54 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Expiration timers are set up on the same process as the agent and +used from your script to deal with stuck workers. When you go into +a code-block that has the potential to get stuck you can acquire +an expiration timer, which instructs the timer server to kill the +process if it does not release the timer by the self-imposed expiration +deadline. + +Usage:: + + from mindnlp import coreelastic.timer as timer + from mindnlp import coreelastic.agent.server as agent + + def main(): + start_method = "spawn" + message_queue = mp.get_context(start_method).Queue() + server = timer.LocalTimerServer(message, max_interval=0.01) + server.start() # non-blocking + + spec = WorkerSpec( + fn=trainer_func, + args=(message_queue,), + ...) + agent = agent.LocalElasticAgent(spec, start_method) + agent.run() + + def trainer_func(message_queue): + timer.configure(timer.LocalTimerClient(message_queue)) + with timer.expires(after=60): # 60 second expiry + # do some work + +In the example above if ``trainer_func`` takes more than 60 seconds to +complete, then the worker process is killed and the agent retries the worker group. +""" + +from .api import ( # noqa: F401 + configure, + expires, + TimerClient, + TimerRequest, + TimerServer, +) +from .file_based_local_timer import ( # noqa: F401 + FileTimerClient, + FileTimerRequest, + FileTimerServer, +) +from .local_timer import LocalTimerClient, LocalTimerServer # noqa: F401 diff --git a/mindnlp/core/distributed/elastic/timer/api.py b/mindnlp/core/distributed/elastic/timer/api.py new file mode 100644 index 000000000..ae587255c --- /dev/null +++ b/mindnlp/core/distributed/elastic/timer/api.py @@ -0,0 +1,283 @@ +# mypy: allow-untyped-defs +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import abc +import logging +import threading +import time +from contextlib import contextmanager +from inspect import getframeinfo, stack +from typing import Any, Dict, List, Optional, Set + + +__all__ = [ + "TimerRequest", + "TimerClient", + "RequestQueue", + "TimerServer", + "configure", + "expires", +] + +logger = logging.getLogger(__name__) + + +class TimerRequest: + """ + Data object representing a countdown timer acquisition and release + that is used between the ``TimerClient`` and ``TimerServer``. + A negative ``expiration_time`` should be interpreted as a "release" + request. + + .. note:: the type of ``worker_id`` is implementation specific. + It is whatever the TimerServer and TimerClient implementations + have on to uniquely identify a worker. + """ + + __slots__ = ["worker_id", "scope_id", "expiration_time"] + + def __init__(self, worker_id: Any, scope_id: str, expiration_time: float): + self.worker_id = worker_id + self.scope_id = scope_id + self.expiration_time = expiration_time + + def __eq__(self, other): + if isinstance(other, TimerRequest): + return ( + self.worker_id == other.worker_id + and self.scope_id == other.scope_id + and self.expiration_time == other.expiration_time + ) + return False + + +class TimerClient(abc.ABC): + """ + Client library to acquire and release countdown timers by communicating + with the TimerServer. + """ + + @abc.abstractmethod + def acquire(self, scope_id: str, expiration_time: float) -> None: + """ + Acquires a timer for the worker that holds this client object + given the scope_id and expiration_time. Typically registers + the timer with the TimerServer. + """ + + @abc.abstractmethod + def release(self, scope_id: str): + """ + Releases the timer for the ``scope_id`` on the worker this + client represents. After this method is + called, the countdown timer on the scope is no longer in effect. + """ + + +class RequestQueue(abc.ABC): + """ + Consumer queue holding timer acquisition/release requests + """ + + @abc.abstractmethod + def size(self) -> int: + """ + Returns the size of the queue at the time this method is called. + Note that by the time ``get`` is called the size of the queue + may have increased. The size of the queue should not decrease + until the ``get`` method is called. That is, the following assertion + should hold: + + size = q.size() + res = q.get(size, timeout=0) + assert size == len(res) + + -- or -- + + size = q.size() + res = q.get(size * 2, timeout=1) + assert size <= len(res) <= size * 2 + """ + + @abc.abstractmethod + def get(self, size: int, timeout: float) -> List[TimerRequest]: + """ + Gets up to ``size`` number of timer requests in a blocking fashion + (no more than ``timeout`` seconds). + """ + + +class TimerServer(abc.ABC): + """ + Entity that monitors active timers and expires them + in a timely fashion. This server is responsible for + reaping workers that have expired timers. + """ + + def __init__( + self, request_queue: RequestQueue, max_interval: float, daemon: bool = True + ): + """ + :param request_queue: Consumer ``RequestQueue`` + :param max_interval: max time (in seconds) to wait + for an item in the request_queue + :param daemon: whether to run the watchdog thread as a daemon + """ + super().__init__() + self._request_queue = request_queue + self._max_interval = max_interval + self._daemon = daemon + self._watchdog_thread: Optional[threading.Thread] = None + self._stop_signaled = False + + @abc.abstractmethod + def register_timers(self, timer_requests: List[TimerRequest]) -> None: + """ + Processes the incoming timer requests and registers them with the server. + The timer request can either be a acquire-timer or release-timer request. + Timer requests with a negative expiration_time should be interpreted + as a release-timer request. + """ + + @abc.abstractmethod + def clear_timers(self, worker_ids: Set[Any]) -> None: + """ + Clears all timers for the given ``worker_ids``. + """ + + @abc.abstractmethod + def get_expired_timers(self, deadline: float) -> Dict[str, List[TimerRequest]]: + """ + Returns all expired timers for each worker_id. An expired timer + is a timer for which the expiration_time is less than or equal to + the provided deadline. + """ + + @abc.abstractmethod + def _reap_worker(self, worker_id: Any) -> bool: + """ + Reaps the given worker. Returns True if the worker has been + successfully reaped, False otherwise. If any uncaught exception + is thrown from this method, the worker is considered reaped + and all associated timers will be removed. + """ + + def _reap_worker_no_throw(self, worker_id: Any) -> bool: + """ + Wraps ``_reap_worker(worker_id)``, if an uncaught exception is + thrown, then it considers the worker as reaped. + """ + try: + return self._reap_worker(worker_id) + except Exception: + logger.exception( + "Uncaught exception thrown from _reap_worker(), " + "check that the implementation correctly catches exceptions", + ) + return True + + def _watchdog_loop(self): + while not self._stop_signaled: + try: + self._run_watchdog() + except Exception: + logger.exception("Error running watchdog") + + def _run_watchdog(self): + batch_size = max(1, self._request_queue.size()) + timer_requests = self._request_queue.get(batch_size, self._max_interval) + self.register_timers(timer_requests) + now = time.time() + reaped_worker_ids = set() + for worker_id, expired_timers in self.get_expired_timers(now).items(): + logger.info( + "Reaping worker_id=[%s]." " Expired timers: %s", + worker_id, + self._get_scopes(expired_timers), + ) + if self._reap_worker_no_throw(worker_id): + logger.info("Successfully reaped worker=[%s]", worker_id) + reaped_worker_ids.add(worker_id) + else: + logger.error( + "Error reaping worker=[%s]. Will retry on next watchdog.", worker_id + ) + self.clear_timers(reaped_worker_ids) + + def _get_scopes(self, timer_requests): + return [r.scope_id for r in timer_requests] + + def start(self) -> None: + logger.info( + "Starting %s..." " max_interval=%s," " daemon=%s", + type(self).__name__, + self._max_interval, + self._daemon, + ) + self._watchdog_thread = threading.Thread( + target=self._watchdog_loop, daemon=self._daemon + ) + logger.info("Starting watchdog thread...") + self._watchdog_thread.start() + + def stop(self) -> None: + logger.info("Stopping %s", type(self).__name__) + self._stop_signaled = True + if self._watchdog_thread: + logger.info("Stopping watchdog thread...") + self._watchdog_thread.join(self._max_interval) + self._watchdog_thread = None + else: + logger.info("No watchdog thread running, doing nothing") + + +_timer_client: Optional[TimerClient] = None + + +def configure(timer_client: TimerClient): + """ + Configures a timer client. Must be called before using ``expires``. + """ + global _timer_client + _timer_client = timer_client + logger.info("Timer client configured to: %s", type(_timer_client).__name__) + + +@contextmanager +def expires( + after: float, scope: Optional[str] = None, client: Optional[TimerClient] = None +): + """ + Acquires a countdown timer that expires in ``after`` seconds from now, + unless the code-block that it wraps is finished within the timeframe. + When the timer expires, this worker is eligible to be reaped. The + exact meaning of "reaped" depends on the client implementation. In + most cases, reaping means to terminate the worker process. + Note that the worker is NOT guaranteed to be reaped at exactly + ``time.now() + after``, but rather the worker is "eligible" for being + reaped and the ``TimerServer`` that the client talks to will ultimately + make the decision when and how to reap the workers with expired timers. + + Usage:: + + core.distributed.elastic.timer.configure(LocalTimerClient()) + with expires(after=10): + core.distributed.all_reduce(...) + """ + if client is None: + if _timer_client is None: + raise RuntimeError("Configure timer client before using countdown timers.") + client = _timer_client + if scope is None: + # grab the caller file + lineno + caller = getframeinfo(stack()[1][0]) + scope = f"{caller.filename}#{caller.lineno}" + expiration = time.time() + after + client.acquire(scope, expiration) + try: + yield + finally: + client.release(scope) diff --git a/mindnlp/core/distributed/elastic/timer/debug_info_logging.py b/mindnlp/core/distributed/elastic/timer/debug_info_logging.py new file mode 100644 index 000000000..7815c8aea --- /dev/null +++ b/mindnlp/core/distributed/elastic/timer/debug_info_logging.py @@ -0,0 +1,25 @@ +#!/usr/bin/env python3 +# mypy: allow-untyped-defs + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict, List + +from core.distributed.elastic.utils.logging import get_logger + + +logger = get_logger(__name__) + +__all__ = ["log_debug_info_for_expired_timers"] + + +def log_debug_info_for_expired_timers( + run_id: str, + expired_timers: Dict[int, List[str]], +): + if expired_timers: + logger.info("Timers expired for run:[%s] [%s].", run_id, expired_timers) diff --git a/mindnlp/core/distributed/elastic/timer/file_based_local_timer.py b/mindnlp/core/distributed/elastic/timer/file_based_local_timer.py new file mode 100644 index 000000000..537c2c3a7 --- /dev/null +++ b/mindnlp/core/distributed/elastic/timer/file_based_local_timer.py @@ -0,0 +1,396 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import io +import json +import os +import select +import signal +import sys +import threading +import time +from typing import Callable, Dict, List, Optional, Set, Tuple + +from core.distributed.elastic.timer.api import TimerClient, TimerRequest +from core.distributed.elastic.timer.debug_info_logging import ( + log_debug_info_for_expired_timers, +) +from core.distributed.elastic.utils.logging import get_logger + + +__all__ = ["FileTimerClient", "FileTimerRequest", "FileTimerServer"] + +logger = get_logger(__name__) + + +class FileTimerRequest(TimerRequest): + """ + Data object representing a countdown timer acquisition and release + that is used between the ``FileTimerClient`` and ``FileTimerServer``. + A negative ``expiration_time`` should be interpreted as a "release" + request. + ``signal`` is the signal to reap the worker process from the server + process. + """ + + __slots__ = ["version", "worker_pid", "scope_id", "expiration_time", "signal"] + + def __init__( + self, worker_pid: int, scope_id: str, expiration_time: float, signal: int = 0 + ) -> None: + self.version = 1 + self.worker_pid = worker_pid + self.scope_id = scope_id + self.expiration_time = expiration_time + self.signal = signal + + def __eq__(self, other) -> bool: + if isinstance(other, FileTimerRequest): + return ( + self.version == other.version + and self.worker_pid == other.worker_pid + and self.scope_id == other.scope_id + and self.expiration_time == other.expiration_time + and self.signal == other.signal + ) + return False + + def to_json(self) -> str: + return json.dumps( + { + "version": self.version, + "pid": self.worker_pid, + "scope_id": self.scope_id, + "expiration_time": self.expiration_time, + "signal": self.signal, + }, + ) + + +class FileTimerClient(TimerClient): + """ + Client side of ``FileTimerServer``. This client is meant to be used + on the same host that the ``FileTimerServer`` is running on and uses + pid to uniquely identify a worker. + This client uses a named_pipe to send timer requests to the + ``FileTimerServer``. This client is a producer while the + ``FileTimerServer`` is a consumer. Multiple clients can work with + the same ``FileTimerServer``. + + Args: + + file_path: str, the path of a FIFO special file. ``FileTimerServer`` + must have created it by calling os.mkfifo(). + + signal: signal, the signal to use to kill the process. Using a + negative or zero signal will not kill the process. + """ + + def __init__( + self, + file_path: str, + signal=(signal.SIGKILL if sys.platform != "win32" else signal.CTRL_C_EVENT), # type: ignore[attr-defined] + ) -> None: + super().__init__() + self._file_path = file_path + self.signal = signal + + def _open_non_blocking(self) -> Optional[io.TextIOWrapper]: + try: + fd = os.open(self._file_path, os.O_WRONLY | os.O_NONBLOCK) + return os.fdopen(fd, "wt") + except Exception: + return None + + def _send_request(self, request: FileTimerRequest) -> None: + # The server may have crashed or may haven't started yet. + # In such case, calling open() in blocking model blocks the client. + # To avoid such issue, open it in non-blocking mode, and an OSError will + # be raised if the server is not there. + file = self._open_non_blocking() + if file is None: + raise BrokenPipeError( + "Could not send the FileTimerRequest because FileTimerServer is not available." + ) + with file: + json_request = request.to_json() + # Write request with no greater than select.PIPE_BUF is guarantee to be atomic. + if len(json_request) > select.PIPE_BUF: + raise RuntimeError( + f"FileTimerRequest larger than {select.PIPE_BUF} bytes " + f"is not supported: {json_request}" + ) + file.write(json_request + "\n") + + def acquire(self, scope_id: str, expiration_time: float) -> None: + self._send_request( + request=FileTimerRequest( + worker_pid=os.getpid(), + scope_id=scope_id, + expiration_time=expiration_time, + signal=self.signal, + ), + ) + + def release(self, scope_id: str) -> None: + self._send_request( + request=FileTimerRequest( + worker_pid=os.getpid(), scope_id=scope_id, expiration_time=-1, signal=0 + ), + ) + + +class FileTimerServer: + """ + Server that works with ``FileTimerClient``. Clients are expected to be + running on the same host as the process that is running this server. + Each host in the job is expected to start its own timer server locally + and each server instance manages timers for local workers (running on + processes on the same host). + + Args: + + file_path: str, the path of a FIFO special file to be created. + + max_interval: float, max interval in seconds for each watchdog loop. + + daemon: bool, running the watchdog thread in daemon mode or not. + A daemon thread will not block a process to stop. + log_event: Callable[[Dict[str, str]], None], an optional callback for + logging the events in JSON format. + """ + + def __init__( + self, + file_path: str, + run_id: str, + max_interval: float = 10, + daemon: bool = True, + log_event: Optional[Callable[[str, Optional[FileTimerRequest]], None]] = None, + ) -> None: + self._file_path = file_path + self._run_id = run_id + self._max_interval = max_interval + self._daemon = daemon + self._timers: Dict[Tuple[int, str], FileTimerRequest] = {} + self._stop_signaled = False + self._watchdog_thread: Optional[threading.Thread] = None + + self._is_client_started = False + if os.path.exists(self._file_path): + os.remove(self._file_path) + os.mkfifo(self._file_path) + # For test only. Count the number of requests received. + self._request_count = 0 + # For test only. Process all requests and stop the server. + self._run_once = False + self._log_event = ( + log_event if log_event is not None else lambda name, request: None + ) + self._last_progress_time = int(time.time()) + + def start(self) -> None: + logger.info( + "Starting %s... max_interval=%s, daemon=%s, file_path=%s", + type(self).__name__, + self._max_interval, + self._daemon, + self._file_path, + ) + self._watchdog_thread = threading.Thread( + target=self._watchdog_loop, daemon=self._daemon + ) + logger.info("Starting watchdog thread...") + self._watchdog_thread.start() + self._log_event("watchdog started", None) + + def stop(self) -> None: + logger.info("Stopping %s", type(self).__name__) + self._stop_signaled = True + if self._watchdog_thread: + logger.info("Stopping watchdog thread...") + self._watchdog_thread.join(self._max_interval) + self._watchdog_thread = None + else: + logger.info("No watchdog thread running, doing nothing") + if os.path.exists(self._file_path): + os.remove(self._file_path) + self._log_event("watchdog stopped", None) + + def run_once(self) -> None: + self._run_once = True + if self._watchdog_thread: + logger.info("Stopping watchdog thread...") + self._watchdog_thread.join() + self._watchdog_thread = None + else: + logger.info("No watchdog thread running, doing nothing") + if os.path.exists(self._file_path): + os.remove(self._file_path) + + @staticmethod + def is_process_running(pid: int): + """ + function to check process is running or not + """ + try: + # Check if the process exists and we can send signals to it + os.kill(pid, 0) + return True + except OSError: + return False + + def _watchdog_loop(self) -> None: + # Open the pipe in blocking mode blocks the server thread. + # This is fine for the following reasons: + # 1. No client case usually does not happen. + # 2. We are running the watchdog loop in a separate daemon + # thread, which will not block the process to stop. + with open(self._file_path) as fd: + self._is_client_started = True + while not self._stop_signaled: + try: + run_once = self._run_once + self._run_watchdog(fd) + if run_once: + break + self._last_progress_time = int(time.time()) + except Exception: + logger.exception("Error running watchdog") + + def _run_watchdog(self, fd: io.TextIOWrapper) -> None: + timer_requests = self._get_requests(fd, self._max_interval) + self.register_timers(timer_requests) + now = time.time() + reaped_worker_pids = set() + + all_expired_timers = self.get_expired_timers(now) + log_debug_info_for_expired_timers( + self._run_id, + { + pid: [expired_timer.to_json() for expired_timer in expired_timers] + for pid, expired_timers in all_expired_timers.items() + }, + ) + + for worker_pid, expired_timers in all_expired_timers.items(): + logger.info( + "Reaping worker_pid=[%s]. Expired timers: %s", + worker_pid, + self._get_scopes(expired_timers), + ) + reaped_worker_pids.add(worker_pid) + # In case we have multiple expired timers, we find the first timer + # with a valid signal (>0) in the expiration time order. + expired_timers.sort(key=lambda timer: timer.expiration_time) + signal = 0 + expired_timer = None + for timer in expired_timers: + self._log_event("timer expired", timer) + if timer.signal > 0: + signal = timer.signal + expired_timer = timer + break + if signal <= 0: + logger.info( + "No signal specified with worker=[%s]. Do not reap it.", worker_pid + ) + continue + if self._reap_worker(worker_pid, signal): + logger.info( + "Successfully reaped worker=[%s] with signal=%s", worker_pid, signal + ) + self._log_event("kill worker process", expired_timer) + else: + logger.error( + "Error reaping worker=[%s]. Will retry on next watchdog.", + worker_pid, + ) + self.clear_timers(reaped_worker_pids) + + def _get_scopes(self, timer_requests: List[FileTimerRequest]) -> List[str]: + return [r.scope_id for r in timer_requests] + + def _get_requests( + self, fd: io.TextIOWrapper, max_interval: float + ) -> List[FileTimerRequest]: + start = time.time() + requests = [] + while not self._stop_signaled or self._run_once: + # For named pipe, readline() is blocking when at least one writer opens. + # It returns only when flush() is called at the writer side. + # Note that flush() is automatically called inside close(). + # After the last writer closes, readline() is not blocking. + # It will return an empty string when it's at end-of-file. + # Since the client side always opens the pipe, writes a message and closes + # the pipe immediately, the readline() call below is not blocking for long. + json_request = fd.readline() + if len(json_request) == 0: + if self._run_once: + break + time.sleep(min(max_interval, 1)) + else: + request = json.loads(json_request) + pid = request["pid"] + scope_id = request["scope_id"] + expiration_time = request["expiration_time"] + signal = request["signal"] + requests.append( + FileTimerRequest( + worker_pid=pid, + scope_id=scope_id, + expiration_time=expiration_time, + signal=signal, + ) + ) + now = time.time() + if now - start > max_interval: + break + return requests + + def register_timers(self, timer_requests: List[FileTimerRequest]) -> None: + for request in timer_requests: + pid = request.worker_pid + scope_id = request.scope_id + expiration_time = request.expiration_time + self._request_count += 1 + + key = (pid, scope_id) + # negative expiration is a proxy for a release call + if expiration_time < 0: + if key in self._timers: + del self._timers[key] + else: + self._timers[key] = request + + def clear_timers(self, worker_pids: Set[int]) -> None: + for pid, scope_id in list(self._timers.keys()): + if pid in worker_pids or not FileTimerServer.is_process_running(pid): + del self._timers[(pid, scope_id)] + + def get_expired_timers(self, deadline: float) -> Dict[int, List[FileTimerRequest]]: + # pid -> [timer_requests...] + expired_timers: Dict[int, List[FileTimerRequest]] = {} + for request in self._timers.values(): + if request.expiration_time <= deadline: + expired_scopes = expired_timers.setdefault(request.worker_pid, []) + expired_scopes.append(request) + return expired_timers + + def _reap_worker(self, worker_pid: int, signal: int) -> bool: + try: + os.kill(worker_pid, signal) + return True + except ProcessLookupError: + logger.info("Process with pid=%s does not exist. Skipping", worker_pid) + return True + except Exception: + logger.exception("Error terminating pid=%s", worker_pid) + return False + + def get_last_progress_time(self) -> int: + return self._last_progress_time if self._is_client_started else int(time.time()) diff --git a/mindnlp/core/distributed/elastic/timer/local_timer.py b/mindnlp/core/distributed/elastic/timer/local_timer.py new file mode 100644 index 000000000..d3562877a --- /dev/null +++ b/mindnlp/core/distributed/elastic/timer/local_timer.py @@ -0,0 +1,128 @@ +# mypy: allow-untyped-defs +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import logging +import multiprocessing as mp +import os +import signal +import time +from queue import Empty +from typing import Any, Dict, List, Set, Tuple + +from .api import RequestQueue, TimerClient, TimerRequest, TimerServer + + +__all__ = ["LocalTimerClient", "MultiprocessingRequestQueue", "LocalTimerServer"] + +logger = logging.getLogger(__name__) + + +class LocalTimerClient(TimerClient): + """ + Client side of ``LocalTimerServer``. This client is meant to be used + on the same host that the ``LocalTimerServer`` is running on and uses + pid to uniquely identify a worker. This is particularly useful in situations + where one spawns a subprocess (trainer) per GPU on a host with multiple + GPU devices. + """ + + def __init__(self, mp_queue): + super().__init__() + self._mp_queue = mp_queue + + def acquire(self, scope_id, expiration_time): + pid = os.getpid() + acquire_request = TimerRequest(pid, scope_id, expiration_time) + self._mp_queue.put(acquire_request) + + def release(self, scope_id): + pid = os.getpid() + release_request = TimerRequest(pid, scope_id, -1) + self._mp_queue.put(release_request) + + +class MultiprocessingRequestQueue(RequestQueue): + """ + A ``RequestQueue`` backed by python ``multiprocessing.Queue`` + """ + + def __init__(self, mp_queue: mp.Queue): + super().__init__() + self._mp_queue = mp_queue + + def size(self) -> int: + return self._mp_queue.qsize() + + def get(self, size, timeout: float) -> List[TimerRequest]: + requests = [] + wait = timeout + for _ in range(0, size): + start = time.time() + + try: + r = self._mp_queue.get(block=True, timeout=wait) + except Empty: + break + + requests.append(r) + wait = wait - (time.time() - start) + if wait <= 0: + break + + return requests + + +class LocalTimerServer(TimerServer): + """ + Server that works with ``LocalTimerClient``. Clients are expected to be + subprocesses to the parent process that is running this server. Each host + in the job is expected to start its own timer server locally and each + server instance manages timers for local workers (running on processes + on the same host). + """ + + def __init__( + self, mp_queue: mp.Queue, max_interval: float = 60, daemon: bool = True + ): + super().__init__(MultiprocessingRequestQueue(mp_queue), max_interval, daemon) + self._timers: Dict[Tuple[Any, str], TimerRequest] = {} + + def register_timers(self, timer_requests: List[TimerRequest]) -> None: + for request in timer_requests: + pid = request.worker_id + scope_id = request.scope_id + expiration_time = request.expiration_time + + # negative expiration is a proxy for a release call + if expiration_time < 0: + self._timers.pop((pid, scope_id), None) + else: + self._timers[(pid, scope_id)] = request + + def clear_timers(self, worker_ids: Set[int]) -> None: + for pid, scope_id in list(self._timers.keys()): + if pid in worker_ids: + self._timers.pop((pid, scope_id)) + + def get_expired_timers(self, deadline: float) -> Dict[Any, List[TimerRequest]]: + # pid -> [timer_requests...] + expired_timers: Dict[Any, List[TimerRequest]] = {} + for request in self._timers.values(): + if request.expiration_time <= deadline: + expired_scopes = expired_timers.setdefault(request.worker_id, []) + expired_scopes.append(request) + return expired_timers + + def _reap_worker(self, worker_id: int) -> bool: + try: + os.kill(worker_id, signal.SIGKILL) + return True + except ProcessLookupError: + logger.info("Process with pid=%s does not exist. Skipping", worker_id) + return True + except Exception: + logger.exception("Error terminating pid=%s", worker_id) + return False diff --git a/mindnlp/core/distributed/elastic/utils/__init__.py b/mindnlp/core/distributed/elastic/utils/__init__.py new file mode 100644 index 000000000..5fbc76bf7 --- /dev/null +++ b/mindnlp/core/distributed/elastic/utils/__init__.py @@ -0,0 +1,9 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from .api import get_env_variable_or_raise, get_socket_with_port, macros # noqa: F401 diff --git a/mindnlp/core/distributed/elastic/utils/api.py b/mindnlp/core/distributed/elastic/utils/api.py new file mode 100644 index 000000000..bff91438b --- /dev/null +++ b/mindnlp/core/distributed/elastic/utils/api.py @@ -0,0 +1,62 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import os +import socket +from string import Template +from typing import Any, List + + +def get_env_variable_or_raise(env_name: str) -> str: + r""" + Tries to retrieve environment variable. Raises ``ValueError`` + if no environment variable found. + + Args: + env_name (str): Name of the env variable + """ + value = os.environ.get(env_name, None) + if value is None: + msg = f"Environment variable {env_name} expected, but not set" + raise ValueError(msg) + return value + + +def get_socket_with_port() -> socket.socket: + addrs = socket.getaddrinfo( + host="localhost", port=None, family=socket.AF_UNSPEC, type=socket.SOCK_STREAM + ) + for addr in addrs: + family, type, proto, _, _ = addr + s = socket.socket(family, type, proto) + try: + s.bind(("localhost", 0)) + s.listen(0) + return s + except OSError: + s.close() + raise RuntimeError("Failed to create a socket") + + +class macros: + """ + Defines simple macros for caffe2.distributed.launch cmd args substitution + """ + + local_rank = "${local_rank}" + + @staticmethod + def substitute(args: List[Any], local_rank: str) -> List[str]: + args_sub = [] + for arg in args: + if isinstance(arg, str): + sub = Template(arg).safe_substitute(local_rank=local_rank) + args_sub.append(sub) + else: + args_sub.append(arg) + return args_sub diff --git a/mindnlp/core/distributed/elastic/utils/distributed.py b/mindnlp/core/distributed/elastic/utils/distributed.py new file mode 100644 index 000000000..d8cc412e2 --- /dev/null +++ b/mindnlp/core/distributed/elastic/utils/distributed.py @@ -0,0 +1,184 @@ +#!/usr/bin/env python3 +# mypy: allow-untyped-defs + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import datetime +import os +import socket +from contextlib import closing +from typing import Optional + +from mindnlp import core.distributed as dist +from core.distributed.elastic.utils.logging import get_logger +from core.distributed.elastic.utils.store import barrier + + +__all__ = ["create_c10d_store", "get_free_port", "get_socket_with_port"] + +logger = get_logger(__name__) + +_ADDRESS_IN_USE = "Address already in use" +_SOCKET_TIMEOUT = "Socket Timeout" + +_TCP_STORE_INIT = "_tcp_store/num_members" + + +def create_c10d_store( + is_server: bool, + server_addr: str, + server_port: int = -1, + world_size: int = 1, + timeout: float = (60 * 10), # 10 min + wait_for_workers: bool = True, + retries=3, + use_libuv: Optional[bool] = None, +): + if use_libuv is not None: + logger.warning( + "argument use_libuv is deprecated and ignored. Set USE_LIBUV environment " + 'variable to "0" to disable libuv, or "1" to enable it. If the env var ' + "is not set, libuv will be used by default." + ) + + # check os.environ for use_libuv + use_libuv = os.environ.get("USE_LIBUV", "1") == "1" # libuv is the default option + + if server_port == -1 and world_size > 1: + raise ValueError( + f"server_port must be specified when world_size > 1, got server_port={server_port}, world_size={world_size}" + ) + + if server_port != -1: + logger.info("sever_port: %s, specified, ignoring retries", server_port) + + # only retry when server_port is NOT static + attempt = retries if server_port == -1 else 1 + while True: + if server_port != -1: + port = server_port + else: + port = get_free_port() + + logger.info( + "Creating c10d store on %s:%s\n" + " world_size : %s\n" + " is_server : %s\n" + " timeout(sec): %s\n" + " use_libuv : %s\n", + server_addr, + port, + world_size, + is_server, + timeout, + use_libuv, + ) + + try: + store = dist.TCPStore( + host_name=server_addr, + port=port, + world_size=world_size, + is_master=is_server, + timeout=datetime.timedelta(seconds=timeout), + wait_for_workers=wait_for_workers, + use_libuv=use_libuv, + ) + # skips full rank check when we don't have to wait for all workers + if wait_for_workers: + _check_full_rank(store, world_size, timeout=timeout) + logger.info("Successfully created c10d store") + return store + except RuntimeError as e: + # this is brittle, but the underlying exception type is not properly pybinded + # so we parse the error msg for now, interestingly this is how torch itself + # detects timeouts and port conflicts in their own unittests + # see - caffe2/torch/testing/_internal/common_utils.py + # TODO properly map the exceptions in pybind (c10d/init.cpp) + if str(e) == _ADDRESS_IN_USE: # this will only happen on the server + if attempt < retries: + logger.warning( + "port: %s already in use, attempt: [%s/%s]", + port, + attempt, + retries, + ) + attempt += 1 + else: + raise RuntimeError( + f"on {server_addr}, port: {port} already in use" + ) from e + else: + raise + + +def _check_full_rank(store, world_size, timeout): + try: + barrier(store, world_size, key_prefix=_TCP_STORE_INIT, barrier_timeout=timeout) + except RuntimeError as e: + if str(e) == _SOCKET_TIMEOUT: + raise TimeoutError( + f"timed out waiting for all {world_size} members to join" + ) from e + else: + raise + + +def get_free_port(): + """ + Returns an unused port on localhost. + + This function finds an unused port on localhost by opening to socket to bind + to a port and then closing it. + + Returns: + int: an unused port on localhost + + Example: + >>> # xdoctest: +SKIP("Nondeterministic") + >>> get_free_port() + 63976 + + ..note: + The port returned by :func:`get_free_port` is not reserved and may be + taken by another process after this function returns. + """ + sock = get_socket_with_port() + with closing(sock): + return sock.getsockname()[1] + + +def get_socket_with_port() -> socket.socket: + """ + Returns a free port on localhost that is "reserved" by binding a temporary + socket on it. Close the socket before passing the port to the entity + that requires it. Usage example + + :: + + sock = _get_socket_with_port() + with closing(sock): + port = sock.getsockname()[1] + sock.close() + # there is still a race-condition that some other process + # may grab this port before func() runs + func(port) + """ + + addrs = socket.getaddrinfo( + host="localhost", port=None, family=socket.AF_UNSPEC, type=socket.SOCK_STREAM + ) + for addr in addrs: + family, type, proto, _, _ = addr + s = socket.socket(family, type, proto) + try: + s.bind(("localhost", 0)) + s.listen(0) + return s + except OSError as e: + s.close() + logger.warning("Socket creation attempt failed.", exc_info=e) + raise RuntimeError("Failed to create a socket") diff --git a/mindnlp/core/distributed/elastic/utils/log_level.py b/mindnlp/core/distributed/elastic/utils/log_level.py new file mode 100644 index 000000000..0785e1520 --- /dev/null +++ b/mindnlp/core/distributed/elastic/utils/log_level.py @@ -0,0 +1,14 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +def get_log_level() -> str: + """ + Return default log level for pycore. + """ + return "WARNING" diff --git a/mindnlp/core/distributed/elastic/utils/logging.py b/mindnlp/core/distributed/elastic/utils/logging.py new file mode 100644 index 000000000..40eb1bd9b --- /dev/null +++ b/mindnlp/core/distributed/elastic/utils/logging.py @@ -0,0 +1,70 @@ +#!/usr/bin/env python3 +# mypy: allow-untyped-defs + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import inspect +import logging +import os +import warnings +from typing import Optional + +from core.distributed.elastic.utils.log_level import get_log_level + + +def get_logger(name: Optional[str] = None): + """ + Util function to set up a simple logger that writes + into stderr. The loglevel is fetched from the LOGLEVEL + env. variable or WARNING as default. The function will use the + module name of the caller if no name is provided. + + Args: + name: Name of the logger. If no name provided, the name will + be derived from the call stack. + """ + + # Derive the name of the caller, if none provided + # Use depth=2 since this function takes up one level in the call stack + return _setup_logger(name or _derive_module_name(depth=2)) + + +def _setup_logger(name: Optional[str] = None): + logger = logging.getLogger(name) + logger.setLevel(os.environ.get("LOGLEVEL", get_log_level())) + return logger + + +def _derive_module_name(depth: int = 1) -> Optional[str]: + """ + Derives the name of the caller module from the stack frames. + + Args: + depth: The position of the frame in the stack. + """ + try: + stack = inspect.stack() + assert depth < len(stack) + # FrameInfo is just a named tuple: (frame, filename, lineno, function, code_context, index) + frame_info = stack[depth] + + module = inspect.getmodule(frame_info[0]) + if module: + module_name = module.__name__ + else: + # inspect.getmodule(frame_info[0]) does NOT work (returns None) in + # binaries built with @mode/opt + # return the filename (minus the .py extension) as modulename + filename = frame_info[1] + module_name = os.path.splitext(os.path.basename(filename))[0] + return module_name + except Exception as e: + warnings.warn( + f"Error deriving logger module name, using . Exception: {e}", + RuntimeWarning, + ) + return None diff --git a/mindnlp/core/distributed/elastic/utils/store.py b/mindnlp/core/distributed/elastic/utils/store.py new file mode 100644 index 000000000..853a534e4 --- /dev/null +++ b/mindnlp/core/distributed/elastic/utils/store.py @@ -0,0 +1,225 @@ +#!/usr/bin/env python3 +# mypy: allow-untyped-defs + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from contextlib import contextmanager +from datetime import timedelta +from typing import Callable, Iterable, List, Optional + +from mindnlp import core + + +DistStoreError = core._C._DistStoreError + +_NUM_MEMBERS = "/num_members" +_LAST_MEMBER_CHECKIN = "/last_member" +_TRACE = "/TRACE" +_TRACING_GATE = "/TRACING_GATE" +_MAX_TRACE_MISSING_RANKS = 16 + + +__all__ = ["store_timeout", "get_all", "synchronize", "barrier"] + + +@contextmanager +def store_timeout(store, timeout: float): + """ + This sets the timeout and then restores the old timeout when the context + manager exits. + + Args: + store: the store to set the timeout on + timeout: the timeout to set + """ + + old_timeout = store.timeout + store.set_timeout(timedelta(seconds=timeout)) + yield + store.set_timeout(old_timeout) + + +def get_all(store, rank: int, prefix: str, world_size: int): + r""" + Given a store and a prefix, the method goes through the array of keys + of the following format: ``{prefix}{idx}``, where idx is in a range + from 0 to size, and tries to retrieve the data. + + The Rank0 process waits at the end to make sure all other processes + finished the procedure before exiting. + + Usage + + :: + + values = get_all(store, 'torchelastic/data', 3) + value1 = values[0] # retrieves the data for key torchelastic/data0 + value2 = values[1] # retrieves the data for key torchelastic/data1 + value3 = values[2] # retrieves the data for key torchelastic/data2 + + """ + data_arr = store.multi_get([f"{prefix}{idx}" for idx in range(world_size)]) + + barrier_key = _barrier_nonblocking( + store=store, + world_size=world_size, + key_prefix=f"{prefix}/finished", + ) + if rank == 0: + # Rank0 runs the TCPStore daemon, as a result it needs to exit last. + # Otherwise, the barrier may timeout if rank0 process finished the work + # before other processes finished `get_all` method + store.wait([barrier_key]) + + return data_arr + + +def synchronize( + store, + data: bytes, + rank: int, + world_size: int, + key_prefix: str, + timeout: float = 300, +) -> List[bytes]: + """ + Synchronizes ``world_size`` agents between each other using the underlying c10d store. + The ``data`` will be available on each of the agents. + + Note: The data on the path is not deleted, as a result there can be stale data if + you use the same key_prefix twice. + + Time complexity: O(N) per worker, O(N^2) globally. + """ + with store_timeout(store, timeout): + store.set(f"{key_prefix}{rank}", data) + agent_data = get_all(store, rank, key_prefix, world_size) + return agent_data + + +def _try_detecting_missing_ranks( + store, + world_size: int, + key_prefix: str, + rank: int, + rank_decoder: Callable[[int], str], + trace_timeout: float, +) -> Optional[Iterable[str]]: + store.set(f"{key_prefix}{rank}{_TRACE}", "") + + def _find_missing_ranks(): + missing_rank_info = set() + ranks_missing = 0 + for i in range(1, world_size): + # reduce noise, assuming in general 8 ranks per node + # It is valuable to know that 1 or >1 nodes have timed-out. + if ranks_missing >= _MAX_TRACE_MISSING_RANKS: + break + try: + if ranks_missing == 0: + store.wait( + [f"{key_prefix}{i}{_TRACE}"], timedelta(seconds=trace_timeout) + ) + else: + # use a shortest timeout, some ranks have failed to check-in + store.wait([f"{key_prefix}{i}{_TRACE}"], timedelta(milliseconds=1)) + except DistStoreError: + ranks_missing += 1 + missing_rank_info.add(rank_decoder(i)) + return missing_rank_info + + def _checkin(): + try: + store.wait([f"{key_prefix}{_TRACING_GATE}"]) + return [f"[]"] + except DistStoreError: + # in case rank0 is the source of the timeout, original exception will be raised + return None + + if rank == 0: + missing_rank_info = _find_missing_ranks() + store.set(f"{key_prefix}{_TRACING_GATE}", "") + return missing_rank_info + else: + return _checkin() + + +def _barrier_nonblocking(store, world_size: int, key_prefix: str) -> str: + """ + Does all the non-blocking operations for a barrier and returns the final key + that can be waited on. + """ + num_members_key = key_prefix + _NUM_MEMBERS + last_member_key = key_prefix + _LAST_MEMBER_CHECKIN + + idx = store.add(num_members_key, 1) + if idx == world_size: + store.set(last_member_key, "") + + return last_member_key + + +def barrier( + store, + world_size: int, + key_prefix: str, + barrier_timeout: float = 300, + rank: Optional[int] = None, + rank_tracing_decoder: Optional[Callable[[int], str]] = None, + trace_timeout: float = 10, +) -> None: + """ + A global lock between agents. This will pause all workers until at least + ``world_size`` workers respond. + + This uses a fast incrementing index to assign waiting ranks and a success + flag set by the last worker. + + Time complexity: O(1) per worker, O(N) globally. + + Optionally, passing rank will enable tracing of missing ranks on timeouts. + `rank_tracing_decoder` lambda arg can be used to convert rank data + into a more meaninful information at an app level (e.g. hostname). + + Note: Since the data is not removed from the store, the barrier can be used + once per unique ``key_prefix``. + """ + + if rank is None: + assert rank_tracing_decoder is None, "Tracing requires rank information" + + with store_timeout(store, barrier_timeout): + last_member_key = _barrier_nonblocking( + store=store, world_size=world_size, key_prefix=key_prefix + ) + try: + store.wait([last_member_key]) + except DistStoreError as e: + if rank is None: + raise e + else: + missing_ranks = _try_detecting_missing_ranks( + store, + world_size, + key_prefix, + rank, + rank_tracing_decoder or (lambda x: str(x)), + trace_timeout, + ) + if missing_ranks is not None: + raise DistStoreError( + "Timed out waiting on barrier on " + "rank {}, for key prefix: {} (world_size={}, missing_ranks={}, timeout={})".format( + rank, + key_prefix, + world_size, + f"[{', '.join(missing_ranks)}]", + barrier_timeout, + ) + ) from None + else: + raise e diff --git a/mindnlp/core/distributed/examples/memory_tracker_example.py b/mindnlp/core/distributed/examples/memory_tracker_example.py new file mode 100644 index 000000000..69114b8fd --- /dev/null +++ b/mindnlp/core/distributed/examples/memory_tracker_example.py @@ -0,0 +1,33 @@ +# mypy: allow-untyped-defs +from mindnlp import corevision + +from mindnlp import core +from core.distributed._tools import MemoryTracker + + +def run_one_model(net: core.nn.Module, input: core.Tensor): + net.cuda() + input = input.cuda() + + # Create the memory Tracker + mem_tracker = MemoryTracker() + # start_monitor before the training iteration starts + mem_tracker.start_monitor(net) + + # run one training iteration + net.zero_grad(True) + loss = net(input) + if isinstance(loss, dict): + loss = loss["out"] + loss.sum().backward() + net.zero_grad(set_to_none=True) + + # stop monitoring after the training iteration ends + mem_tracker.stop() + # print the memory stats summary + mem_tracker.summary() + # plot the memory traces at operator level + mem_tracker.show_traces() + + +run_one_model(torchvision.models.resnet34(), core.rand(32, 3, 224, 224, device="cuda")) diff --git a/mindnlp/core/distributed/fsdp/__init__.py b/mindnlp/core/distributed/fsdp/__init__.py new file mode 100644 index 000000000..fa3888cbd --- /dev/null +++ b/mindnlp/core/distributed/fsdp/__init__.py @@ -0,0 +1 @@ +FullyShardedDataParallel = None diff --git a/mindnlp/core/distributed/launch.py b/mindnlp/core/distributed/launch.py new file mode 100644 index 000000000..516a0d74c --- /dev/null +++ b/mindnlp/core/distributed/launch.py @@ -0,0 +1,208 @@ +# mypy: allow-untyped-defs +r""" +Module ``core.distributed.launch``. + +``core.distributed.launch`` is a module that spawns up multiple distributed +training processes on each of the training nodes. + +.. warning:: + + This module is going to be deprecated in favor of :ref:`torchrun `. + +The utility can be used for single-node distributed training, in which one or +more processes per node will be spawned. The utility can be used for either +CPU training or GPU training. If the utility is used for GPU training, +each distributed process will be operating on a single GPU. This can achieve +well-improved single-node training performance. It can also be used in +multi-node distributed training, by spawning up multiple processes on each node +for well-improved multi-node distributed training performance as well. +This will especially be beneficial for systems with multiple Infiniband +interfaces that have direct-GPU support, since all of them can be utilized for +aggregated communication bandwidth. + +In both cases of single-node distributed training or multi-node distributed +training, this utility will launch the given number of processes per node +(``--nproc-per-node``). If used for GPU training, this number needs to be less +or equal to the number of GPUs on the current system (``nproc_per_node``), +and each process will be operating on a single GPU from *GPU 0 to +GPU (nproc_per_node - 1)*. + +**How to use this module:** + +1. Single-Node multi-process distributed training + +:: + + python -m core.distributed.launch --nproc-per-node=NUM_GPUS_YOU_HAVE + YOUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3 and all other + arguments of your training script) + +2. Multi-Node multi-process distributed training: (e.g. two nodes) + + +Node 1: *(IP: 192.168.1.1, and has a free port: 1234)* + +:: + + python -m core.distributed.launch --nproc-per-node=NUM_GPUS_YOU_HAVE + --nnodes=2 --node-rank=0 --master-addr="192.168.1.1" + --master-port=1234 YOUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3 + and all other arguments of your training script) + +Node 2: + +:: + + python -m core.distributed.launch --nproc-per-node=NUM_GPUS_YOU_HAVE + --nnodes=2 --node-rank=1 --master-addr="192.168.1.1" + --master-port=1234 YOUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3 + and all other arguments of your training script) + +3. To look up what optional arguments this module offers: + +:: + + python -m core.distributed.launch --help + + +**Important Notices:** + +1. This utility and multi-process distributed (single-node or +multi-node) GPU training currently only achieves the best performance using +the NCCL distributed backend. Thus NCCL backend is the recommended backend to +use for GPU training. + +2. In your training program, you must parse the command-line argument: +``--local-rank=LOCAL_PROCESS_RANK``, which will be provided by this module. +If your training program uses GPUs, you should ensure that your code only +runs on the GPU device of LOCAL_PROCESS_RANK. This can be done by: + +Parsing the local_rank argument + +:: + + >>> # xdoctest: +SKIP + >>> import argparse + >>> parser = argparse.ArgumentParser() + >>> parser.add_argument("--local-rank", "--local_rank", type=int) + >>> args = parser.parse_args() + +Set your device to local rank using either + +:: + + >>> core.cuda.set_device(args.local_rank) # before your code runs + +or + +:: + + >>> with core.cuda.device(args.local_rank): + >>> # your code to run + >>> ... + +.. versionchanged:: 2.0.0 + + The launcher will passes the ``--local-rank=`` argument to your script. + From PyTorch 2.0.0 onwards, the dashed ``--local-rank`` is preferred over the + previously used underscored ``--local_rank``. + + For backward compatibility, it may be necessary for users to handle both + cases in their argument parsing code. This means including both ``"--local-rank"`` + and ``"--local_rank"`` in the argument parser. If only ``"--local_rank"`` is + provided, the launcher will trigger an error: "error: unrecognized arguments: + --local-rank=". For training code that only supports PyTorch 2.0.0+, + including ``"--local-rank"`` should be sufficient. + +3. In your training program, you are supposed to call the following function +at the beginning to start the distributed backend. It is strongly recommended +that ``init_method=env://``. Other init methods (e.g. ``tcp://``) may work, +but ``env://`` is the one that is officially supported by this module. + +:: + + >>> core.distributed.init_process_group(backend='YOUR BACKEND', + >>> init_method='env://') + +4. In your training program, you can either use regular distributed functions +or use :func:`core.nn.parallel.DistributedDataParallel` module. If your +training program uses GPUs for training and you would like to use +:func:`core.nn.parallel.DistributedDataParallel` module, +here is how to configure it. + +:: + + >>> model = core.nn.parallel.DistributedDataParallel(model, + >>> device_ids=[args.local_rank], + >>> output_device=args.local_rank) + +Please ensure that ``device_ids`` argument is set to be the only GPU device id +that your code will be operating on. This is generally the local rank of the +process. In other words, the ``device_ids`` needs to be ``[args.local_rank]``, +and ``output_device`` needs to be ``args.local_rank`` in order to use this +utility + +5. Another way to pass ``local_rank`` to the subprocesses via environment variable +``LOCAL_RANK``. This behavior is enabled when you launch the script with +``--use-env=True``. You must adjust the subprocess example above to replace +``args.local_rank`` with ``os.environ['LOCAL_RANK']``; the launcher +will not pass ``--local-rank`` when you specify this flag. + +.. warning:: + + ``local_rank`` is NOT globally unique: it is only unique per process + on a machine. Thus, don't use it to decide if you should, e.g., + write to a networked filesystem. See + https://github.com/pytorch/pytorch/issues/12042 for an example of + how things can go wrong if you don't do this correctly. + + + +""" + +from typing_extensions import deprecated as _deprecated + +from core.distributed.run import get_args_parser, run + + +def parse_args(args): + parser = get_args_parser() + parser.add_argument( + "--use-env", + "--use_env", + default=False, + action="store_true", + help="Use environment variable to pass " + "'local rank'. For legacy reasons, the default value is False. " + "If set to True, the script will not pass " + "--local-rank as argument, and will instead set LOCAL_RANK.", + ) + return parser.parse_args(args) + + +def launch(args): + if args.no_python and not args.use_env: + raise ValueError( + "When using the '--no-python' flag," + " you must also set the '--use-env' flag." + ) + run(args) + + +@_deprecated( + "The module core.distributed.launch is deprecated\n" + "and will be removed in future. Use torchrun.\n" + "Note that --use-env is set by default in torchrun.\n" + "If your script expects `--local-rank` argument to be set, please\n" + "change it to read from `os.environ['LOCAL_RANK']` instead. See \n" + "https://pycore.org/docs/stable/distributed.html#launch-utility for \n" + "further instructions\n", + category=FutureWarning, +) +def main(args=None): + args = parse_args(args) + launch(args) + + +if __name__ == "__main__": + main() diff --git a/mindnlp/core/distributed/launcher/__init__.py b/mindnlp/core/distributed/launcher/__init__.py new file mode 100644 index 000000000..caeffef5d --- /dev/null +++ b/mindnlp/core/distributed/launcher/__init__.py @@ -0,0 +1,14 @@ +#!/usr/bin/env/python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +from core.distributed.launcher.api import ( # noqa: F401 + elastic_launch, + launch_agent, + LaunchConfig, +) diff --git a/mindnlp/core/distributed/launcher/api.py b/mindnlp/core/distributed/launcher/api.py new file mode 100644 index 000000000..8e8971090 --- /dev/null +++ b/mindnlp/core/distributed/launcher/api.py @@ -0,0 +1,289 @@ +#!/usr/bin/env python3 +# mypy: allow-untyped-defs + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import sys +import uuid +from dataclasses import dataclass, field +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +from mindnlp import core.distributed.elastic.rendezvous.registry as rdzv_registry +from core.distributed.elastic import events, metrics +from core.distributed.elastic.agent.server.api import WorkerSpec +from core.distributed.elastic.agent.server.local_elastic_agent import LocalElasticAgent +from core.distributed.elastic.multiprocessing import ( + DefaultLogsSpecs, + LogsSpecs, + SignalException, +) +from core.distributed.elastic.multiprocessing.errors import ChildFailedError +from core.distributed.elastic.rendezvous import RendezvousParameters +from core.distributed.elastic.rendezvous.utils import parse_rendezvous_endpoint +from core.distributed.elastic.utils.logging import get_logger + + +__all__ = ["LaunchConfig", "elastic_launch", "launch_agent"] + +logger = get_logger(__name__) + + +@dataclass +class LaunchConfig: + """ + Creates a rendezvous config. + + Args: + min_nodes: Minimum amount of nodes that the user function will + be launched on. Elastic agent ensures that the user + function start only when the min_nodes amount enters + the rendezvous. + max_nodes: Maximum amount of nodes that the user function + will be launched on. + nproc_per_node: On each node the elastic agent will launch + this amount of workers that will execute user + defined function. + rdzv_backend: rdzv_backend to use in the rendezvous (zeus-adapter, etcd). + rdzv_endpoint: The endpoint of the rdzv sync. storage. + rdzv_configs: Key, value pair that specifies rendezvous specific configuration. + rdzv_timeout: Legacy argument that specifies timeout for the rendezvous. It is going + to be removed in future versions, see the note below. The default timeout is 900 seconds. + run_id: The unique run id of the job (if not passed a unique one will be + deduced from run environment - flow workflow id in flow - or auto generated). + role: User defined role of the worker (defaults to "trainer"). + max_restarts: The maximum amount of restarts that elastic agent will conduct + on workers before failure. + monitor_interval: The interval in seconds that is used by the elastic_agent + as a period of monitoring workers. + start_method: The method is used by the elastic agent to start the + workers (spawn, fork, forkserver). + metrics_cfg: configuration to initialize metrics. + local_addr: address of the local node if any. If not set, a lookup on the local + machine's FQDN will be performed. + local_ranks_filter: ranks for which to show logs in console. If not set, show from all. + ..note: + `rdzv_timeout` is a legacy argument that will be removed in future. + Set the timeout via `rdzv_configs['timeout']` + + """ + + min_nodes: int + max_nodes: int + nproc_per_node: int + logs_specs: Optional[LogsSpecs] = None + run_id: str = "" + role: str = "default_role" + rdzv_endpoint: str = "" + rdzv_backend: str = "etcd" + rdzv_configs: Dict[str, Any] = field(default_factory=dict) + rdzv_timeout: int = -1 + max_restarts: int = 3 + monitor_interval: float = 0.1 + start_method: str = "spawn" + log_line_prefix_template: Optional[str] = None + metrics_cfg: Dict[str, str] = field(default_factory=dict) + local_addr: Optional[str] = None + + def __post_init__(self): + default_timeout = 900 + if self.rdzv_timeout != -1: + self.rdzv_configs["timeout"] = self.rdzv_timeout + elif "timeout" not in self.rdzv_configs: + self.rdzv_configs["timeout"] = default_timeout + + # Post-processing to enable refactoring to introduce logs_specs due to non-torchrun API usage + if self.logs_specs is None: + self.logs_specs = DefaultLogsSpecs() + + +class elastic_launch: + """ + Launches an torchelastic agent on the container that invoked the entrypoint. + + 1. Pass the ``entrypoint`` arguments as non ``kwargs`` (e.g. no named parameters)/ + ``entrypoint`` can be a function or a command. + 2. The return value is a map of each worker's output mapped + by their respective global rank. + + Usage + + :: + + def worker_fn(foo): + # ... + + def main(): + # entrypoint is a function. + outputs = elastic_launch(LaunchConfig, worker_fn)(foo) + # return rank 0's output + return outputs[0] + + # entrypoint is a command and ``script.py`` is the python module. + outputs = elastic_launch(LaunchConfig, "script.py")(args) + outputs = elastic_launch(LaunchConfig, "python")("script.py") + """ + + def __init__( + self, + config: LaunchConfig, + entrypoint: Union[Callable, str, None], + ): + self._config = config + self._entrypoint = entrypoint + + def __call__(self, *args): + return launch_agent(self._config, self._entrypoint, list(args)) + + +def _get_entrypoint_name( + entrypoint: Union[Callable, str, None], args: List[Any] +) -> str: + """Retrieve entrypoint name with the rule: + 1. If entrypoint is a function, use ``entrypoint.__qualname__``. + 2. If entrypoint is a string, check its value: + 2.1 if entrypoint equals to ``sys.executable`` (like "python"), use the first element from ``args`` + which does not start with hifen letter (for example, "-u" will be skipped). + 2.2 otherwise, use ``entrypoint`` value. + 3. Otherwise, return empty string. + """ + if isinstance(entrypoint, Callable): # type: ignore[arg-type] + return entrypoint.__name__ # type: ignore[union-attr] + elif isinstance(entrypoint, str): + if entrypoint == sys.executable: + return next((arg for arg in args if arg[0] != "-"), "") + else: + return entrypoint + else: + return "" + + +def _get_addr_and_port( + rdzv_parameters: RendezvousParameters, +) -> Tuple[Optional[str], Optional[int]]: + if rdzv_parameters.backend != "static": + return (None, None) + endpoint = rdzv_parameters.endpoint + endpoint = endpoint.strip() + if not endpoint: + raise ValueError( + "Endpoint is missing in endpoint. Try to add --master-addr and --master-port" + ) + master_addr, master_port = parse_rendezvous_endpoint(endpoint, default_port=-1) + if master_port == -1: + raise ValueError( + f"port is missing in endpoint: {endpoint}. Try to specify --master-port" + ) + return (master_addr, master_port) + + +def launch_agent( + config: LaunchConfig, + entrypoint: Union[Callable, str, None], + args: List[Any], +) -> Dict[int, Any]: + if not config.run_id: + run_id = str(uuid.uuid4().int) + logger.warning("config has no run_id, generated a random run_id: %s", run_id) + config.run_id = run_id + + entrypoint_name = _get_entrypoint_name(entrypoint, args) + + logger.info( + "Starting elastic_operator with launch configs:\n" + " entrypoint : %(entrypoint)s\n" + " min_nodes : %(min_nodes)s\n" + " max_nodes : %(max_nodes)s\n" + " nproc_per_node : %(nproc_per_node)s\n" + " run_id : %(run_id)s\n" + " rdzv_backend : %(rdzv_backend)s\n" + " rdzv_endpoint : %(rdzv_endpoint)s\n" + " rdzv_configs : %(rdzv_configs)s\n" + " max_restarts : %(max_restarts)s\n" + " monitor_interval : %(monitor_interval)s\n" + " log_dir : %(log_dir)s\n" + " metrics_cfg : %(metrics_cfg)s\n", + { + "entrypoint": entrypoint_name, + "min_nodes": config.min_nodes, + "max_nodes": config.max_nodes, + "nproc_per_node": config.nproc_per_node, + "run_id": config.run_id, + "rdzv_backend": config.rdzv_backend, + "rdzv_endpoint": config.rdzv_endpoint, + "rdzv_configs": config.rdzv_configs, + "max_restarts": config.max_restarts, + "monitor_interval": config.monitor_interval, + "log_dir": config.logs_specs.root_log_dir, # type: ignore[union-attr] + "metrics_cfg": config.metrics_cfg, + }, + ) + + rdzv_parameters = RendezvousParameters( + backend=config.rdzv_backend, + endpoint=config.rdzv_endpoint, + run_id=config.run_id, + min_nodes=config.min_nodes, + max_nodes=config.max_nodes, + local_addr=config.local_addr, + **config.rdzv_configs, + ) + + master_addr, master_port = _get_addr_and_port(rdzv_parameters) + + spec = WorkerSpec( + role=config.role, + local_world_size=config.nproc_per_node, + entrypoint=entrypoint, + args=tuple(args), + rdzv_handler=rdzv_registry.get_rendezvous_handler(rdzv_parameters), + max_restarts=config.max_restarts, + monitor_interval=config.monitor_interval, + master_addr=master_addr, + master_port=master_port, + local_addr=config.local_addr, + ) + + agent = LocalElasticAgent( + spec=spec, + logs_specs=config.logs_specs, # type: ignore[arg-type] + start_method=config.start_method, + log_line_prefix_template=config.log_line_prefix_template, + ) + + shutdown_rdzv = True + try: + metrics.initialize_metrics(metrics.MetricsConfig(config.metrics_cfg)) + + result = agent.run() + # records that agent.run() has succeeded NOT that workers have succeeded + events.record(agent.get_event_succeeded()) + + if result.is_failed(): + # ChildFailedError is treated specially by @record + # if the error files for the failed children exist + # @record will copy the first error (root cause) + # to the error file of the launcher process. + raise ChildFailedError( + name=entrypoint_name, + failures=result.failures, + ) + + return result.return_values + except ChildFailedError: + raise + except SignalException: + # when the agent dies with a signal do NOT shutdown the rdzv_handler + # since this closes the rendezvous on this rdzv_id permanently and + # prevents any additional scaling events + shutdown_rdzv = False + events.record(agent.get_event_failed()) + raise + except Exception: + events.record(agent.get_event_failed()) + raise + finally: + if shutdown_rdzv: + spec.rdzv_handler.shutdown() diff --git a/mindnlp/core/distributed/logging_handlers.py b/mindnlp/core/distributed/logging_handlers.py new file mode 100644 index 000000000..b1b02a635 --- /dev/null +++ b/mindnlp/core/distributed/logging_handlers.py @@ -0,0 +1,17 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from typing import Dict, List + + +__all__: List[str] = [] + +_log_handlers: Dict[str, logging.Handler] = { + "default": logging.NullHandler(), +} diff --git a/mindnlp/core/distributed/nn/__init__.py b/mindnlp/core/distributed/nn/__init__.py new file mode 100644 index 000000000..7c8b68e3e --- /dev/null +++ b/mindnlp/core/distributed/nn/__init__.py @@ -0,0 +1,7 @@ +from mindnlp import core + +from .functional import * # noqa: F403 + + +if core.distributed.rpc.is_available(): + from .api.remote_module import RemoteModule diff --git a/mindnlp/core/distributed/nn/api/__init__.py b/mindnlp/core/distributed/nn/api/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/mindnlp/core/distributed/nn/api/remote_module.py b/mindnlp/core/distributed/nn/api/remote_module.py new file mode 100644 index 000000000..a419f801f --- /dev/null +++ b/mindnlp/core/distributed/nn/api/remote_module.py @@ -0,0 +1,762 @@ +#!/usr/bin/python3 +# mypy: allow-untyped-defs +import collections +import io +import sys +import types +from typing import ( + Any, + Callable, + Dict, + Iterator, + List, + Mapping, + Optional, + Set, + Tuple, + Type, + TypeVar, + Union, +) + +from mindnlp import core +from mindnlp import core.distributed.rpc as rpc +from mindnlp.core import device, dtype, nn, Tensor +from core.distributed import _remote_device +from core.distributed.nn.jit import instantiator +from core.distributed.rpc.internal import _internal_rpc_pickler +from core.nn import Module +from core.nn.parameter import Parameter +from core.utils.hooks import RemovableHandle + + +__all__ = ["RemoteModule"] + +_grad_t = Union[Tuple[Tensor, ...], Tensor] +# See https://mypy.readthedocs.io/en/latest/generics.html#generic-methods-and-generic-self for the use +# of `T` to annotate `self`. Many methods of `Module` return `self` and we want those return values to be +# the type of the subclass, not the looser type of `Module`. +T = TypeVar("T", bound="Module") + +_NON_SCRIPTABLE_REMOTE_MODULE_MODULE = ( + instantiator.instantiate_non_scriptable_remote_module_template() +) + +_REMOTE_MODULE_PICKLED_ATTRIBUTES = ( + "on", + "device", + "is_device_map_set", + "is_scriptable", + "generated_methods", + "module_rref", +) + +_SerializedRemoteModule = collections.namedtuple("_SerializedRemoteModule", _REMOTE_MODULE_PICKLED_ATTRIBUTES) # type: ignore[misc] + +# These attributes are mostly from RemoteModule's parent class and are intentionally not pickled. +# A new attribute of RemoteModule should be either in _REMOTE_MODULE_PICKLED_ATTRIBUTES +# or _REMOTE_MODULE_ATTRIBUTES_IGNORE_FOR_PICKLING. +# Otherwise, it will not be pickled. +_REMOTE_MODULE_ATTRIBUTES_IGNORE_FOR_PICKLING = ( + "training", + "_parameters", + "_buffers", + "_non_persistent_buffers_set", + "_backward_hooks", + "_backward_pre_hooks", + "_is_full_backward_hook", + "_forward_hooks", + "_forward_hooks_with_kwargs", + "_forward_hooks_always_called", + "_forward_pre_hooks", + "_forward_pre_hooks_with_kwargs", + "_state_dict_hooks", + "_state_dict_pre_hooks", + "_load_state_dict_pre_hooks", + "_load_state_dict_post_hooks", + "_state_dict_pre_hooks", + "_modules", + # The two attributes below are generated methods, not available at pickling time. + "forward_async", + "forward", +) + + +# RPC handler. +def _instantiate_template(module_interface_cls, enable_moving_cpu_tensors_to_cuda): + instantiator.instantiate_scriptable_remote_module_template( + module_interface_cls, enable_moving_cpu_tensors_to_cuda + ) + + +def _create_module(module_cls, args, kwargs, device): + module = module_cls(*args, **kwargs) + if not isinstance(module, nn.Module): + raise ValueError( + "Expect `module_cls(*args, **kwargs)` returns an instance of , " + f"but it returns an instance of {type(module)}." + ) + module.to(device) + return module + + +def _create_module_with_interface( + module_cls, args, kwargs, device, module_interface_cls +): + module = _create_module(module_cls, args, kwargs, device) + if module_interface_cls is not None: + module = core.jit.script(module) + return rpc.RRef(module, module_interface_cls) + + +def _param_rrefs(module_rref, recurse) -> List[rpc.RRef[Parameter]]: + ret: List[rpc.RRef[Parameter]] = [ + rpc.RRef(param) for param in module_rref.local_value().parameters(recurse) + ] + return ret + + +def _raise_not_supported(name: str) -> None: + raise ValueError(f"Method ``{name}`` not supported for RemoteModule") + + +class _RemoteModule(nn.Module): + def __new__(cls, *args, **kwargs): + # Use __new__ for logging purposes. + core._C._log_api_usage_once("core.distributed.nn.api.remote_module") + return super().__new__(cls) + + def __init__( + self, + remote_device: str, + module_cls: Type[nn.Module], + args: Optional[Tuple] = None, + kwargs: Optional[Dict[str, Any]] = None, + _module_interface_cls: Any = None, + ): + """ + RemoteModule instance can only be created after RPC initialization. + + It creates a user-specified module on a specified remote node. + It behaves like a regular ``nn.Module`` except that the ``forward`` method is + executed on the remote node. + It takes care of autograd recording to ensure the backward pass propagates + gradients back to the corresponding remote module. + It can be shared across processors using `RPC framework `__, + without incurring any overheads of copying the actual module, + which is equivalent to an :class:`~core.distributed.rpc.RRef` + pointing to the remote module. + + The arguments of ``forward_async`` and ``forward`` are the same as + the ``forward`` method of the module returned by the ``module_cls``. + + Apart from ``forward_async`` and ``forward``, no other methods are supported from nn.Module for now. + + Particularly, to create a hybrid model, typically the local modules should be + created outside of remote modules, rather than as submodules of any remote module (by calling ``add_module``). + Hybrid Example: + >>> class HybridModel(nn.Module): + >>> def __init__(self) -> None: + >>> nn.Module.__init__(self) + >>> self.remote_embedding = RemoteModule(...) + >>> self.local_linear = nn.Linear(...) + + For example, if ``module_cls`` returns an instance of ``nn.Linear``, + that has ``forward`` method signature, ``def forward(input: Tensor) -> Tensor:``, + the generated ``RemoteModule`` will have 2 methods in signature of + ``def forward(input: Tensor) -> Tensor:`` and + ``def forward_async(input: Tensor) -> Future[Tensor]:``. + + .. note:: + If the remote module is placed on a cuda device, + any input CPU tensors will be automatically moved to the same cuda device, + and GPU tensors are returned over the wire according to the device map of the remote worker on TensorPipe RPC backend. + + Args: + remote_device (str): Device on the destination worker where we'd like to place this module. + The device can be a local device or a remote device specified by one of the following remote + formats: + + 1. "rank:/" (ex: "rank:0/cuda:0"). + 2. "/" (ex: "trainer0/cuda:0"). + + In addition, the device field can be optional and the default value is "cpu". + module_cls (nn.Module): For example, + >>> class MyModule(nn.Module): + >>> def forward(input): + >>> return input + 1 + >>> + >>> module_cls = MyModule + args (Sequence, optional): args to be passed to ``module_cls``. + kwargs (Dict, optional): kwargs to be passed to ``module_cls``. + _module_interface_cls (type, optional): The TorchScript interface type for the module + to be created. The type object should be decorated by @core.jit.interface. + If not provided, the generated RemoteModule is not torchscript-able. + Warning, this is an experimental API and susceptible to frequent changes. + + Returns: + A remote module instance which wraps the :class:`~nn.Module` created by the + user-provided ``module_cls``, it has a blocking ``forward`` method and an + asynchronous ``forward_async`` method that returns a future of the ``forward`` call + on the user-provided module on the remote side. + + Example:: + Run the following code in two different processes: + + >>> # xdoctest: +SKIP("distributed") + >>> # On worker 0: + >>> from mindnlp import core + >>> from mindnlp import core.distributed.rpc as rpc + >>> from mindnlp.core import nn, Tensor + >>> from core.distributed.nn.api.remote_module import RemoteModule + >>> + >>> rpc.init_rpc("worker0", rank=0, world_size=2) + >>> remote_linear_module = RemoteModule( + >>> "worker1/cpu", nn.Linear, args=(20, 30), + >>> ) + >>> input = core.randn(128, 20) + >>> ret_fut = remote_linear_module.forward_async(input) + >>> ret = ret_fut.wait() + >>> rpc.shutdown() + + >>> # On worker 1: + >>> from mindnlp import core + >>> from mindnlp import core.distributed.rpc as rpc + >>> + >>> rpc.init_rpc("worker1", rank=1, world_size=2) + >>> rpc.shutdown() + """ + super().__init__() + + enable_moving_cpu_tensors_to_cuda = self._prepare_init(remote_device) + + # Default arguments preparation. + args = args if args is not None else () + kwargs = kwargs if kwargs is not None else {} + + if _module_interface_cls is not None: + # Users reply on this field to know if this generated RemoteModule is TorchScript-able. + self.is_scriptable = True + + # Instantiate template on remote side. + fut = rpc.rpc_async( + self.on, + _instantiate_template, + (_module_interface_cls, enable_moving_cpu_tensors_to_cuda), + ) + + self._init_template( + _module_interface_cls, enable_moving_cpu_tensors_to_cuda + ) + + # Instantiate template on remote side. + fut = rpc.rpc_async( + self.on, + _instantiate_template, + (_module_interface_cls, enable_moving_cpu_tensors_to_cuda), + ) + + # Create the module on the remote side. + fut.wait() # Ensure remote_module_cls is available on remote side. + + # TODO: We need to change this to rpc.remote, and make it async (see the else branch below). + # For that we need to be able to apply _module_interface_cls to the RRef returned by rpc.remote + # See https://github.com/pytorch/pytorch/issues/58098 for more context. + self.module_rref = rpc.rpc_sync( + self.on, + _create_module_with_interface, + (module_cls, args, kwargs, self.device, _module_interface_cls), + ) + else: + self.is_scriptable = False + self.generated_methods = ( + _NON_SCRIPTABLE_REMOTE_MODULE_MODULE._generated_methods + ) + # Create the module on the remote side. + self.module_rref = rpc.remote( + self.on, + _create_module, + (module_cls, args, kwargs, self.device), + ) + + self._install_generated_methods() + self._check_attribute_picklability() + + def remote_parameters(self, recurse: bool = True) -> List[rpc.RRef[Parameter]]: + """ + Return a list of :class:`~core.distributed.rpc.RRef` pointing to the remote module's parameters. + + This can typically be used in conjunction + with :class:`~core.distributed.optim.DistributedOptimizer`. + + Args: + recurse (bool): if True, then returns parameters of the remote + module and all submodules of the remote module. Otherwise, + returns only parameters that are direct members of the + remote module. + + Returns: + A list of :class:`~core.distributed.rpc.RRef` (``List[RRef[nn.Parameter]]``) + to remote module's parameters. + """ + return rpc.rpc_sync(self.on, _param_rrefs, args=(self.module_rref, recurse)) + + def get_module_rref(self) -> rpc.RRef[nn.Module]: + """Return an :class:`~core.distributed.rpc.RRef` (``RRef[nn.Module]``) pointing to the remote module.""" + return self.module_rref + + @core.jit.export + def __getstate__(self): + raise RuntimeError( + "Cannot pickle RemoteModule in python pickler. RemoteModule can only be pickled when using RPC" + ) + + @core.jit.export + def __setstate__(self, state): + raise RuntimeError( + "Cannot unpickle RemoteModule in python pickler. RemoteModule can only be unpickled when using RPC" + ) + + def register_buffer( + self, name: str, tensor: Optional[Tensor], persistent: bool = True + ) -> None: + _raise_not_supported(self.register_buffer.__name__) + + def register_parameter(self, name: str, param: Optional[Parameter]) -> None: + _raise_not_supported(self.register_parameter.__name__) + + def add_module(self, name: str, module: Optional[Module]) -> None: + _raise_not_supported(self.add_module.__name__) + + def apply(self: T, fn: Callable[[Module], None]) -> T: # type: ignore[return] + _raise_not_supported(self.apply.__name__) + + def cuda(self: T, device: Optional[Union[int, device]] = None) -> T: # type: ignore[return] + _raise_not_supported(self.cuda.__name__) + + def ipu(self: T, device: Optional[Union[int, device]] = None) -> T: # type: ignore[return] + _raise_not_supported(self.ipu.__name__) + + def xpu(self: T, device: Optional[Union[int, device]] = None) -> T: # type: ignore[return] + _raise_not_supported(self.xpu.__name__) + + def cpu(self: T) -> T: # type: ignore[return] + _raise_not_supported(self.cpu.__name__) + + def type(self: T, dst_type: Union[dtype, str]) -> T: # type: ignore[return] + _raise_not_supported(self.type.__name__) + + def float(self: T) -> T: # type: ignore[return] + _raise_not_supported(self.float.__name__) + + def double(self: T) -> T: # type: ignore[return] + _raise_not_supported(self.double.__name__) + + def half(self: T) -> T: # type: ignore[return] + _raise_not_supported(self.half.__name__) + + def bfloat16(self: T) -> T: # type: ignore[return] + _raise_not_supported(self.bfloat16.__name__) + + def to(self, *args, **kwargs) -> T: # type: ignore[misc, return, type-var] + _raise_not_supported(self.to.__name__) + + def register_backward_hook( # type: ignore[return] + self, hook: Callable[[Module, _grad_t, _grad_t], Union[None, _grad_t]] + ) -> RemovableHandle: + _raise_not_supported(self.register_backward_hook.__name__) + + def register_forward_pre_hook( # type: ignore[return] + self, + hook: Union[ + Callable[[T, Tuple[Any, ...]], Optional[Any]], + Callable[ + [T, Tuple[Any, ...], Dict[str, Any]], + Optional[Tuple[Any, Dict[str, Any]]], + ], + ], + prepend: bool = False, + with_kwargs: bool = False, + ) -> RemovableHandle: + _raise_not_supported(self.register_forward_pre_hook.__name__) + + def register_forward_hook( # type: ignore[return, override] + self, + hook: Union[ + Callable[[T, Tuple[Any, ...], Any], Optional[Any]], + Callable[[T, Tuple[Any, ...], Dict[str, Any], Any], Optional[Any]], + ], + prepend: bool = False, + with_kwargs: bool = False, + ) -> RemovableHandle: + _raise_not_supported(self.register_forward_hook.__name__) + + def state_dict(self, *args, **kwargs): + _raise_not_supported(self.state_dict.__name__) + + def load_state_dict( + self, + state_dict: Mapping[str, Any], + strict: bool = True, + assign: bool = False, + ): + _raise_not_supported(self.load_state_dict.__name__) + + def parameters(self, recurse: bool = True) -> Iterator[Parameter]: + raise ValueError( + "Method ``parameters`` not supported for RemoteModule. Please use ``remote_parameters`` instead." + ) + + def named_parameters( # type: ignore[return] + self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True + ) -> Iterator[Tuple[str, Parameter]]: + _raise_not_supported(self.named_parameters.__name__) + + def buffers(self, recurse: bool = True) -> Iterator[Tensor]: # type: ignore[return] + _raise_not_supported(self.buffers.__name__) + + def named_buffers( # type: ignore[return] + self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True + ) -> Iterator[Tuple[str, Tensor]]: + _raise_not_supported(self.named_buffers.__name__) + + def children(self) -> Iterator[Module]: # type: ignore[return] + _raise_not_supported(self.children.__name__) + + def named_children(self) -> Iterator[Tuple[str, Module]]: # type: ignore[return] + _raise_not_supported(self.named_children.__name__) + + def modules(self) -> Iterator[Module]: # type: ignore[return] + _raise_not_supported(self.modules.__name__) + + def named_modules( + self, + memo: Optional[Set[Module]] = None, + prefix: str = "", + remove_duplicate: bool = True, + ): + _raise_not_supported(self.named_modules.__name__) + + def train(self: T, mode: bool = True) -> T: + return self.module_rref.rpc_sync().train() # type: ignore[operator, union-attr] + + def eval(self: T) -> T: + return self.module_rref.rpc_sync().eval() # type: ignore[operator, union-attr] + + def requires_grad_(self: T, requires_grad: bool = True) -> T: # type: ignore[return] + _raise_not_supported(self.requires_grad_.__name__) + + def zero_grad(self, set_to_none: bool = True) -> None: + _raise_not_supported(self.zero_grad.__name__) + + def share_memory(self: T) -> T: # type: ignore[return] + _raise_not_supported(self.share_memory.__name__) + + def extra_repr(self) -> str: # type: ignore[return] + _raise_not_supported(self.extra_repr.__name__) + + def _prepare_init(self, remote_device_str: str) -> bool: + """Prepare the initialization and returns whether to enable automatically moving CPU tensors to CUDA devices.""" + # Sanity check. + assert rpc._is_current_rpc_agent_set(), "RemoteModule only works in RPC." + + remote_device = _remote_device(remote_device_str) + self.on = ( + remote_device.worker_name() + if remote_device.worker_name() is not None + else remote_device.rank() + ) + self.device = str(remote_device.device()) + agent = rpc._get_current_rpc_agent() + # If the device map of the remote worker is set, + # then enable moving any input CPU tensors to the same cuda device. + self.is_device_map_set = bool( + agent._get_device_map(agent.get_worker_info(self.on)) # type: ignore[arg-type] + ) + # ``enable_moving_cpu_tensors_to_cuda`` is less strict than ``is_device_map_set``: + # If ``enable_moving_cpu_tensors_to_cuda`` is true, but the device map is not set, + # then any CPU tensors can still be moved to a cuda device to run forward, + # but the output must be moved back to CPU before being sent over the wire. + enable_moving_cpu_tensors_to_cuda = core.device(self.device).type == "cuda" + return enable_moving_cpu_tensors_to_cuda + + def _init_template(self, module_interface_cls, enable_moving_cpu_tensors_to_cuda): + """Instantiate template on local side.""" + generated_module = instantiator.instantiate_scriptable_remote_module_template( + module_interface_cls, enable_moving_cpu_tensors_to_cuda + ) + self.generated_methods = generated_module._generated_methods + + def _check_attribute_picklability(self): + """Check if all the attribute has explicitly defined whether to be pickled (i.e., picklability).""" + for k in self.__dict__.keys(): + if ( + k not in _REMOTE_MODULE_PICKLED_ATTRIBUTES + and k not in _REMOTE_MODULE_ATTRIBUTES_IGNORE_FOR_PICKLING + ): + raise AttributeError( + f"Attribute {k} must be either in ``_REMOTE_MODULE_PICKLED_ATTRIBUTES`` or " + "``_REMOTE_MODULE_ATTRIBUTES_IGNORE_FOR_PICKLING``." + ) + + def _install_generated_methods(self): + for method in self.generated_methods: + method_name = method.__name__ + method = core.jit.export(method) + setattr(self, method_name, types.MethodType(method, self)) + + @staticmethod + def init_from_module_rref( + remote_device: str, + module_rref: rpc.RRef[nn.Module], + _module_interface_cls: Any = None, + ): + """ + Besides the constructor, a RemoteModule instance can also be initialized given a module RRef. + + This alternate initialization method can be particularly useful if we want to create multiple + RemoteModule instances that share the same underlying module and reduce memory consumption. + + Moreover, this also provides a workaround for passing script RemoteModule over RPC, + which is not supported. The recommended way is as follows: + + 1. the sender creates a RemoteModule; + 2. the sender sends its ``module_rref`` over RPC; + 3. the receiver calls this method to initialize another RemoteModule using the same ``module_rref``. + + Example:: + Run the following code in two different processes: + + >>> # xdoctest: +SKIP("distributed") + >>> # On worker 0: + >>> from mindnlp import core + >>> from mindnlp import core.distributed.rpc as rpc + >>> from mindnlp.core import nn, Tensor + >>> from core.distributed.nn.api.remote_module import RemoteModule + >>> + >>> rpc.init_rpc("worker0", rank=0, world_size=2) + >>> remote_module = RemoteModule( + >>> "worker1/cpu", nn.Linear, args=(20, 30), + >>> ) + >>> + >>> remote_module1 = rpc.rpc_sync( + >>> "worker1/cpu", + >>> RemoteModule.init_from_module_rref, + >>> ("worker1/cpu", remote_module1.get_module_rref()), + >>> ) + >>> rpc.shutdown() + + >>> # On worker 1: + >>> from mindnlp import core + >>> from mindnlp import core.distributed.rpc as rpc + >>> + >>> rpc.init_rpc("worker1", rank=1, world_size=2) + >>> rpc.shutdown() + + Args: + remote_device (str): Device on the destination worker where we'd like to place this module. + The device can be a local device or a remote device specified by one of the following remote + formats: + + 1. "rank:/" (ex: "rank:0/cuda:0"). + 2. "/" (ex: "trainer0/cuda:0"). + + In addition, the device field can be optional and the default value is "cpu". + module_rref (RRef[nn.Module]): The module reference shared by both the caller and + the created remote module. + _module_interface_cls (type, optional): The TorchScript interface type for the module + to be created. The type object should be decorated by @core.jit.interface. + If not provided, the generated RemoteModule is not torchscript-able. + Warning, this is an experimental API and susceptible to frequent changes. + + Returns: + A remote module instance which wraps the :class:`~nn.Module` created by the + user-provided ``module_rref``, it has a blocking ``forward`` method and an + asynchronous ``forward_async`` method that returns a future of the ``forward`` call + on the user-provided module on the remote side. + """ + # NOTE: if a new attribute is added to this class, also need to add it + # to ``_REMOTE_MODULE_PICKLED_ATTRIBUTES`` for pickling/unpickling. + + remote_module = object.__new__(RemoteModule) + + enable_moving_cpu_tensors_to_cuda = remote_module._prepare_init(remote_device) + + if _module_interface_cls is not None: + # Users reply on this field to know if this generated RemoteModule is TorchScript-able. + remote_module.is_scriptable = True + + remote_module._init_template( + _module_interface_cls, enable_moving_cpu_tensors_to_cuda + ) + else: + remote_module.is_scriptable = False + remote_module.generated_methods = ( + _NON_SCRIPTABLE_REMOTE_MODULE_MODULE._generated_methods + ) + remote_module.module_rref = module_rref + + remote_module._install_generated_methods() + remote_module._check_attribute_picklability() + + return remote_module + + +class RemoteModule(_RemoteModule): + """ + A RemoteModule instance can only be created after RPC initialization. + + It creates a user-specified module on a specified remote node. + It behaves like a regular ``nn.Module`` except that the ``forward`` method is + executed on the remote node. + It takes care of autograd recording to ensure the backward pass propagates + gradients back to the corresponding remote module. + + It generates two methods ``forward_async`` and ``forward`` based on the + signature of the ``forward`` method of ``module_cls``. ``forward_async`` + runs asynchronously and returns a Future. The arguments of ``forward_async`` + and ``forward`` are the same as the ``forward`` method of the module + returned by the ``module_cls``. + + For example, if ``module_cls`` returns an instance of ``nn.Linear``, + that has ``forward`` method signature: ``def forward(input: Tensor) -> Tensor:``, + the generated ``RemoteModule`` will have 2 methods with the signatures: + + | ``def forward(input: Tensor) -> Tensor:`` + | ``def forward_async(input: Tensor) -> Future[Tensor]:`` + + Args: + remote_device (str): Device on the destination worker where we'd like to place this module. + The format should be "/", where the device field can be parsed as core.device type. + E.g., "trainer0/cpu", "trainer0", "ps0/cuda:0". + In addition, the device field can be optional and the default value is "cpu". + module_cls (nn.Module): Class for the module to be created remotely. For example, + + >>> class MyModule(nn.Module): + >>> def forward(input): + >>> return input + 1 + >>> + >>> module_cls = MyModule + + args (Sequence, optional): args to be passed to ``module_cls``. + kwargs (Dict, optional): kwargs to be passed to ``module_cls``. + + Returns: + A remote module instance which wraps the :class:`~nn.Module` created by the + user-provided ``module_cls``, it has a blocking ``forward`` method and an + asynchronous ``forward_async`` method that returns a future of the ``forward`` call + on the user-provided module on the remote side. + + Example:: + Run the following code in two different processes: + + >>> # xdoctest: +SKIP("distributed") + >>> # On worker 0: + >>> from mindnlp import core + >>> from mindnlp import core.distributed.rpc as rpc + >>> from mindnlp.core import nn, Tensor + >>> from core.distributed.nn.api.remote_module import RemoteModule + >>> + >>> rpc.init_rpc("worker0", rank=0, world_size=2) + >>> remote_linear_module = RemoteModule( + >>> "worker1/cpu", nn.Linear, args=(20, 30), + >>> ) + >>> input = core.randn(128, 20) + >>> ret_fut = remote_linear_module.forward_async(input) + >>> ret = ret_fut.wait() + >>> rpc.shutdown() + + >>> # On worker 1: + >>> from mindnlp import core + >>> from mindnlp import core.distributed.rpc as rpc + >>> + >>> rpc.init_rpc("worker1", rank=1, world_size=2) + >>> rpc.shutdown() + + Furthermore, a more practical example that is combined with + `DistributedDataParallel `__ (DDP) + can be found in this `tutorial `__. + """ + + def __init__( + self, + remote_device: str, + module_cls: Type[nn.Module], + args: Optional[Tuple] = None, + kwargs: Optional[Dict[str, Any]] = None, + ): + super().__init__(remote_device, module_cls, args, kwargs) + + +def _remote_module_receiver( + *remote_module_pickled_attrs, +): + """Deserializes a RemoteModule.""" + serialized_remote_module = _SerializedRemoteModule._make( + remote_module_pickled_attrs + ) + m = object.__new__(RemoteModule) + m.__dict__.update(serialized_remote_module._asdict()) + + # Unpickling the attribute `module_rref` must invoke RRef's `_deserialize()` method. + m.module_rref = rpc.PyRRef._deserialize(m.module_rref) + + # Install generated methods when unpickled. + for method in m.generated_methods: + method_name = method.__name__ + method = core.jit.export(method) + setattr(m, method_name, types.MethodType(method, m)) + + return m + + +def _remote_module_reducer(remote_module): + """Serialize a RemoteModule.""" + pickled_attrs = {} + for k, v in remote_module.__dict__.items(): + # Pickling the attribute `module_rref` must invoke RRef's `_serialize()` method. + if k == "module_rref": + pickled_attrs[k] = v._serialize() + elif k in _REMOTE_MODULE_PICKLED_ATTRIBUTES: + pickled_attrs[k] = v + # Check if unpickled attributes are all in _REMOTE_MODULE_ATTRIBUTES_IGNORE_FOR_PICKLING. + elif k not in _REMOTE_MODULE_ATTRIBUTES_IGNORE_FOR_PICKLING: + print( + f"The new attribute ``{k}`` of RemoteModule is ignored during RPC pickling. " + "To pickle this attribute, please add it to ``_REMOTE_MODULE_PICKLED_ATTRIBUTES``. " + "Otherwise, please explicitly add it to ``_REMOTE_MODULE_ATTRIBUTES_IGNORE_FOR_PICKLING``.", + file=sys.stderr, + ) + + return ( + _remote_module_receiver, + tuple(pickled_attrs.values()), + ) + + +def _recursive_script_module_receiver( + recursive_script_module_serialized, +): + """Deserializes a RecursiveScriptModule that does not contain a script RemoteModule.""" + f = io.BytesIO(recursive_script_module_serialized) + m = core.jit.load(f) + return m + + +def _recursive_script_module_reducer(recursive_script_module): + """Serialize a RecursiveScriptModule that does not contain a script RemoteModule, and raises an error otherwise.""" + if hasattr(recursive_script_module._c, "module_rref"): + raise RuntimeError( + "Passing a script RemoteModule over RPC is not supported. Please create a RemoteModule in the sender, " + "send the `module_rref` to the receiver, and create a new instance on the receiver end by passing this `module_rref`." + ) + + f = io.BytesIO() + core.jit.save(recursive_script_module, f) + return (_recursive_script_module_receiver, (f.getvalue(),)) + + +_internal_rpc_pickler._register_reducer(RemoteModule, _remote_module_reducer) +_internal_rpc_pickler._register_reducer( + core.jit.RecursiveScriptModule, _recursive_script_module_reducer +) diff --git a/mindnlp/core/distributed/nn/functional.py b/mindnlp/core/distributed/nn/functional.py new file mode 100644 index 000000000..5f72df508 --- /dev/null +++ b/mindnlp/core/distributed/nn/functional.py @@ -0,0 +1,452 @@ +# mypy: allow-untyped-defs +from mindnlp import core +from mindnlp import core.distributed as dist +from core.autograd import Function + +# The two imports below are not always available depending on the +# USE_DISTRIBUTED compile flag. Make sure they raise import error +# if we're trying to use them. +from core.distributed import group, ReduceOp + + +def broadcast(tensor, src, group=group.WORLD): + """ + Broadcasts the tensor to the whole group. + + ``tensor`` must have the same number of elements in all processes + participating in the collective. + + Arguments: + tensor (Tensor): Data to be sent if ``src`` is the rank of current + process. + src (int): Source rank. + group (ProcessGroup, optional): The process group to work on. + + Returns: + Tensor: Received tensor from the broadcast op. + + """ + return _Broadcast.apply(src, group, tensor) + + +def gather(tensor, dst=0, group=group.WORLD): + """ + Gathers a list of tensors in a single process. + + Arguments: + tensor (Tensor): Input tensor. + dst (int, optional): Destination rank (default is 0). + group (ProcessGroup, optional): The process group to work on. + + Returns: + tuple[Tensor]: List of appropriately-sized tensors with the gathered data. + """ + return _Gather.apply(dst, group, tensor) + + +def scatter(tensors, src=0, group=group.WORLD): + """ + Scatters a list of tensors to all processes in a group. + + Each process will receive exactly one tensor and store its data in the + ``tensor`` argument. + + Arguments: + tensors (list[Tensor]): List of tensors to scatter on the source rank. + Receivers must pass ``None`. + src (int, optional): Source rank (default is 0). + group (ProcessGroup, optional): The process group to work on. + + Returns: + Tensor: Output tensor from the scatter operation. + + """ + return _Scatter.apply(src, group, *tensors) + + +def reduce(tensor, dst, op=ReduceOp.SUM, group=group.WORLD): + """ + Reduces the tensor data across all machines. + + Only the process with rank ``dst`` is going to receive the final result. + + Arguments: + tensor (Tensor): Input of the collective. + dst (int): Destination rank. + op (optional): One of the values from + ``core.distributed.ReduceOp`` + enum. Specifies an operation used for element-wise reductions. + group (ProcessGroup, optional): The process group to work on. + + Returns: + Tensor: Output of the collective. + + """ + return _Reduce.apply(dst, op, group, tensor) + + +def reduce_scatter(output, input_list, op=ReduceOp.SUM, group=group.WORLD): + """ + Reduces, then scatters a list of tensors to all processes in a group. + + Arguments: + output (Tensor): Output tensor. + input_list (list[Tensor]): List of tensors to reduce and scatter. + op (optional): One of the values from + ``core.distributed.ReduceOp`` + enum. Specifies an operation used for element-wise reductions. + group (ProcessGroup, optional): The process group to work on. + + Returns: + Tensor: Output of the collective. + + """ + return _Reduce_Scatter.apply(op, group, output, *input_list) + + +def all_gather(tensor, group=group.WORLD): + """ + Gathers tensors from the whole group in a list. + + Arguments: + tensor (Tensor): Tensor to be broadcast from current process. + group (ProcessGroup, optional): The process group to work on. + + Returns: + tuple([Tensor]): Output of the collective. + + """ + return _AllGather.apply(group, tensor) + + +def _all_gather_base(output_tensor, input_tensor, group=group.WORLD): + """ + Single tensor all gather. Gathers a single tensor from all ranks, and puts them in a single output tensor. + + Args: + output_tensor (Tensor): Output tensor. It should contain + correctly-sized tensors to be used for output of the collective. + input_tensor (Tensor): Tensor to be broadcast from current process. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + + Examples: + >>> # All tensors below are of core.int64 dtype. + >>> # We have 2 process groups, 2 ranks. + >>> # xdoctest: +SKIP("incorrect want text") + >>> output_tensor = core.zeros(2, dtype=core.int64) + >>> output_tensor + [tensor([0, 0])] # Rank 0 and 1 + >>> tensor = core.arange(1, dtype=core.int64) + 1 + rank + >>> tensor + tensor([1]) # Rank 0 + tensor([2]) # Rank 1 + >>> dist.all_gather_base(output_tensor, tensor) + >>> output_tensor + tensor([1,2]) # Rank 0 + tensor([1,2]) # Rank 1 + + .. warning:: + `_all_gather_base` is experimental and subject to change. + It is the caller's responsibility to ensure the output_tensor + is correctly sized. + + """ + return _AllGatherBase.apply(output_tensor, input_tensor, group) + + +def all_to_all(output_tensor_list, input_tensor_list, group=group.WORLD): + """ + Each process scatters list of input tensors to all processes in a group and return gathered list of tensors in output list. + + Arguments: + output_tensor_list (list[Tensor]): list of tensors to gather one per rank. + input_tensor_list (list[Tensor]): List of tensors to scatter one per rank. + group (ProcessGroup, optional): The process group to work on. + + Returns: + tuple([Tensor]): Output of the collective. + + """ + return _AlltoAll.apply(group, output_tensor_list, *input_tensor_list) + + +def all_to_all_single( + output, + input, + output_split_sizes=None, + input_split_sizes=None, + group=group.WORLD, +): + """ + Each process splits input tensor and then scatters the split list to all processes in a group. + + Then concatenate the received tensors from all the processes in the group and return single output tensor. + + Arguments: + output (Tensor): Gathered concatenated output tensor. + input (Tensor): Input tensor to scatter. + output_split_sizes: (list[Int], optional): Output split sizes for dim 0 + if specified None or empty, dim 0 of ``output`` tensor must divide + equally by ``world_size``. + input_split_sizes: (list[Int], optional): Input split sizes for dim 0 + if specified None or empty, dim 0 of ``input`` tensor must divide + equally by ``world_size``. + + Returns: + Tensor: Output of the collective. + + """ + return _AlltoAllSingle.apply( + group, output, output_split_sizes, input_split_sizes, input + ) + + +def all_reduce(tensor, op=ReduceOp.SUM, group=group.WORLD): + """ + Reduces the tensor data across all machines in such a way that all get the final result. + + After the call the returned tensor is going to be bitwise + identical in all processes. + + Arguments: + tensor (Tensor): Input of the collective. + op (optional): One of the values from + ``core.distributed.ReduceOp`` + enum. Specifies an operation used for element-wise reductions. + group (ProcessGroup, optional): The process group to work on. + + Returns: + Tensor: Output of the collective + + """ + return _AllReduce.apply(op, group, tensor) + + +class _Broadcast(Function): + @staticmethod + def forward(ctx, src, group, tensor): + ctx.src = src + ctx.group = group + ctx.rank = dist.get_rank(group=group) + # core.distributed makes all the calls in place + # we allocate new tensors to avoid this + tensor = tensor.clone() + dist.broadcast(tensor, src, group=group) + return tensor + + @staticmethod + def backward(ctx, grad_output): + gx = _Reduce.apply(ctx.src, ReduceOp.SUM, ctx.group, grad_output) + if ctx.src != ctx.rank: + gx.zero_() + return (None, None, gx) + + +class _Gather(Function): + @staticmethod + def forward(ctx, dst, group, tensor): + ctx.dst = dst + ctx.group = group + # Need to create a list of tensors here to do the + # aggregation, get it from the group size + # tensor should be correctly sized for the method + # gathering + tensor_list = [ + core.zeros_like(tensor) for i in range(dist.get_world_size(group=group)) + ] + + tensor = tensor.contiguous() + if dist.get_rank(group=group) == dst: + dist.gather(tensor, tensor_list, dst, group=group) + else: + dist.gather(tensor, None, dst, group=group) + return tuple(tensor_list) + + @staticmethod + def backward(ctx, *grad_outputs): + return (None, None) + (_Scatter.apply(ctx.dst, ctx.group, *grad_outputs),) + + +class _Scatter(Function): + @staticmethod + def forward(ctx, src, group, *tensors): + ctx.src = src + ctx.group = group + assert all(t.size() == tensors[0].size() for t in tensors) + output = core.zeros_like(tensors[0]) + if dist.get_rank(group=group) == src: + dist.scatter(output, list(tensors), src, group=group) + else: + dist.scatter(output, None, src, group=group) + return output + + @staticmethod + def backward(ctx, grad_output): + return (None, None) + _Gather.apply(ctx.src, ctx.group, grad_output) + + +class _Reduce(Function): + @staticmethod + def forward(ctx, src, op, group, tensor): + ctx.src = src + ctx.group = group + tensor = tensor.clone() + dist.reduce(tensor, src, op=op, group=group) + return tensor + + @staticmethod + def backward(ctx, grad_output): + return (None, None, None) + (_Broadcast.apply(ctx.src, ctx.group, grad_output),) + + +class _Reduce_Scatter(Function): + @staticmethod + def forward(ctx, op, group, tensor, *input_tensor_list): + ctx.group = group + # Need contiguous tensors for collectives. + tensor = tensor.contiguous() + input_tensor_list = tuple(t.contiguous() for t in input_tensor_list) + dist.reduce_scatter(tensor, list(input_tensor_list), op=op, group=group) + return tensor + + @staticmethod + def backward(ctx, grad_output): + return (None, None, None) + _AllGather.apply(ctx.group, grad_output) + + +class _AllGather(Function): + @staticmethod + def forward(ctx, group, tensor): + # Need contiguous tensors for collectives. + tensor = tensor.contiguous() + + ctx.group = group + out_tensor_list = [ + core.empty_like(tensor) for _ in range(dist.get_world_size(group=group)) + ] + + dist.all_gather(out_tensor_list, tensor, group=group) + return tuple(out_tensor_list) + + @staticmethod + def backward(ctx, *grad_outputs): + if dist.get_backend(group=ctx.group) is dist.Backend.NCCL: + rank = dist.get_rank(group=ctx.group) + gx = core.empty_like(grad_outputs[rank]) + gx = _Reduce_Scatter.apply(ReduceOp.SUM, ctx.group, gx, *grad_outputs) + else: + # As many backends doesn't support ReduceScatter, we use AlltoAll with .sum() + # to emulate the ReduceScatter behavior + tensor_list = [core.empty_like(tensor) for tensor in grad_outputs] + gxs = _AlltoAll.apply(ctx.group, tensor_list, *grad_outputs) + gx = core.sum(core.stack(gxs), dim=0) + return (None, gx) + + +class _AllGatherBase(Function): + @staticmethod + def forward(ctx, output_tensor, input_tensor, group): + ctx.group = group + dist._all_gather_base(output_tensor, input_tensor.contiguous(), group=group) + return output_tensor + + @staticmethod + def backward(ctx, grad_output): + if dist.get_backend(group=ctx.group) is dist.Backend.NCCL: + world_size = dist.get_world_size(group=ctx.group) + out_size = list(grad_output.size()) + if out_size[0] % world_size != 0: + raise RuntimeError( + f"Tensor with dimensions: {out_size} does " + f"not have first dimension divisible by world_size: {world_size}" + ) + out_size[0] = out_size[0] // dist.get_world_size(group=ctx.group) + gx = core.empty( + out_size, device=grad_output.device, dtype=grad_output.dtype + ) + dist._reduce_scatter_base(gx, grad_output, ReduceOp.SUM, ctx.group) + else: + raise RuntimeError("Backend not supported!") + return (None, gx, None) + + +class _AlltoAll(Function): + @staticmethod + def forward(ctx, group, out_tensor_list, *tensors): + ctx.group = group + ctx.input_tensor_size_list = [ + tensors[i].size() for i in range(dist.get_world_size(group=group)) + ] + my_rank = dist.get_rank(group=group) + tensors = tuple(t.contiguous() for t in tensors) + # Implement it on means of scatter/gather, send/recv async operations have issues + if dist.get_backend(group=group) is dist.Backend.GLOO: + for i in range(dist.get_world_size(group=group)): + to_send = None + if i == my_rank: + to_send = list(tensors) + dist.scatter(out_tensor_list[i], to_send, i, group=group) + else: + dist.all_to_all( + out_tensor_list, + list(tensors), + group=group, + ) + return tuple(out_tensor_list) + + @staticmethod + def backward(ctx, *grad_outputs): + tensor_list = [ + core.empty( + size, device=grad_outputs[0].device, dtype=grad_outputs[0].dtype + ) + for size in ctx.input_tensor_size_list + ] + return (None, None) + _AlltoAll.apply(ctx.group, tensor_list, *grad_outputs) + + +class _AlltoAllSingle(Function): + @staticmethod + def forward(ctx, group, output, output_split_sizes, input_split_sizes, input): + ctx.group = group + ctx.input_size = input.size() + ctx.output_split_sizes = input_split_sizes + ctx.input_split_sizes = output_split_sizes + dist.all_to_all_single( + output, + input, + output_split_sizes=output_split_sizes, + input_split_sizes=input_split_sizes, + group=group, + ) + return output + + @staticmethod + def backward(ctx, grad_output): + tensor = core.empty( + ctx.input_size, device=grad_output.device, dtype=grad_output.dtype + ) + return (None, None, None, None) + ( + _AlltoAllSingle.apply( + ctx.group, + tensor, + ctx.output_split_sizes, + ctx.input_split_sizes, + grad_output.contiguous(), + ), + ) + + +class _AllReduce(Function): + @staticmethod + def forward(ctx, op, group, tensor): + ctx.group = group + ctx.op = op + tensor = tensor.clone() + dist.all_reduce(tensor, op=op, group=group) + return tensor + + @staticmethod + def backward(ctx, grad_output): + return (None, None) + (_AllReduce.apply(ctx.op, ctx.group, grad_output),) diff --git a/mindnlp/core/distributed/nn/jit/__init__.py b/mindnlp/core/distributed/nn/jit/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/mindnlp/core/distributed/nn/jit/instantiator.py b/mindnlp/core/distributed/nn/jit/instantiator.py new file mode 100644 index 000000000..f3d2009c3 --- /dev/null +++ b/mindnlp/core/distributed/nn/jit/instantiator.py @@ -0,0 +1,154 @@ +#!/usr/bin/python3 +# mypy: allow-untyped-defs +import importlib +import logging +import os +import sys +import tempfile +from typing import Optional + +from mindnlp import core +from core.distributed.nn.jit.templates.remote_module_template import ( + get_remote_module_template, +) + + +logger = logging.getLogger(__name__) + + +_FILE_PREFIX = "_remote_module_" +_TEMP_DIR = tempfile.TemporaryDirectory() +INSTANTIATED_TEMPLATE_DIR_PATH = _TEMP_DIR.name +logger.info("Created a temporary directory at %s", INSTANTIATED_TEMPLATE_DIR_PATH) +sys.path.append(INSTANTIATED_TEMPLATE_DIR_PATH) + + +def get_arg_return_types_from_interface(module_interface): + assert getattr( + module_interface, "__torch_script_interface__", False + ), "Expect a TorchScript class interface decorated by @core.jit.interface." + qualified_name = core._jit_internal._qualified_name(module_interface) + cu = core.jit._state._python_cu + module_interface_c = cu.get_interface(qualified_name) + assert ( + "forward" in module_interface_c.getMethodNames() + ), f"Expect forward in interface methods, while it has {module_interface_c.getMethodNames()}" + method_schema = module_interface_c.getMethod("forward") + + arg_str_list = [] + arg_type_str_list = [] + assert method_schema is not None + for argument in method_schema.arguments: + arg_str_list.append(argument.name) + + if argument.has_default_value(): + default_value_str = f" = {argument.default_value}" + else: + default_value_str = "" + arg_type_str = f"{argument.name}: {argument.type}{default_value_str}" + arg_type_str_list.append(arg_type_str) + + arg_str_list = arg_str_list[1:] # Remove "self". + args_str = ", ".join(arg_str_list) + + arg_type_str_list = arg_type_str_list[1:] # Remove "self". + arg_types_str = ", ".join(arg_type_str_list) + + assert len(method_schema.returns) == 1 + argument = method_schema.returns[0] + return_type_str = str(argument.type) + + return args_str, arg_types_str, return_type_str + + +def _write(out_path, text): + old_text: Optional[str] + try: + with open(out_path) as f: + old_text = f.read() + except OSError: + old_text = None + if old_text != text: + with open(out_path, "w") as f: + logger.info("Writing %s", out_path) + f.write(text) + else: + logger.info("Skipped writing %s", out_path) + + +def _do_instantiate_remote_module_template( + generated_module_name, str_dict, enable_moving_cpu_tensors_to_cuda +): + generated_code_text = get_remote_module_template( + enable_moving_cpu_tensors_to_cuda + ).format(**str_dict) + out_path = os.path.join( + INSTANTIATED_TEMPLATE_DIR_PATH, f"{generated_module_name}.py" + ) + _write(out_path, generated_code_text) + + # From importlib doc, + # > If you are dynamically importing a module that was created since + # the interpreter began execution (e.g., created a Python source file), + # you may need to call invalidate_caches() in order for the new module + # to be noticed by the import system. + importlib.invalidate_caches() + generated_module = importlib.import_module(f"{generated_module_name}") + return generated_module + + +def instantiate_scriptable_remote_module_template( + module_interface_cls, enable_moving_cpu_tensors_to_cuda=True +): + if not getattr(module_interface_cls, "__torch_script_interface__", False): + raise ValueError( + f"module_interface_cls {module_interface_cls} must be a type object decorated by " + "@core.jit.interface" + ) + + # Generate the template instance name. + module_interface_cls_name = core._jit_internal._qualified_name( + module_interface_cls + ).replace(".", "_") + generated_module_name = f"{_FILE_PREFIX}{module_interface_cls_name}" + + # Generate type annotation strs. + assign_module_interface_cls_str = ( + f"from {module_interface_cls.__module__} import " + f"{module_interface_cls.__name__} as module_interface_cls" + ) + args_str, arg_types_str, return_type_str = get_arg_return_types_from_interface( + module_interface_cls + ) + kwargs_str = "" + arrow_and_return_type_str = f" -> {return_type_str}" + arrow_and_future_return_type_str = f" -> Future[{return_type_str}]" + + str_dict = dict( + assign_module_interface_cls=assign_module_interface_cls_str, + arg_types=arg_types_str, + arrow_and_return_type=arrow_and_return_type_str, + arrow_and_future_return_type=arrow_and_future_return_type_str, + args=args_str, + kwargs=kwargs_str, + jit_script_decorator="@core.jit.script", + ) + return _do_instantiate_remote_module_template( + generated_module_name, str_dict, enable_moving_cpu_tensors_to_cuda + ) + + +def instantiate_non_scriptable_remote_module_template(): + generated_module_name = f"{_FILE_PREFIX}non_scriptable" + str_dict = dict( + assign_module_interface_cls="module_interface_cls = None", + args="*args", + kwargs="**kwargs", + arg_types="*args, **kwargs", + arrow_and_return_type="", + arrow_and_future_return_type="", + jit_script_decorator="", + ) + # For a non-scriptable template, always enable moving CPU tensors to a cuda device, + # because there is no syntax limitation on the extra handling caused by the script. + return _do_instantiate_remote_module_template(generated_module_name, str_dict, True) diff --git a/mindnlp/core/distributed/nn/jit/templates/__init__.py b/mindnlp/core/distributed/nn/jit/templates/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/mindnlp/core/distributed/nn/jit/templates/remote_module_template.py b/mindnlp/core/distributed/nn/jit/templates/remote_module_template.py new file mode 100644 index 000000000..d56ca0930 --- /dev/null +++ b/mindnlp/core/distributed/nn/jit/templates/remote_module_template.py @@ -0,0 +1,108 @@ +#!/usr/bin/python3 +# mypy: allow-untyped-defs + + +def get_remote_module_template(enable_moving_cpu_tensors_to_cuda: bool): + return _TEMPLATE_PREFIX + ( + _REMOTE_FORWARD_TEMPLATE_ENABLE_MOVING_CPU_TENSORS_TO_CUDA + if enable_moving_cpu_tensors_to_cuda + else _REMOTE_FORWARD_TEMPLATE + ) + + +_TEMPLATE_PREFIX = """from typing import * + +from mindnlp import core +from mindnlp import core.distributed.rpc as rpc +from mindnlp.core import Tensor +from core._jit_internal import Future +from core.distributed.rpc import RRef +from typing import Tuple # pyre-ignore: unused import + + +{assign_module_interface_cls} + + +def forward_async(self, {arg_types}){arrow_and_future_return_type}: + args = (self.module_rref, self.device, self.is_device_map_set, {args}) + kwargs = {{{kwargs}}} + return rpc.rpc_async( + self.module_rref.owner(), + _remote_forward, + args, + kwargs, + ) + + +def forward(self, {arg_types}){arrow_and_return_type}: + args = (self.module_rref, self.device, self.is_device_map_set, {args}) + kwargs = {{{kwargs}}} + ret_fut = rpc.rpc_async( + self.module_rref.owner(), + _remote_forward, + args, + kwargs, + ) + return ret_fut.wait() + + +_generated_methods = [ + forward_async, + forward, +] + + +{jit_script_decorator} +""" + +# This template may cause typing error (the mismatch between ``Tuple[()]`` and ``Tuple[Any]``) +# even if the code is only used for instantiation but not execution. +# Therefore, only include handling moving CPU tensors to a cuda device if necessary. +# TODO: Merge these two templates together in the future once TorchScript syntax is improved. +_REMOTE_FORWARD_TEMPLATE_ENABLE_MOVING_CPU_TENSORS_TO_CUDA = """ +def _remote_forward( + module_rref: RRef[module_interface_cls], device: str, is_device_map_set: bool, {arg_types}){arrow_and_return_type}: + module = module_rref.local_value() + device = core.device(device) + + if device.type != "cuda": + return module.forward({args}, {kwargs}) + + # If the module is on a cuda device, + # move any CPU tensor in args or kwargs to the same cuda device. + # Since torch script does not support generator expression, + # have to use concatenation instead of + # ``tuple(i.to(device) if isinstance(i, Tensor) else i for i in *args)``. + args = ({args},) + out_args: Tuple[()] = () + for arg in args: + arg = (arg.to(device),) if isinstance(arg, Tensor) else (arg,) + out_args = out_args + arg + + kwargs = {{{kwargs}}} + for k, v in kwargs.items(): + if isinstance(v, Tensor): + kwargs[k] = kwargs[k].to(device) + + if is_device_map_set: + return module.forward(*out_args, {kwargs}) + + # If the device map is empty, then only CPU tensors are allowed to send over wire, + # so have to move any GPU tensor to CPU in the output. + # Since torch script does not support generator expression, + # have to use concatenation instead of + # ``tuple(i.cpu() if isinstance(i, Tensor) else i for i in module.forward(*out_args, {kwargs}))``. + ret: Tuple[()] = () + for i in module.forward(*out_args, {kwargs}): + i = (i.cpu(),) if isinstance(i, Tensor) else (i,) + ret = ret + i + return ret +""" + +_REMOTE_FORWARD_TEMPLATE = """ +def _remote_forward( + module_rref: RRef[module_interface_cls], device: str, is_device_map_set: bool, {arg_types}){arrow_and_return_type}: + module = module_rref.local_value() + + return module.forward({args}, {kwargs}) +""" diff --git a/mindnlp/core/distributed/optim/__init__.py b/mindnlp/core/distributed/optim/__init__.py new file mode 100644 index 000000000..2d4f1ac5d --- /dev/null +++ b/mindnlp/core/distributed/optim/__init__.py @@ -0,0 +1,43 @@ +""" +:mod:`core.distributed.optim` exposes DistributedOptimizer, which takes a list +of remote parameters (:class:`~core.distributed.rpc.RRef`) and runs the +optimizer locally on the workers where the parameters live. The distributed +optimizer can use any of the local optimizer :ref:`optimizer-algorithms` to +apply the gradients on each worker. +""" +import warnings + +from mindnlp import core +from mindnlp.core import optim + +from .apply_optimizer_in_backward import ( + _apply_optimizer_in_backward, + _get_in_backward_optimizers, +) +from .functional_adadelta import _FunctionalAdadelta +from .functional_adagrad import _FunctionalAdagrad +from .functional_adam import _FunctionalAdam +from .functional_adamax import _FunctionalAdamax +from .functional_adamw import _FunctionalAdamW +from .functional_rmsprop import _FunctionalRMSprop +from .functional_rprop import _FunctionalRprop +from .functional_sgd import _FunctionalSGD +from .named_optimizer import _NamedOptimizer +from .utils import as_functional_optim + + +# DistributedOptimizer imports core.distributed.rpc names, so gate availability +# based on RPC being available. +if hasattr(core._C, "_rpc_init"): + from .optimizer import DistributedOptimizer + +from .post_localSGD_optimizer import PostLocalSGDOptimizer +from .zero_redundancy_optimizer import ZeroRedundancyOptimizer + + +__all__ = [ + "as_functional_optim", + "DistributedOptimizer", + "PostLocalSGDOptimizer", + "ZeroRedundancyOptimizer", +] diff --git a/mindnlp/core/distributed/optim/_deprecation_warning.py b/mindnlp/core/distributed/optim/_deprecation_warning.py new file mode 100644 index 000000000..5af6feb1c --- /dev/null +++ b/mindnlp/core/distributed/optim/_deprecation_warning.py @@ -0,0 +1,16 @@ +import warnings + +from mindnlp import core + + +@core.jit.ignore # type: ignore[misc] +def _scripted_functional_optimizer_deprecation_warning(stacklevel: int = 0) -> None: + with warnings.catch_warnings(): + warnings.simplefilter("always") + warnings.warn( + "`TorchScript` support for functional optimizers is deprecated " + "and will be removed in a future PyTorch release. " + "Consider using the `core.compile` optimizer instead.", + DeprecationWarning, + stacklevel=stacklevel + 2, + ) diff --git a/mindnlp/core/distributed/optim/apply_optimizer_in_backward.py b/mindnlp/core/distributed/optim/apply_optimizer_in_backward.py new file mode 100644 index 000000000..02b393727 --- /dev/null +++ b/mindnlp/core/distributed/optim/apply_optimizer_in_backward.py @@ -0,0 +1,120 @@ +from typing import Any, Dict, Iterable, List, no_type_check, Type + +from mindnlp import core + + +__all__: List[str] = [] + +# WeakTensorKeyDictionary to store relevant meta-data for the Tensor/Parameter +# without changing it's life-time. +# NOTE: Alternative is to add the meta-data as an attribute to the tensor, +# but that will serialize the meta-data if Tensor is serialized. +param_to_optim_hook_handle_map = core.utils.weak.WeakTensorKeyDictionary() +param_to_acc_grad_map = core.utils.weak.WeakTensorKeyDictionary() + + +@no_type_check +def _apply_optimizer_in_backward( + optimizer_class: Type[core.optim.Optimizer], + params: Iterable[core.nn.Parameter], + optimizer_kwargs: Dict[str, Any], + register_hook: bool = True, +) -> None: + """ + Upon ``backward()``, the optimizer specified for each parameter will fire after + the gradient has been accumulated into the parameter. + + Note - gradients for these parameters will be set to None after ``backward()``. + This means that any other optimizer not specified via `_apply_optimizer_in_backward` + over this parameter will be a no-op. + + Args: + optimizer_class: (Type[core.optim.Optimizer]): Optimizer to apply to parameter + params: (Iterator[nn.Parameter]): parameters to apply optimizer state to + optimizer_kwargs: (Dict[str, Any]): kwargs to pass to optimizer constructor + register_hook: (bool): whether to register a hook that runs the optimizer + after gradient for this parameter is accumulated. This is the default + way that optimizer in backward is implemented, but specific use cases + (such as DDP) may wish to override this to implement custom behavior. + (Default = True) + + Example:: + params_generator = model.parameters() + param_1 = next(params_generator) + remainder_params = list(params_generator) + + apply_optimizer_in_backward(core.optim.SGD, [param_1], {"lr": .02}) + apply_optimizer_in_backward(core.optim.Adam, remainder_params, {"lr": .04}) + + model(...).sum().backward() # after backward, parameters will already + # have their registered optimizer(s) applied. + + """ + core._C._log_api_usage_once("core.distributed.optim.apply_optimizer_in_backward") + + @no_type_check + def _apply_optimizer_in_backward_to_param(param: core.nn.Parameter) -> None: + # view_as creates a node in autograd graph that allows us access to the + # parameter's AccumulateGrad autograd function object. We register a + # hook on this object to fire the optimizer when the gradient for + # this parameter is ready (has been accumulated into .grad field) + + # Don't create a new acc_grad if we already have one + # i.e. for shared parameters or attaching multiple optimizers to a param. + if param not in param_to_acc_grad_map: + param_to_acc_grad_map[param] = param.view_as(param).grad_fn.next_functions[ + 0 + ][0] + + optimizer = optimizer_class([param], **optimizer_kwargs) + + if not hasattr(param, "_in_backward_optimizers"): + param._in_backward_optimizers = [] # type: ignore[attr-defined] + # TODO: Remove these attributes once we have a better way of accessing + # optimizer classes and kwargs for a parameter. + param._optimizer_classes = [] # type: ignore[attr-defined] + param._optimizer_kwargs = [] # type: ignore[attr-defined] + + param._in_backward_optimizers.append(optimizer) # type: ignore[attr-defined] + param._optimizer_classes.append(optimizer_class) # type: ignore[attr-defined] + param._optimizer_kwargs.append(optimizer_kwargs) # type: ignore[attr-defined] + + if not register_hook: + return + + def optimizer_hook(*_unused) -> None: + for opt in param._in_backward_optimizers: # type: ignore[attr-defined] + opt.step() + + param.grad = None + + handle = param_to_acc_grad_map[param].register_hook(optimizer_hook) # type: ignore[attr-defined] + if param not in param_to_optim_hook_handle_map: + param_to_optim_hook_handle_map[param] = [] + param_to_optim_hook_handle_map[param].append(handle) + + for param in params: + _apply_optimizer_in_backward_to_param(param) + + +def _get_in_backward_optimizers(module: core.nn.Module) -> List[core.optim.Optimizer]: + """ + Return a list of in-backward optimizers applied to ``module``'s parameters. Note that these + optimizers are not intended to directly have their ``step`` or ``zero_grad`` methods called + by the user and are intended to be used for things like checkpointing. + + Args: + module: (core.nn.Module): model to retrieve in-backward optimizers for + + Returns: + List[core.optim.Optimizer]: the in-backward optimizers. + + Example:: + _apply_optimizer_in_backward(core.optim.SGD, model.parameters(), {'lr': 0.01}) + optims = _get_optimizers_in_backward(model) + """ + optims: List[core.optim.Optimizer] = [] + for param in module.parameters(): + optims.extend(getattr(param, "_in_backward_optimizers", [])) + + return optims diff --git a/mindnlp/core/distributed/optim/functional_adadelta.py b/mindnlp/core/distributed/optim/functional_adadelta.py new file mode 100644 index 000000000..bbdb18cdb --- /dev/null +++ b/mindnlp/core/distributed/optim/functional_adadelta.py @@ -0,0 +1,111 @@ +# mypy: allow-untyped-defs +from typing import Dict, List, Optional + +from mindnlp import core +from mindnlp import core.optim._functional as F +from mindnlp.core import Tensor +from core.distributed.optim._deprecation_warning import ( + _scripted_functional_optimizer_deprecation_warning, +) + + +__all__: List[str] = [] + + +# Define a TorchScript compatible Functional Adadelta Optimizer +# where we use these optimizer in a functional way. +# Instead of using the `param.grad` when updating parameters, +# we explicitly allow the distributed optimizer pass gradients to +# the `step` function. In this way, we could separate the gradients +# and parameters and allow multithreaded trainer to update the +# parameters without data traces on accumulating to the same .grad. +# NOTE: This should be only used by distributed optimizer internals +# and not meant to expose to the user. +@core.jit.script +class _FunctionalAdadelta: + def __init__( + self, + params: List[Tensor], + lr: float = 1.0, + rho: float = 0.9, + eps: float = 1e-6, + weight_decay: float = 0.0, + foreach: bool = False, + maximize: bool = False, + _allow_empty_param_list: bool = False, + ): + _scripted_functional_optimizer_deprecation_warning(stacklevel=2) + self.defaults = { + "lr": lr, + "rho": rho, + "eps": eps, + "weight_decay": weight_decay, + } + self.foreach = foreach + self.maximize = maximize + + if len(params) == 0 and not _allow_empty_param_list: + raise ValueError("optimizer got an empty parameter list") + + # NOTE: we only have one param_group and don't allow user to add additional + # param group as it's not a common use case. + self.param_group = {"params": params} + + self.state = core.jit.annotate(Dict[core.Tensor, Dict[str, core.Tensor]], {}) + + def step(self, gradients: List[Optional[Tensor]]): + params = self.param_group["params"] + params_with_grad = [] + grads = [] + square_avgs = [] + acc_deltas = [] + state_steps = [] + lr = self.defaults["lr"] + rho = self.defaults["rho"] + eps = self.defaults["eps"] + weight_decay = self.defaults["weight_decay"] + + if len(params) != len(gradients): + raise ValueError( + "the gradients passed in does not equal to the size of the parameters!" + + f"Params length: {len(params)}. " + + f"Gradients length: {len(gradients)}" + ) + has_complex = False + for param, gradient in zip(params, gradients): + if gradient is not None: + has_complex |= core.is_complex(param) + params_with_grad.append(param) + grads.append(gradient) + # Lazy state initialization + if param not in self.state: + self.state[param] = {} + state = self.state[param] + state["step"] = core.tensor(0.0) + state["square_avg"] = core.zeros_like( + param, memory_format=core.preserve_format + ) + state["acc_delta"] = core.zeros_like( + param, memory_format=core.preserve_format + ) + + state = self.state[param] + square_avgs.append(state["square_avg"]) + acc_deltas.append(state["acc_delta"]) + state_steps.append(state["step"]) + + with core.no_grad(): + F.adadelta( + params_with_grad, + grads, + square_avgs, + acc_deltas, + state_steps, + lr=lr, + rho=rho, + eps=eps, + weight_decay=weight_decay, + foreach=self.foreach, + maximize=self.maximize, + has_complex=has_complex, + ) diff --git a/mindnlp/core/distributed/optim/functional_adagrad.py b/mindnlp/core/distributed/optim/functional_adagrad.py new file mode 100644 index 000000000..80d7d26ae --- /dev/null +++ b/mindnlp/core/distributed/optim/functional_adagrad.py @@ -0,0 +1,115 @@ +# mypy: allow-untyped-defs +from typing import Dict, List, Optional + +from mindnlp import core +from mindnlp import core.optim._functional as F +from mindnlp.core import Tensor +from core.distributed.optim._deprecation_warning import ( + _scripted_functional_optimizer_deprecation_warning, +) + + +__all__: List[str] = [] + + +# Define a TorchScript compatible Functional Adagrad Optimizer +# where we use these optimizer in a functional way. +# Instead of using the `param.grad` when updating parameters, +# we explicitly let the user pass gradients to the `step` function +# this is so that we could separate the gradients and parameters +# and allow multithreaded trainer to update the parameters +# without data traces on accumulating to the same .grad. +# NOTE: This should be only used by distributed optimizer internals +# and not meant to expose to the user. +@core.jit.script +class _FunctionalAdagrad: + def __init__( + self, + params: List[Tensor], + lr: float = 1e-2, + lr_decay: float = 0.0, + weight_decay: float = 0.0, + initial_accumulator_value: float = 0.0, + warmup_lr_multiplier: float = 1.0, + warmup_num_iters: float = 0.0, + eps: float = 1e-10, + coalesce_grad: bool = True, + foreach: bool = False, + fused: bool = False, + maximize: bool = False, + _allow_empty_param_list: bool = False, + ): + _scripted_functional_optimizer_deprecation_warning(stacklevel=2) + self.defaults = { + "lr": lr, + "lr_decay": lr_decay, + "eps": eps, + "weight_decay": weight_decay, + "initial_accumulator_value": initial_accumulator_value, + "warmup_lr_multiplier": warmup_lr_multiplier, + "warmup_num_iters": warmup_num_iters, + } + self.coalesce_grad = coalesce_grad + self.foreach = foreach + self.fused = fused + self.maximize = maximize + self.state = core.jit.annotate(Dict[core.Tensor, Dict[str, core.Tensor]], {}) + + if len(params) == 0 and not _allow_empty_param_list: + raise ValueError("optimizer got an empty parameter list") + + # NOTE: we only have one param_group and don't allow user to add additional + # param group as it's not a common use case. + self.param_group = {"params": params} + + # TODO: no union or any types in TorchScript, make step a scalar tensor instead + # This is also needed by if we want to share_memory on the step across processes + for p in self.param_group["params"]: + self.state[p] = { + "sum": core.full_like(p.data, initial_accumulator_value), + "step": core.tensor(0.0), + } + + def step(self, gradients: List[Optional[Tensor]]): + params = self.param_group["params"] + params_with_grad = [] + grads = [] + state_sums = [] + state_steps: List[Tensor] = [] + + if len(params) != len(gradients): + raise ValueError( + "the gradients passed in does not equal to the size of the parameters!" + + f"Params length: {len(params)}. " + + f"Gradients length: {len(gradients)}" + ) + + has_sparse_grad, has_complex = False, False + for param, gradient in zip(self.param_group["params"], gradients): + if gradient is not None: + has_sparse_grad |= gradient.is_sparse + has_complex |= core.is_complex(param) + params_with_grad.append(param) + grads.append(gradient) + state = self.state[param] + state_sums.append(state["sum"]) + state_steps.append(state["step"]) + + with core.no_grad(): + F.adagrad( + params, + grads, + state_sums, + state_steps, + lr=self.defaults["lr"], + weight_decay=self.defaults["weight_decay"], + lr_decay=self.defaults["lr_decay"], + eps=self.defaults["eps"], + has_sparse_grad=has_sparse_grad, + foreach=self.foreach, + maximize=self.maximize, + has_complex=has_complex, + fused=self.fused, + grad_scale=None, + found_inf=None, + ) diff --git a/mindnlp/core/distributed/optim/functional_adam.py b/mindnlp/core/distributed/optim/functional_adam.py new file mode 100644 index 000000000..65467833d --- /dev/null +++ b/mindnlp/core/distributed/optim/functional_adam.py @@ -0,0 +1,202 @@ +# mypy: allow-untyped-defs +from typing import Dict, List, Optional, Tuple + +from mindnlp import core +from mindnlp import core.optim._functional as F +from mindnlp.core import Tensor +from core.distributed.optim._deprecation_warning import ( + _scripted_functional_optimizer_deprecation_warning, +) + + +__all__: List[str] = [] + + +# Define a TorchScript compatible Functional Adam Optimizer +# where we use these optimizer in a functional way. +# Instead of using the `param.grad` when updating parameters, +# we explicitly allow the distributed optimizer pass gradients to +# the `step` function. In this way, we could separate the gradients +# and parameters and allow multithreaded trainer to update the +# parameters without data traces on accumulating to the same .grad. +# NOTE: This should be only used by distributed optimizer internals +# and not meant to expose to the user. +@core.jit.script +class _FunctionalAdam: + def __init__( + self, + params: List[Tensor], + lr: float = 1e-3, + betas: Tuple[float, float] = (0.9, 0.999), + eps: float = 1e-8, + weight_decay: float = 0.0, + amsgrad: bool = False, + maximize: bool = False, + foreach: bool = False, + fused: bool = False, + _allow_empty_param_list: bool = False, + ): + _scripted_functional_optimizer_deprecation_warning(stacklevel=2) + if not 0.0 <= lr: + raise ValueError(f"Invalid learning rate: {lr}") + if not 0.0 <= eps: + raise ValueError(f"Invalid epsilon value: {eps}") + if not 0.0 <= betas[0] < 1.0: + raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}") + if not 0.0 <= betas[1] < 1.0: + raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}") + if not 0.0 <= weight_decay: + raise ValueError(f"Invalid weight_decay value: {weight_decay}") + + self.defaults = { + "lr": lr, + "eps": eps, + "beta1": betas[0], + "beta2": betas[1], + "weight_decay": weight_decay, + } + self.amsgrad = amsgrad + self.maximize = maximize + self.foreach = foreach + self.fused = fused + self.state = core.jit.annotate(Dict[core.Tensor, Dict[str, core.Tensor]], {}) + + if len(params) == 0 and not _allow_empty_param_list: + raise ValueError("optimizer got an empty parameter list") + + # NOTE: we only have one param_group and don't allow user to add additional + # param group as it's not a common use case. + self.param_group = {"params": params} + + def step_param(self, param: Tensor, grad: Optional[Tensor]): + """ + Similar to step, but operates on a single parameter and optionally a + gradient tensor. + """ + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + max_exp_avg_sqs = [] + state_steps: List[Tensor] = [] + has_complex = core.is_complex(param) + if grad is not None: + params_with_grad.append(param) + grads.append(grad) + if param not in self.state: + self.state[param] = {} + state = self.state[param] + state["step"] = core.tensor(0.0) + state["exp_avg"] = core.zeros_like( + param, memory_format=core.preserve_format + ) + state["exp_avg_sq"] = core.zeros_like( + param, memory_format=core.preserve_format + ) + if self.amsgrad: + state["max_exp_avg_sq"] = core.zeros_like( + param, memory_format=core.preserve_format + ) + + state = self.state[param] + exp_avgs.append(state["exp_avg"]) + exp_avg_sqs.append(state["exp_avg_sq"]) + + if self.amsgrad: + max_exp_avg_sqs.append(state["max_exp_avg_sq"]) + + state_steps.append(state["step"]) + with core.no_grad(): + F.adam( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=self.amsgrad, + has_complex=has_complex, + maximize=self.maximize, + beta1=self.defaults["beta1"], + beta2=self.defaults["beta2"], + lr=self.defaults["lr"], + weight_decay=self.defaults["weight_decay"], + eps=self.defaults["eps"], + foreach=self.foreach, + fused=self.fused, + grad_scale=None, + found_inf=None, + ) + + def step(self, gradients: List[Optional[Tensor]]): + params = self.param_group["params"] + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + max_exp_avg_sqs = [] + state_steps: List[Tensor] = [] + has_complex = False + + if len(params) != len(gradients): + raise ValueError( + "the gradients passed in does not equal to the size of the parameters!" + + f"Params length: {len(params)}. " + + f"Gradients length: {len(gradients)}" + ) + + for param, gradient in zip(self.param_group["params"], gradients): + if gradient is not None: + has_complex |= core.is_complex(param) + params_with_grad.append(param) + grads.append(gradient) + # Lazy state initialization + if param not in self.state: + self.state[param] = {} + state = self.state[param] + state["step"] = core.tensor(0.0) + # Exponential moving average of gradient values + state["exp_avg"] = core.zeros_like( + param, memory_format=core.preserve_format + ) + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = core.zeros_like( + param, memory_format=core.preserve_format + ) + if self.amsgrad: + # Maintains max of all exp. moving avg. of sq. grad. values + state["max_exp_avg_sq"] = core.zeros_like( + param, memory_format=core.preserve_format + ) + + state = self.state[param] + + exp_avgs.append(state["exp_avg"]) + exp_avg_sqs.append(state["exp_avg_sq"]) + + if self.amsgrad: + max_exp_avg_sqs.append(state["max_exp_avg_sq"]) + + state_steps.append(state["step"]) + + with core.no_grad(): + F.adam( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=self.amsgrad, + has_complex=has_complex, + maximize=self.maximize, + beta1=self.defaults["beta1"], + beta2=self.defaults["beta2"], + lr=self.defaults["lr"], + weight_decay=self.defaults["weight_decay"], + eps=self.defaults["eps"], + foreach=self.foreach, + fused=self.fused, + grad_scale=None, + found_inf=None, + ) diff --git a/mindnlp/core/distributed/optim/functional_adamax.py b/mindnlp/core/distributed/optim/functional_adamax.py new file mode 100644 index 000000000..4b645c4a3 --- /dev/null +++ b/mindnlp/core/distributed/optim/functional_adamax.py @@ -0,0 +1,123 @@ +# mypy: allow-untyped-defs +from typing import Dict, List, Optional, Tuple + +from mindnlp import core +from mindnlp import core.optim._functional as F +from mindnlp.core import Tensor +from core.distributed.optim._deprecation_warning import ( + _scripted_functional_optimizer_deprecation_warning, +) + + +__all__: List[str] = [] + + +# Define a TorchScript compatible Functional Adamax Optimizer +# where we use these optimizer in a functional way. +# Instead of using the `param.grad` when updating parameters, +# we explicitly allow the distributed optimizer pass gradients to +# the `step` function. In this way, we could separate the gradients +# and parameters and allow multithreaded trainer to update the +# parameters without data traces on accumulating to the same .grad. +# NOTE: This should be only used by distributed optimizer internals +# and not meant to expose to the user. +@core.jit.script +class _FunctionalAdamax: + def __init__( + self, + params: List[Tensor], + lr: float = 1e-3, + betas: Tuple[float, float] = (0.9, 0.999), + eps: float = 1e-8, + weight_decay: float = 0.0, + foreach: bool = False, + maximize: bool = False, + _allow_empty_param_list: bool = False, + ): + _scripted_functional_optimizer_deprecation_warning(stacklevel=2) + if not 0.0 <= lr: + raise ValueError(f"Invalid learning rate: {lr}") + if not 0.0 <= eps: + raise ValueError(f"Invalid epsilon value: {eps}") + if not 0.0 <= betas[0] < 1.0: + raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}") + if not 0.0 <= betas[1] < 1.0: + raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}") + if not 0.0 <= weight_decay: + raise ValueError(f"Invalid weight_decay value: {weight_decay}") + + self.defaults = { + "lr": lr, + "eps": eps, + "beta1": betas[0], + "beta2": betas[1], + "weight_decay": weight_decay, + } + self.foreach = foreach + self.maximize = maximize + self.state = core.jit.annotate(Dict[core.Tensor, Dict[str, core.Tensor]], {}) + + if len(params) == 0 and not _allow_empty_param_list: + raise ValueError("optimizer got an empty parameter list") + + # NOTE: we only have one param_group and don't allow user to add additional + # param group as it's not a common use case. + self.param_group = {"params": params} + + def step(self, gradients: List[Optional[Tensor]]): + params = self.param_group["params"] + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_infs = [] + state_steps: List[Tensor] = [] + + if len(params) != len(gradients): + raise ValueError( + "the gradients passed in does not equal to the size of the parameters!" + + f"Params length: {len(params)}. " + + f"Gradients length: {len(gradients)}" + ) + + has_complex = False + for param, gradient in zip(self.param_group["params"], gradients): + if gradient is not None: + has_complex |= core.is_complex(param) + params_with_grad.append(param) + grads.append(gradient) + # Lazy state initialization + if param not in self.state: + self.state[param] = {} + state = self.state[param] + state["step"] = core.tensor(0.0) + # Exponential moving average of gradient values + state["exp_avg"] = core.zeros_like( + param, memory_format=core.preserve_format + ) + # Exponential moving average of squared gradient values + state["exp_inf"] = core.zeros_like( + param, memory_format=core.preserve_format + ) + + state = self.state[param] + + exp_avgs.append(state["exp_avg"]) + exp_infs.append(state["exp_inf"]) + state_steps.append(state["step"]) + + with core.no_grad(): + F.adamax( + params_with_grad, + grads, + exp_avgs, + exp_infs, + state_steps, + eps=self.defaults["eps"], + beta1=self.defaults["beta1"], + beta2=self.defaults["beta2"], + lr=self.defaults["lr"], + weight_decay=self.defaults["weight_decay"], + foreach=self.foreach, + maximize=self.maximize, + has_complex=has_complex, + ) diff --git a/mindnlp/core/distributed/optim/functional_adamw.py b/mindnlp/core/distributed/optim/functional_adamw.py new file mode 100644 index 000000000..e19e8298f --- /dev/null +++ b/mindnlp/core/distributed/optim/functional_adamw.py @@ -0,0 +1,203 @@ +# mypy: allow-untyped-defs +from typing import Dict, List, Optional, Tuple + +from mindnlp import core +from mindnlp import core.optim._functional as F +from mindnlp.core import Tensor +from core.distributed.optim._deprecation_warning import ( + _scripted_functional_optimizer_deprecation_warning, +) + + +__all__: List[str] = [] + + +# Define a TorchScript compatible Functional AdamW Optimizer +# where we use these optimizer in a functional way. +# Instead of using the `param.grad` when updating parameters, +# we explicitly allow the distributed optimizer pass gradients to +# the `step` function. In this way, we could separate the gradients +# and parameters and allow multithreaded trainer to update the +# parameters without data traces on accumulating to the same .grad. +# NOTE: This should be only used by distributed optimizer internals +# and not meant to expose to the user. +@core.jit.script +class _FunctionalAdamW: + def __init__( + self, + params: List[Tensor], + lr: float = 1e-3, + betas: Tuple[float, float] = (0.9, 0.999), + eps: float = 1e-8, + weight_decay: float = 1e-2, + amsgrad: bool = False, + maximize: bool = False, + foreach: bool = False, + fused: bool = False, + _allow_empty_param_list: bool = False, + ): + _scripted_functional_optimizer_deprecation_warning(stacklevel=2) + if not 0.0 <= lr: + raise ValueError(f"Invalid learning rate: {lr}") + if not 0.0 <= eps: + raise ValueError(f"Invalid epsilon value: {eps}") + if not 0.0 <= betas[0] < 1.0: + raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}") + if not 0.0 <= betas[1] < 1.0: + raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}") + if not 0.0 <= weight_decay: + raise ValueError(f"Invalid weight_decay value: {weight_decay}") + + self.defaults = { + "lr": lr, + "eps": eps, + "beta1": betas[0], + "beta2": betas[1], + "weight_decay": weight_decay, + } + self.amsgrad = amsgrad + self.maximize = maximize + self.foreach = foreach + self.fused = fused + self.state = core.jit.annotate(Dict[core.Tensor, Dict[str, core.Tensor]], {}) + + if len(params) == 0 and not _allow_empty_param_list: + raise ValueError("optimizer got an empty parameter list") + + # NOTE: we only have one param_group and don't allow user to add additional + # param group as it's not a common use case. + self.param_group = {"params": params} + + def step_param(self, param: Tensor, grad: Optional[Tensor]): + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + max_exp_avg_sqs = [] + state_steps: List[Tensor] = [] + has_complex = core.is_complex(param) + if grad is not None: + params_with_grad.append(param) + grads.append(grad) + # Lazy state initialization + if param not in self.state: + self.state[param] = {} + state = self.state[param] + state["step"] = core.tensor(0.0) + # Exponential moving average of gradient values + state["exp_avg"] = core.zeros_like( + param, memory_format=core.preserve_format + ) + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = core.zeros_like( + param, memory_format=core.preserve_format + ) + if self.amsgrad: + # Maintains max of all exp. moving avg. of sq. grad. values + state["max_exp_avg_sq"] = core.zeros_like( + param, memory_format=core.preserve_format + ) + + state = self.state[param] + + exp_avgs.append(state["exp_avg"]) + exp_avg_sqs.append(state["exp_avg_sq"]) + + if self.amsgrad: + max_exp_avg_sqs.append(state["max_exp_avg_sq"]) + + state_steps.append(state["step"]) + with core.no_grad(): + F.adamw( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=self.amsgrad, + maximize=self.maximize, + beta1=self.defaults["beta1"], + beta2=self.defaults["beta2"], + lr=self.defaults["lr"], + weight_decay=self.defaults["weight_decay"], + eps=self.defaults["eps"], + foreach=self.foreach, + fused=self.fused, + grad_scale=None, + found_inf=None, + has_complex=has_complex, + ) + + def step(self, gradients: List[Optional[Tensor]]): + params = self.param_group["params"] + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + max_exp_avg_sqs = [] + state_steps: List[Tensor] = [] + + if len(params) != len(gradients): + raise ValueError( + "the gradients passed in does not equal to the size of the parameters!" + + f"Params length: {len(params)}. " + + f"Gradients length: {len(gradients)}" + ) + + has_complex = False + for param, gradient in zip(self.param_group["params"], gradients): + if gradient is not None: + has_complex |= core.is_complex(param) + params_with_grad.append(param) + grads.append(gradient) + # Lazy state initialization + if param not in self.state: + self.state[param] = {} + state = self.state[param] + state["step"] = core.tensor(0.0) + # Exponential moving average of gradient values + state["exp_avg"] = core.zeros_like( + param, memory_format=core.preserve_format + ) + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = core.zeros_like( + param, memory_format=core.preserve_format + ) + if self.amsgrad: + # Maintains max of all exp. moving avg. of sq. grad. values + state["max_exp_avg_sq"] = core.zeros_like( + param, memory_format=core.preserve_format + ) + + state = self.state[param] + + exp_avgs.append(state["exp_avg"]) + exp_avg_sqs.append(state["exp_avg_sq"]) + + if self.amsgrad: + max_exp_avg_sqs.append(state["max_exp_avg_sq"]) + + state_steps.append(state["step"]) + + with core.no_grad(): + F.adamw( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=self.amsgrad, + maximize=self.maximize, + beta1=self.defaults["beta1"], + beta2=self.defaults["beta2"], + lr=self.defaults["lr"], + weight_decay=self.defaults["weight_decay"], + eps=self.defaults["eps"], + foreach=self.foreach, + fused=self.fused, + grad_scale=None, + found_inf=None, + has_complex=has_complex, + ) diff --git a/mindnlp/core/distributed/optim/functional_rmsprop.py b/mindnlp/core/distributed/optim/functional_rmsprop.py new file mode 100644 index 000000000..45d5dc697 --- /dev/null +++ b/mindnlp/core/distributed/optim/functional_rmsprop.py @@ -0,0 +1,130 @@ +# mypy: allow-untyped-defs +from typing import Dict, List, Optional + +from mindnlp import core +from mindnlp import core.optim._functional as F +from mindnlp.core import Tensor +from core.distributed.optim._deprecation_warning import ( + _scripted_functional_optimizer_deprecation_warning, +) + + +__all__: List[str] = [] + + +# Define a TorchScript compatible Functional RMSprop Optimizer +# where we use these optimizer in a functional way. +# Instead of using the `param.grad` when updating parameters, +# we explicitly allow the distributed optimizer pass gradients to +# the `step` function. In this way, we could separate the gradients +# and parameters and allow multithreaded trainer to update the +# parameters without data traces on accumulating to the same .grad. +# NOTE: This should be only used by distributed optimizer internals +# and not meant to expose to the user. +@core.jit.script +class _FunctionalRMSprop: + def __init__( + self, + params: List[Tensor], + lr: float = 1e-2, + alpha: float = 0.99, + eps: float = 1e-8, + weight_decay: float = 0.0, + momentum: float = 0.0, + centered: bool = False, + foreach: bool = False, + maximize: bool = False, + _allow_empty_param_list: bool = False, + ): + _scripted_functional_optimizer_deprecation_warning(stacklevel=2) + self.defaults = { + "lr": lr, + "alpha": alpha, + "eps": eps, + "weight_decay": weight_decay, + "momentum": momentum, + } + self.centered = centered + self.foreach = foreach + self.maximize = maximize + + if len(params) == 0 and not _allow_empty_param_list: + raise ValueError("optimizer got an empty parameter list") + + # NOTE: we only have one param_group and don't allow user to add additional + # param group as it's not a common use case. + self.param_group = {"params": params} + + self.state = core.jit.annotate(Dict[core.Tensor, Dict[str, core.Tensor]], {}) + + def step(self, gradients: List[Optional[Tensor]]): + params = self.param_group["params"] + params_with_grad = [] + grads = [] + square_avgs = [] + grad_avgs = [] + momentum_buffer_list = [] + state_steps = [] + lr = self.defaults["lr"] + alpha = self.defaults["alpha"] + eps = self.defaults["eps"] + momentum = self.defaults["momentum"] + weight_decay = self.defaults["weight_decay"] + + if len(params) != len(gradients): + raise ValueError( + "the gradients passed in does not equal to the size of the parameters!" + + f"Params length: {len(params)}. " + + f"Gradients length: {len(gradients)}" + ) + + has_complex = False + for param, gradient in zip(params, gradients): + if gradient is not None: + has_complex |= core.is_complex(param) + params_with_grad.append(param) + grads.append(gradient) + # Lazy state initialization + if param not in self.state: + self.state[param] = {} + state = self.state[param] + state["step"] = core.tensor(0.0) + state["square_avg"] = core.zeros_like( + param, memory_format=core.preserve_format + ) + if momentum > 0: + state["momentum_buffer"] = core.zeros_like( + param, memory_format=core.preserve_format + ) + if self.centered: + state["grad_avg"] = core.zeros_like( + param, memory_format=core.preserve_format + ) + + state = self.state[param] + square_avgs.append(state["square_avg"]) + if momentum > 0: + momentum_buffer_list.append(state["momentum_buffer"]) + if self.centered: + grad_avgs.append(state["grad_avg"]) + + state_steps.append(state["step"]) + + with core.no_grad(): + F.rmsprop( + params_with_grad, + grads, + square_avgs, + grad_avgs, + momentum_buffer_list, + state_steps, + lr=lr, + alpha=alpha, + eps=eps, + weight_decay=weight_decay, + momentum=momentum, + centered=self.centered, + foreach=self.foreach, + maximize=self.maximize, + has_complex=has_complex, + ) diff --git a/mindnlp/core/distributed/optim/functional_rprop.py b/mindnlp/core/distributed/optim/functional_rprop.py new file mode 100644 index 000000000..eaa3daa0a --- /dev/null +++ b/mindnlp/core/distributed/optim/functional_rprop.py @@ -0,0 +1,107 @@ +# mypy: allow-untyped-defs +from typing import Dict, List, Optional, Tuple + +from mindnlp import core +from mindnlp import core.optim._functional as F +from mindnlp.core import Tensor +from core.distributed.optim._deprecation_warning import ( + _scripted_functional_optimizer_deprecation_warning, +) + + +__all__: List[str] = [] + + +# Define a TorchScript compatible Functional Rprop Optimizer +# where we use these optimizer in a functional way. +# Instead of using the `param.grad` when updating parameters, +# we explicitly allow the distributed optimizer pass gradients to +# the `step` function. In this way, we could separate the gradients +# and parameters and allow multithreaded trainer to update the +# parameters without data traces on accumulating to the same .grad. +# NOTE: This should be only used by distributed optimizer internals +# and not meant to expose to the user. +@core.jit.script +class _FunctionalRprop: + def __init__( + self, + params: List[Tensor], + lr: float = 1e-2, + etas: Tuple[float, float] = (0.5, 1.2), + step_sizes: Tuple[float, float] = (1e-6, 50), + foreach: bool = False, + maximize: bool = False, + _allow_empty_param_list: bool = False, + ): + _scripted_functional_optimizer_deprecation_warning(stacklevel=2) + self.defaults = { + "lr": lr, + } + self.etas = etas + self.step_sizes = step_sizes + self.foreach = foreach + self.maximize = maximize + + if len(params) == 0 and not _allow_empty_param_list: + raise ValueError("optimizer got an empty parameter list") + + # NOTE: we only have one param_group and don't allow user to add additional + # param group as it's not a common use case. + self.param_group = {"params": params} + + self.state = core.jit.annotate(Dict[core.Tensor, Dict[str, core.Tensor]], {}) + + def step(self, gradients: List[Optional[Tensor]]): + params = self.param_group["params"] + params_with_grad = [] + grads = [] + prevs = [] + step_sizes = [] + state_steps = [] + lr = self.defaults["lr"] + etaminus, etaplus = self.etas + step_size_min, step_size_max = self.step_sizes + + if len(params) != len(gradients): + raise ValueError( + "the gradients passed in does not equal to the size of the parameters!" + + f"Params length: {len(params)}. " + + f"Gradients length: {len(gradients)}" + ) + + has_complex = False + for param, gradient in zip(params, gradients): + if gradient is not None: + has_complex |= core.is_complex(param) + params_with_grad.append(param) + grads.append(gradient) + # Lazy state initialization + if param not in self.state: + self.state[param] = {} + state = self.state[param] + state["step"] = core.tensor(0.0) + state["prev"] = core.zeros_like( + param, memory_format=core.preserve_format + ) + state["step_size"] = core.full_like(gradient, lr) + + state = self.state[param] + prevs.append(state["prev"]) + step_sizes.append(state["step_size"]) + state_steps.append(state["step"]) + + with core.no_grad(): + F.rprop( + params_with_grad, + grads, + prevs, + step_sizes, + state_steps, + step_size_min=step_size_min, + step_size_max=step_size_max, + etaminus=etaminus, + etaplus=etaplus, + foreach=self.foreach, + maximize=self.maximize, + has_complex=has_complex, + ) diff --git a/mindnlp/core/distributed/optim/functional_sgd.py b/mindnlp/core/distributed/optim/functional_sgd.py new file mode 100644 index 000000000..b094abf81 --- /dev/null +++ b/mindnlp/core/distributed/optim/functional_sgd.py @@ -0,0 +1,166 @@ +# mypy: allow-untyped-defs +from typing import Dict, List, Optional + +from mindnlp import core +from mindnlp import core.optim._functional as F +from mindnlp.core import Tensor +from core.distributed.optim._deprecation_warning import ( + _scripted_functional_optimizer_deprecation_warning, +) + + +__all__: List[str] = [] + + +# Define a TorchScript compatible Functional SGD Optimizer +# where we use these optimizer in a functional way. +# Instead of using the `param.grad` when updating parameters, +# we explicitly allow the distributed optimizer pass gradients to +# the `step` function. In this way, we could separate the gradients +# and parameters and allow multithreaded trainer to update the +# parameters without data traces on accumulating to the same .grad. +# NOTE: This should be only used by distributed optimizer internals +# and not meant to expose to the user. +@core.jit.script +class _FunctionalSGD: + def __init__( + self, + params: List[Tensor], + lr: float = 1e-2, + momentum: float = 0.0, + dampening: float = 0.0, + weight_decay: float = 0.0, + nesterov: bool = False, + maximize: bool = False, + foreach: bool = False, + fused: bool = False, + _allow_empty_param_list: bool = False, + ): + _scripted_functional_optimizer_deprecation_warning(stacklevel=2) + self.defaults = { + "lr": lr, + "momentum": momentum, + "dampening": dampening, + "weight_decay": weight_decay, + } + self.nesterov = nesterov + self.maximize = maximize + self.foreach = foreach + self.fused = fused + self.state = core.jit.annotate(Dict[core.Tensor, Dict[str, core.Tensor]], {}) + + if len(params) == 0 and not _allow_empty_param_list: + raise ValueError("optimizer got an empty parameter list") + + # NOTE: we only have one param_group and don't allow user to add additional + # param group as it's not a common use case. + self.param_group = {"params": params} + + def step_param(self, param: Tensor, grad: Optional[Tensor]): + """Similar to self.step, but operates on a single parameter and + its gradient. + """ + # TODO: Once step_param interface is robust, refactor step to call + # step param on each param. + weight_decay = self.defaults["weight_decay"] + momentum = self.defaults["momentum"] + dampening = self.defaults["dampening"] + lr = self.defaults["lr"] + params = [param] + momentum_buffer_list: List[Optional[Tensor]] = [] + grads = [] + + has_sparse_grad = False + if grad is not None: + grads.append(grad) + if grad.is_sparse: + has_sparse_grad = True + if param not in self.state: + self.state[param] = {} + state = self.state[param] + if "momentum_buffer" not in state: + momentum_buffer_list.append(None) + else: + momentum_buffer_list.append(state["momentum_buffer"]) + + with core.no_grad(): + F.sgd( + params, + grads, + momentum_buffer_list, + weight_decay=weight_decay, + momentum=momentum, + lr=lr, + dampening=dampening, + nesterov=self.nesterov, + maximize=self.maximize, + has_sparse_grad=has_sparse_grad, + foreach=self.foreach, + fused=self.fused, + grad_scale=None, + found_inf=None, + ) + # update momentum_buffer in state + state = self.state[param] + momentum_buffer = momentum_buffer_list[0] + if momentum_buffer is not None: + state["momentum_buffer"] = momentum_buffer + + def step(self, gradients: List[Optional[Tensor]]): + params = self.param_group["params"] + params_with_grad = [] + grads = [] + momentum_buffer_list: List[Optional[Tensor]] = [] + lr = self.defaults["lr"] + weight_decay = self.defaults["weight_decay"] + momentum = self.defaults["momentum"] + dampening = self.defaults["dampening"] + + if len(params) != len(gradients): + raise ValueError( + "the gradients passed in does not equal to the size of the parameters!" + + f"Params length: {len(params)}. " + + f"Gradients length: {len(gradients)}" + ) + + has_sparse_grad = False + for param, gradient in zip(params, gradients): + if gradient is not None: + params_with_grad.append(param) + grads.append(gradient) + if gradient.is_sparse: + has_sparse_grad = True + + if param not in self.state: + self.state[param] = {} + + state = self.state[param] + if "momentum_buffer" not in state: + momentum_buffer_list.append(None) + else: + momentum_buffer_list.append(state["momentum_buffer"]) + + with core.no_grad(): + F.sgd( + params_with_grad, + grads, + momentum_buffer_list, + weight_decay=weight_decay, + momentum=momentum, + lr=lr, + dampening=dampening, + nesterov=self.nesterov, + maximize=self.maximize, + has_sparse_grad=has_sparse_grad, + foreach=self.foreach, + fused=self.fused, + grad_scale=None, + found_inf=None, + ) + + # update momentum_buffers in state + for i, p in enumerate(params_with_grad): + state = self.state[p] + momentum_buffer = momentum_buffer_list[i] + if momentum_buffer is not None: + state["momentum_buffer"] = momentum_buffer diff --git a/mindnlp/core/distributed/optim/named_optimizer.py b/mindnlp/core/distributed/optim/named_optimizer.py new file mode 100644 index 000000000..0f3b9fccd --- /dev/null +++ b/mindnlp/core/distributed/optim/named_optimizer.py @@ -0,0 +1,339 @@ +# mypy: allow-untyped-defs +import logging +import warnings +from copy import deepcopy +from typing import ( + Any, + Callable, + Collection, + Dict, + List, + Mapping, + Optional, + overload, + Union, +) + +from mindnlp import core +from mindnlp import core.nn as nn +from mindnlp.core import optim +from core.distributed._shard.sharded_tensor import ShardedTensor +from core.distributed.fsdp import FullyShardedDataParallel as FSDP + + +__all__: List[str] = [] + +logger = logging.getLogger(__name__) + + +class _NamedOptimizer(optim.Optimizer): + """ + ``_NamedOptimizer`` takes a dict of parameters and exposes ``state_dict`` by parameter key. + + We replace the original key (number) in an optim to the + fully qualified name (FQN) string. User can initialize the optim as they + initialize a PyTorch optim, the only difference is that they also need to + pass in the FQN of each parameters. + + Args: + named_parameters (Mapping[str, Union[core.Tensor, ShardedTensor]]): + Mapping from FQN to parameter. + optimizer_class (optim.Optimizer): + The class of optimizer to instantiate. + param_groups (Collection[Mapping[str, Any]]): + `param_groups` to pass to optimizer if specified. + The key of the inner map needs to be FQNs. + Default: None + module (nn.Module): the module whose parameters to updated + by the optimizer. + args: arguments to pass to the optimizer constructor. + kwargs: arguments to pass to the optimizer constructor. + + Example:: + >>> # xdoctest: +SKIP("distributed") + >>> from mindnlp.core import optim + >>> from core.distributed.optim import _NamedOptimizer + >>> + >>> # Define the named optimizer. + >>> m = Model(...) + >>> named_optim = _NamedOptimizer(m.named_parameters(), optim.SGD) + >>> # Forward pass + backward pass. + >>> named_optim.step() + >>> ... + >>> # Call state_dict for the named optimizer returns a FQN state_dict. + >>> named_optim.state_dict() + + Warning: This API is still in development and subject to change. + + TODO: Add tutorial for _NamedOptimizer. + TODO: Add documentation in the docstring for the public attributes + like self.param_groups and self.named_parameters. + """ + + def __init__( + self, + named_parameters: Mapping[str, Union[core.Tensor, ShardedTensor]], + optimizer_class: optim.Optimizer, + param_groups: Optional[Collection[Mapping[str, Any]]] = None, + module: Optional[nn.Module] = None, + *args, + **kwargs, + ) -> None: + core._C._log_api_usage_once("core.distributed.optim._NamedOptimizer") + self.param_groups: Collection[Mapping[str, Any]] = param_groups # type: ignore[assignment] + self._param_groups_check() + self.named_parameters = dict(named_parameters) + params_for_optimizer = ( + self.named_parameters.values() if param_groups is None else param_groups + ) + self._optimizer = optimizer_class( # type: ignore[operator] + params_for_optimizer, + *args, + **kwargs, + ) + self.module = module + if param_groups is None: + self.ordered_param_keys = list(self.named_parameters.keys()) + else: + warnings.warn( + "Since we pass in param_groups, we will use param_groups to " + "initialize the optimizer, not all parameters of the module." + ) + param_to_key = {param: key for key, param in self.named_parameters.items()} # type: ignore[misc, has-type] + ordered_param_keys = [] + for group in param_groups: + for param in group["params"]: + if param not in param_to_key: + raise ValueError( + f"Expect param name {param} found in param group but is missing." + ) + ordered_param_keys.append(param_to_key[param]) + self.ordered_param_keys = ordered_param_keys + # Update param_groups from optimizer. + self.param_groups = self._optimizer.param_groups + + def _param_groups_check(self): + if self.param_groups is not None: + for param_group in self.param_groups: + assert isinstance(param_group, dict), "param group must be a dict" + assert "params" in param_group, "param group must contain key params" + params = param_group["params"] + if isinstance(params, core.Tensor): + params = [params] + params = list(params) + for param in params: + if not isinstance(param, core.Tensor): + raise TypeError( + "optimizer can only optimize Tensors, " + "but one of the params is " + core.typename(param) + ) + param_group["params"] = params + + def state_dict(self) -> Dict[str, Any]: + """ + Return the ``state_dict`` of the optimizer. + + Instead of using number to index + parameters, we will use module fully qualified name (FQN) as the key. + """ + state_dict = self._optimizer.state_dict() + param_groups = state_dict["param_groups"] + + ret_state = { + self.ordered_param_keys[st_key]: state_val + for st_key, state_val in state_dict["state"].items() + } + + ret_groups = [] + for group in param_groups: + param_keys = [self.ordered_param_keys[param] for param in group["params"]] + ret_group = {"params": sorted(param_keys)} + for k, v in group.items(): + if k != "params": + ret_group[k] = deepcopy(v) + ret_groups.append(ret_group) + + return self._post_state_dict({"state": ret_state, "param_groups": ret_groups}) + + @overload + def step(self, closure: None = ...) -> None: + ... + + @overload + def step(self, closure: Callable[[], float]) -> float: + ... + + def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]: + """ + Perform a single optimization step. + + This will call :meth:`core.optim.Optimizer.step` on the wrapped + optimizer. + """ + return self._optimizer.step(closure=closure) + + @property + def state(self) -> Mapping[core.Tensor, Any]: # type: ignore[override] + return self._optimizer.state + + def load_state_dict(self, state_dict: Mapping[str, Any]) -> None: + """ + Define the default behavior to load a state_dict for ``_NamedOptimizer``. + + Sample Code + ``` + my_model = MyModule() + optimizer = _NamedOptimizer(my_model.named_parameters(), Adagrad) + ... + + optim_state_dict = optimizer.state_dict() + ... + ... + + optimizer.load_state_dict(optim_state_dict) + ... + ``` + Args: + state_dict (Dict[str, Any]) : A ``state_dict`` to load into the optimizer. + Note that this state dict update is performed in place. + + .. note:: PyTorch is using lazy init to initialize the optim states. + So it is possible that there is no optim state when user call + ``load_state_dict`` and for ``_NamedOptimizer`` we make it stricter + that users can only call ``load_state_dict`` after the state is initialized. + By doing this, we can validate the optim ``state_dict`` to be loaded. + """ + new_state_dict = self._optimizer.state_dict() + state_dict = self._pre_load_state_dict(state_dict) + state = state_dict["state"] + new_state = new_state_dict["state"] + if len(new_state) == 0: + raise ValueError( + "Expects the optim to be initialized before load but found not initialized." + ) + + for idx, param_key in enumerate(self.ordered_param_keys): + # When the conditional training is performed, not all parameters are updated in the optim. + if param_key not in state.keys(): + continue + if len(state[param_key]) != len(new_state[idx]): + raise ValueError( + f"Expects equal length as {len(new_state[idx])} for parameter {param_key} but found: {len(state[param_key])}" + ) + # Iterate through all optimizer states. + for state_key, state_val in new_state[idx].items(): + if state_key not in state[param_key]: + raise ValueError( + f"Expects state {state_key} for parameter {param_key} but not found." + ) + + src_state_val = state[param_key][state_key] + if isinstance(state_val, ShardedTensor): + assert isinstance(src_state_val, ShardedTensor) + num_shards = len(state_val.local_shards()) + num_new_shards = len(src_state_val.local_shards()) + if num_shards != num_new_shards: + raise ValueError( + f"Expects equal number of shards as {num_new_shards} but found {num_shards} for {param_key}/{state_key}" + ) + for shard, src_shard in zip( + state_val.local_shards(), src_state_val.local_shards() + ): + shard.tensor.detach().copy_(src_shard.tensor) + elif isinstance(state_val, core.Tensor): + assert isinstance(src_state_val, core.Tensor) + state_val.detach().copy_(src_state_val) + else: + new_state[idx][state_key] = deepcopy(src_state_val) + + # Load param_groups of state_dict + src_param_groups = state_dict["param_groups"] + new_param_groups = new_state_dict["param_groups"] + + src_group_map = {} + for group in src_param_groups: + param_keys = list(group["params"]) + src_group_map[_gen_param_group_key(param_keys)] = group + new_group_map = {} + for new_group in new_param_groups: + param_keys = [] + for param_key in new_group["params"]: + param_keys.append(self.ordered_param_keys[param_key]) # type: ignore[call-overload] + new_group_map[_gen_param_group_key(param_keys)] = new_group + for group_key, new_group in new_group_map.items(): + # When not all parameters are used in training or receive gradient, aka., not all parameters + # would be in the param_group. Thus we skip the group_key here. + if group_key not in src_group_map: + continue + src_group = src_group_map[group_key] + if len(src_group) != len(new_group): + raise ValueError( + f"Expects equal param_group size as {len(new_group)} for group {group_key} but found {len(src_group)}." + ) + for k in src_group: + if k not in new_group: + raise ValueError( + f"Expects group key {k} to be in group {group_key} in `state_dict` but is missing." + ) + if k != "params": + new_group[k] = deepcopy(src_group[k]) + + self._optimizer.load_state_dict(new_state_dict) + + def add_param_group(self, param_group: Mapping[str, Any]) -> None: + """ + Add a param group to the :class:`_NamedOptimizer` s `param_groups`. + + Warning: This API is still in development and subject to change. + """ + assert isinstance(param_group, dict), "param group must be a dict" + + params = param_group["params"] + if isinstance(params, core.Tensor): + param_group["params"] = [params] + else: + param_group["params"] = list(params) + + param_to_key = {param: key for key, param in self.named_parameters.items()} # type: ignore[misc, has-type] + for param in param_group["params"]: + if param not in param_to_key: + raise ValueError("some parameters are not in the module") + self.ordered_param_keys.append(param_to_key[param]) + + self._optimizer.add_param_group(param_group) + # Update param_groups from optimizer. + self.param_groups = self._optimizer.param_groups + + def init_state(self) -> None: + """ + Run a dummy optimizer step, which allows to initialize optimizer state because we do lazy init for most optimizers. + + This allows doing in-place loading of optimizer state from a checkpoint. + """ + for param in self.named_parameters.values(): + if param.requires_grad: + t = core.zeros_like(param) + param.grad = core.autograd.Variable(t) + # Calling ``step`` will load the initial state for optimizer states. + self.step(closure=None) + + def _pre_load_state_dict(self, state_dict) -> Dict[str, Any]: + # TODO(chienchin): This API should be FSDP agnostic and should support + # general user hooks. + if isinstance(self.module, FSDP): + return FSDP.optim_state_dict_to_load( + self.module, self._optimizer, state_dict, is_named_optimizer=True + ) + return state_dict + + def _post_state_dict(self, state_dict) -> Dict[str, Any]: + # TODO(chienchin): This API should be FSDP agnostic and should support + # general user hooks. + if isinstance(self.module, FSDP): + FSDP.optim_state_dict(self.module, self._optimizer, state_dict) + return state_dict + + +def _gen_param_group_key(param_keys: List[str]) -> str: + """Concatenate all param keys as a unique indentifier for one param group.""" + return "/".join(sorted(param_keys)) diff --git a/mindnlp/core/distributed/optim/optimizer.py b/mindnlp/core/distributed/optim/optimizer.py new file mode 100644 index 000000000..4eaeb77e8 --- /dev/null +++ b/mindnlp/core/distributed/optim/optimizer.py @@ -0,0 +1,256 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +import logging +from collections import defaultdict +from threading import Lock +from typing import List, Optional + +from mindnlp import core +from mindnlp import core.distributed.autograd as dist_autograd +from mindnlp import core.distributed.rpc as rpc +from mindnlp import core.jit as jit +from mindnlp import core.nn as nn +from mindnlp.core import Tensor +from core.distributed.rpc import RRef + +from .utils import functional_optim_map + + +__all__ = ["DistributedOptimizer"] + +logger = logging.getLogger(__name__) + + +# XXX: we define a _ScriptModuleOptimizer here to explicitly +# compile the FunctionalOptimizer class into TorchScript +# This is because ScriptClass instance still lives in +# python unless you explicitly compile it as an attribute +# in ScriptModule or pass it to a ScriptFunction +# _ScriptLocalOptimizerInterface serves as a common +# interface type for Optimizer ScriptModules. +# +# TODO (wanchaol): remove this once we added TorchScript +# class reference semantics +@jit.interface +class _ScriptLocalOptimizerInterface: + def step(self, autograd_ctx_id: int) -> None: + pass + + +class _ScriptLocalOptimizer(nn.Module): + # TorchScript does not support multithread concurrent compiling. + # request_callback might invoke concurrent compiling, so we + # serialize the compiling with a lock + compile_lock = Lock() + + def __init__(self, optim_cls, local_params_rref, *args, **kwargs): + super().__init__() + self._local_params = [rref.local_value() for rref in local_params_rref] + self.optim = optim_cls(self._local_params, *args, **kwargs) + + @jit.export + def step(self, autograd_ctx_id: int): + all_local_grads = dist_autograd.get_gradients(autograd_ctx_id) + # apply functional optimizer step with a list of gradients + grads: List[Optional[Tensor]] = [ + all_local_grads[p] if p in all_local_grads else None + for p in self._local_params + ] + + self.optim.step(grads) + + +# TODO (wanchaol): remove/merge this with ScriptLocalOptimizer once +# we have converted all to functional optimizer in distributed.optim +class _LocalOptimizer: + # Ideally we would only need to share a lock for instances of + # _LocalOptimizer that deal with the same parameters. We are + # making a simplifying assumption here that if there is more + # than one instance of _LocalOptimizer per worker, they will + # be optimizing the same parameters (e.g. each data parallel + # trainer will create its own instance of _LocalOptimizer but + # they will all optimize the same parameters on each worker) + global_lock = Lock() + + def __init__(self, optim_cls, local_params_rref, *args, **kwargs): + self._local_params = [rref.local_value() for rref in local_params_rref] + self.optim = optim_cls(self._local_params, *args, **kwargs) + + def step(self, autograd_ctx_id): + all_local_grads = dist_autograd.get_gradients(autograd_ctx_id) + + with _LocalOptimizer.global_lock: + for param, grad in all_local_grads.items(): + param.grad = grad + self.optim.step() + + +def _new_local_optimizer(optim_cls, local_params_rref, *args, **kwargs): + return rpc.RRef(_LocalOptimizer(optim_cls, local_params_rref, *args, **kwargs)) + + +def _local_optimizer_step(local_optim_rref, autograd_ctx_id): + local_optim = local_optim_rref.local_value() + local_optim.step(autograd_ctx_id) + + +# new/step functions combined with _ScriptLocalOptimizer to provide GIL-free optimizer +def _new_script_local_optimizer(optim_cls, local_params_rref, *args, **kwargs): + optim = _ScriptLocalOptimizer(optim_cls, local_params_rref, *args, **kwargs) + + with _ScriptLocalOptimizer.compile_lock: + script_optim = jit.script(optim) + return rpc.RRef(script_optim, _ScriptLocalOptimizerInterface) + + +@jit.script +def _script_local_optimizer_step( + local_optim_rref: RRef[_ScriptLocalOptimizerInterface], autograd_ctx_id: int +) -> None: + local_optim = local_optim_rref.local_value() + local_optim.step(autograd_ctx_id) + + +def _wait_for_all(rpc_futs): + # TODO: improve error propagation + exception = None + results = [] + for fut in rpc_futs: + try: + results.append(fut.wait()) + except Exception as e: + results.append(e) + exception = e + if exception is not None: + raise exception + return results + + +class DistributedOptimizer: + """ + DistributedOptimizer takes remote references to parameters scattered + across workers and applies the given optimizer locally for each parameter. + + This class uses :meth:`~core.distributed.autograd.get_gradients` in order + to retrieve the gradients for specific parameters. + + Concurrent calls to + :meth:`~core.distributed.optim.DistributedOptimizer.step`, + either from the same or different clients, will + be serialized on each worker -- as each worker's optimizer can only work + on one set of gradients at a time. However, there is no guarantee that + the full forward-backward-optimizer sequence will execute for one client + at a time. This means that the gradients being applied may not correspond + to the latest forward pass executed on a given worker. Also, there is no + guaranteed ordering across workers. + + `DistributedOptimizer` creates the local optimizer with TorchScript enabled + by default, so that optimizer updates are not blocked by the Python Global + Interpreter Lock (GIL) in the case of multithreaded training (e.g. Distributed + Model Parallel). This feature is currently enabled for most optimizers. You + can also follow `the recipe`__ in PyTorch tutorials to enable TorchScript support + for your own custom optimizers. + + Args: + optimizer_class (optim.Optimizer): the class of optimizer to + instantiate on each worker. + params_rref (list[RRef]): list of RRefs to local or remote parameters + to optimize. + args: arguments to pass to the optimizer constructor on each worker. + kwargs: arguments to pass to the optimizer constructor on each worker. + + Example:: + >>> # xdoctest: +SKIP("distributed") + >>> from mindnlp import core.distributed.autograd as dist_autograd + >>> from mindnlp import core.distributed.rpc as rpc + >>> from mindnlp.core import optim + >>> from core.distributed.optim import DistributedOptimizer + >>> + >>> with dist_autograd.context() as context_id: + >>> # Forward pass. + >>> rref1 = rpc.remote("worker1", core.add, args=(core.ones(2), 3)) + >>> rref2 = rpc.remote("worker1", core.add, args=(core.ones(2), 1)) + >>> loss = rref1.to_here() + rref2.to_here() + >>> + >>> # Backward pass. + >>> dist_autograd.backward(context_id, [loss.sum()]) + >>> + >>> # Optimizer. + >>> dist_optim = DistributedOptimizer( + >>> optim.SGD, + >>> [rref1, rref2], + >>> lr=0.05, + >>> ) + >>> dist_optim.step(context_id) + + __ https://github.com/pytorch/tutorials/pull/1465 + """ + + def __init__(self, optimizer_class, params_rref, *args, **kwargs): + core._C._log_api_usage_once("core.distributed.optim.DistributedOptimizer") + per_worker_params_rref = defaultdict(list) + for param in params_rref: + per_worker_params_rref[param.owner()].append(param) + + if optimizer_class in functional_optim_map and jit._state._enabled: + optim_ctor = functional_optim_map.get(optimizer_class) + else: + optim_ctor = optimizer_class + self.is_functional_optim = optim_ctor != optimizer_class + + if self.is_functional_optim: + optimizer_new_func = _new_script_local_optimizer + else: + logger.warning( + "Creating the optimizer %s without TorchScript support, " + "this might result in slow computation time in multithreading environment" + "(i.e. Distributed Model Parallel training on CPU) due to the Python's " + "Global Interpreter Lock (GIL). Please file an issue if you need this " + "optimizer in TorchScript. ", + optimizer_class, + ) + optimizer_new_func = _new_local_optimizer + + remote_optim_futs = [] + for worker, param_rrefs in per_worker_params_rref.items(): + remote_optim_rref_fut = rpc.rpc_async( + worker, + optimizer_new_func, + args=(optim_ctor, param_rrefs) + args, + kwargs=kwargs, + ) + remote_optim_futs.append(remote_optim_rref_fut) + + self.remote_optimizers = _wait_for_all(remote_optim_futs) + + def step(self, context_id): + """ + Performs a single optimization step. + + This will call :meth:`core.optim.Optimizer.step` on each worker + containing parameters to be optimized, and will block until all workers + return. The provided ``context_id`` will be used to retrieve the + corresponding :class:`~core.distributed.autograd.context` that + contains the gradients that should be applied to the parameters. + + Args: + context_id: the autograd context id for which we should run the + optimizer step. + """ + dist_autograd._is_valid_context(context_id) + + optimizer_step_func = ( + _script_local_optimizer_step + if self.is_functional_optim + else _local_optimizer_step + ) + + rpc_futs = [ + rpc.rpc_async( + optimizer.owner(), + optimizer_step_func, + args=(optimizer, context_id), + ) + for optimizer in self.remote_optimizers + ] + _wait_for_all(rpc_futs) diff --git a/mindnlp/core/distributed/optim/post_localSGD_optimizer.py b/mindnlp/core/distributed/optim/post_localSGD_optimizer.py new file mode 100644 index 000000000..71ed61418 --- /dev/null +++ b/mindnlp/core/distributed/optim/post_localSGD_optimizer.py @@ -0,0 +1,110 @@ +# mypy: allow-untyped-defs +import warnings + +from mindnlp import core +from mindnlp import core.distributed.algorithms.model_averaging.averagers as averagers + + +class PostLocalSGDOptimizer(core.optim.Optimizer): + r""" + Wraps an arbitrary :class:`core.optim.Optimizer` and runs `post-local SGD `_, + This optimizer runs local optimizer at every step. + After the warm-up stage, it averages parameters periodically afer the local optimizer is applied. + + Args: + optim: The local optimizer. + averager: A model averager instance to run post-localSGD algorithm. + + Example:: + + >>> # xdoctest: +SKIP("undefined variables") + >>> from mindnlp import core + >>> from mindnlp import core.distributed as dist + >>> from mindnlp import core.distributed.algorithms.model_averaging.averagers as averagers + >>> from mindnlp import core.nn as nn + >>> from core.distributed.optim import PostLocalSGDOptimizer + >>> from core.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook import ( + >>> PostLocalSGDState, + >>> post_localSGD_hook, + >>> ) + >>> + >>> model = nn.parallel.DistributedDataParallel( + >>> module, device_ids=[rank], output_device=rank + >>> ) + >>> + >>> # Register a post-localSGD communication hook. + >>> state = PostLocalSGDState(process_group=None, subgroup=None, start_localSGD_iter=100) + >>> model.register_comm_hook(state, post_localSGD_hook) + >>> + >>> # Create a post-localSGD optimizer that wraps a local optimizer. + >>> # Note that ``warmup_steps`` used in ``PostLocalSGDOptimizer`` must be the same as + >>> # ``start_localSGD_iter`` used in ``PostLocalSGDState``. + >>> local_optim = core.optim.SGD(params=model.parameters(), lr=0.01) + >>> opt = PostLocalSGDOptimizer( + >>> optim=local_optim, + >>> averager=averagers.PeriodicModelAverager(period=4, warmup_steps=100) + >>> ) + >>> + >>> # In the first 100 steps, DDP runs global gradient averaging at every step. + >>> # After 100 steps, DDP runs gradient averaging within each subgroup (intra-node by default), + >>> # and post-localSGD optimizer runs global model averaging every 4 steps after applying the local optimizer. + >>> for step in range(0, 200): + >>> opt.zero_grad() + >>> loss = loss_fn(output, labels) + >>> loss.backward() + >>> opt.step() + """ + + def __init__(self, optim: core.optim.Optimizer, averager: averagers.ModelAverager): + self.optim = optim + self.param_groups = self.optim.param_groups + self.averager = averager + + @property + def state(self): + return self.optim.state + + def __repr__(self): + return self.optim.__repr__() + + def state_dict(self): + r""" + This is the same as :class:`core.optim.Optimizer` :meth:`state_dict`, + but adds an extra entry to record model averager's step to the checkpoint + to ensure reload does not cause unnecessary warm up again. + """ + optim_state_dict = self.optim.state_dict() + optim_state_dict["step"] = self.averager.step + return optim_state_dict + + def load_state_dict(self, state_dict): + r""" + This is the same as :class:`core.optim.Optimizer` :meth:`load_state_dict`, + but also restores model averager's step value to the one + saved in the provided ``state_dict``. + + If there is no ``"step"`` entry in ``state_dict``, + it will raise a warning and initialize the model averager's step to 0. + """ + self.optim.load_state_dict(state_dict) + if "step" in state_dict: + self.averager.step = state_dict["step"] + else: + warnings.warn( + "Loaded state dict does not contain a step counter for an averager. " + "Setting step counter to 0." + ) + self.averager.step = 0 + + def step(self): # type: ignore[override] + r""" + Performs a single optimization step (parameter update). + """ + self.optim.step() + self.averager.average_parameters(params=self.param_groups) + + def zero_grad(self, set_to_none: bool = True): # type: ignore[override] + self.optim.zero_grad(set_to_none=set_to_none) + + def add_param_group(self, param_group): + self.optim.add_param_group(param_group) diff --git a/mindnlp/core/distributed/optim/utils.py b/mindnlp/core/distributed/optim/utils.py new file mode 100644 index 000000000..c78100c17 --- /dev/null +++ b/mindnlp/core/distributed/optim/utils.py @@ -0,0 +1,66 @@ +# mypy: allow-untyped-defs +from typing import Type + +from mindnlp.core import optim + +from .functional_adadelta import _FunctionalAdadelta +from .functional_adagrad import _FunctionalAdagrad +from .functional_adam import _FunctionalAdam +from .functional_adamax import _FunctionalAdamax +from .functional_adamw import _FunctionalAdamW +from .functional_rmsprop import _FunctionalRMSprop +from .functional_rprop import _FunctionalRprop +from .functional_sgd import _FunctionalSGD + + +# dict to map a user passed in optimizer_class to a functional +# optimizer class if we have already defined inside the +# distributed.optim package, this is so that we hide the +# functional optimizer to user and still provide the same API. +functional_optim_map = { + optim.Adagrad: _FunctionalAdagrad, + optim.Adam: _FunctionalAdam, + optim.AdamW: _FunctionalAdamW, + optim.SGD: _FunctionalSGD, + optim.Adadelta: _FunctionalAdadelta, + optim.RMSprop: _FunctionalRMSprop, + optim.Rprop: _FunctionalRprop, + optim.Adamax: _FunctionalAdamax, +} + + +def register_functional_optim(key, optim): + """ + Interface to insert a new functional optimizer to functional_optim_map + ``fn_optim_key`` and ``fn_optimizer`` are user defined. The optimizer and key + need not be of :class:`core.optim.Optimizer` (e.g. for custom optimizers) + Example:: + >>> # import the new functional optimizer + >>> # xdoctest: +SKIP + >>> from xyz import fn_optimizer + >>> from core.distributed.optim.utils import register_functional_optim + >>> fn_optim_key = "XYZ_optim" + >>> register_functional_optim(fn_optim_key, fn_optimizer) + """ + if key not in functional_optim_map: + functional_optim_map[key] = optim + + +def as_functional_optim(optim_cls: Type, *args, **kwargs): + try: + functional_cls = functional_optim_map[optim_cls] + except KeyError as e: + raise ValueError( + f"Optimizer {optim_cls} does not have a functional " f"counterpart!" + ) from e + + return _create_functional_optim(functional_cls, *args, **kwargs) + + +def _create_functional_optim(functional_optim_cls: Type, *args, **kwargs): + return functional_optim_cls( + [], + *args, + **kwargs, + _allow_empty_param_list=True, + ) diff --git a/mindnlp/core/distributed/optim/zero_redundancy_optimizer.py b/mindnlp/core/distributed/optim/zero_redundancy_optimizer.py new file mode 100644 index 000000000..47418e6c0 --- /dev/null +++ b/mindnlp/core/distributed/optim/zero_redundancy_optimizer.py @@ -0,0 +1,1652 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +r"""Zero Redundancy Optimizer.""" +import collections +import copy +import enum +import inspect +import io +import logging +from itertools import chain +from typing import Any, Callable, Dict, List, Optional, Set, Type, Union + +from mindnlp import core +from mindnlp import core.distributed as dist +from core.distributed.algorithms.join import Join, Joinable, JoinHook +from core.distributed.optim.utils import functional_optim_map +from core.optim import Optimizer + + +__all__ = ["ZeroRedundancyOptimizer"] + + +logger = logging.getLogger(__name__) + + +# Credits: classy_vision/generic/distributed_util.py +def _recursive_copy_to_device( + value: Any, + non_blocking: bool, + device: core.device, +) -> Any: + r""" + Recursively searches lists, tuples, dicts and copies tensors to device if possible. + + Non-tensor values are passed as-is in the result. + + .. note: These are all copies, so if there are two objects that reference + the same object, then after this call, there will be two different objects + referenced on the device. + """ + if isinstance(value, core.Tensor): + return value.to(device, non_blocking=non_blocking) + + if isinstance(value, (list, tuple)): + values = [ + _recursive_copy_to_device(val, non_blocking=non_blocking, device=device) + for val in value + ] + return values if isinstance(value, list) else tuple(values) + + if isinstance(value, collections.abc.Mapping): + return { + key: _recursive_copy_to_device( + val, non_blocking=non_blocking, device=device + ) + for key, val in value.items() + } + + return value + + +def _is_trainable(param: core.Tensor) -> bool: + r"""Return if a parameter is trainable, where trainability is equivalent to requiring a gradient.""" + return param.requires_grad + + +def _broadcast_object( + obj: Any, + src_rank: int, + group: object = dist.group.WORLD, + device: core.device = core.device("cpu"), +) -> Any: + r""" + Broadcasts an object to the given group. + + It will be sending the object if called from the source rank and receiving + the object otherwise. + + Arguments: + obj: object to broadcast; only used if called on the source rank. + src_rank (int): source rank. + group (``ProcessGroup``, optional): group used for the broadcast + (default: ``dist.group.WORLD``). + device (``core.device``, optional): device to send from or receive + to (default: ``core.device("cpu")``). + + Returns: + The broadcasted object. + """ + if dist.get_rank() == src_rank: + # Send the object + buffer = io.BytesIO() + core.save(obj, buffer) + data = bytearray(buffer.getbuffer()) + length_tensor = core.LongTensor([len(data)]).to(device) + data_send_tensor = core.ByteTensor(data).to(device) + dist.broadcast(length_tensor, src=src_rank, group=group, async_op=False) + dist.broadcast(data_send_tensor, src=src_rank, group=group, async_op=False) + else: + # Receive the object + length_tensor = core.LongTensor([0]).to(device) + dist.broadcast(length_tensor, src=src_rank, group=group, async_op=False) + data_recv_tensor = core.empty( + [int(length_tensor.item())], dtype=core.uint8, device=device + ) + dist.broadcast(data_recv_tensor, src=src_rank, group=group, async_op=False) + buffer = io.BytesIO(data_recv_tensor.cpu().numpy()) + obj = core.load(buffer, map_location=device, weights_only=False) + return obj + + +class _ZeROJoinHook(JoinHook): + def __init__(self, zero): + assert isinstance(zero, ZeroRedundancyOptimizer), ( + "ZeRO join hook requires passing in a ZeroRedundancyOptimizer " + "instance as the state" + ) + self.zero = zero + super().__init__() + + def main_hook(self): + """ + Perform an optimizer step. + + This step updates the joined process's shard of + the parameters and broadcasts those parameters. + """ + self.zero.step() + + +class _DDPBucketAssignment: + r""" + Represent a :class:`DistributedDataParallel` bucket assignment. + + This means that a (possibly non-strict) subset of the parameters corresponding to + a DDP bucket assigned to a rank to update. + + Attributes: + bucket_index (int): index of the bucket determined by the DDP gradient + bucket all-reduce order. + parameters (List[core.Tensor]): model parameters in the bucket + assigned to this rank. + offset (int): offset into the :class:`GradBucket` 's :meth:`parameters` + giving the index of the first element in the passed-in + ``parameters``; this equivalently indexes into the + :class:`GradBucket` 's :meth:`gradients`. + device (core.device): device on which the parameters are stored. + tensor (core.Tensor): flattened tensor giving the data of the + parameter subset assigned to the rank. + """ + + def __init__( + self, + bucket_index: int, + parameters: List[core.Tensor], + offset: int, + ): + self.bucket_index = bucket_index + self.parameters = parameters + self.offset = offset + if len(self.parameters) == 0: + raise ValueError("Empty bucket assignment") + # DDP guarantees all parameters in the bucket have the same device + self.device: core.device = self.parameters[0].device + self.tensor: Optional[core.Tensor] = None + + +class _OverlapStatus(enum.IntEnum): + r""" + Define possible statuses that :class:`ZeroRedundancyOptimizer` can be in when overlapping with :class:`DistributedDataParallel`. + + Attributes: + ``UNINITIALIZED``: The ZeRO instance is effectively uninitialized and + is waiting for DDP to finalize its bucketing. + ``DDP_HAS_REBUILT_BUCKETS``: DDP has rebuilt its buckets, meaning that + its bucketing is finalized. The ZeRO instance can now collect the + necessary information about the DDP bucketing. + ``INITIALIZED``: The ZeRO instance is fully initialized and can now + optimize parameters. + """ + + UNINITIALIZED = 0 + DDP_HAS_REBUILT_BUCKETS = 1 + INITIALIZED = 2 + + +class _OverlapInfo: + r""" + Information needed by :class:`ZeroRedundancyOptimizer` to overlap with :class:`DistributedDataParallel`. + + Arguments: + world_size (int): world size of the process group being used. + + Attributes: + shard_buckets (bool): if ``True``, then the assignment of each + :class:`DistributedDataParallel` bucket is partitioned across + possibly multiple :class:`ZeroRedundancyOptimizer` instances (i.e. + across possibly multiple ranks) to approximate uniformity following + a threshold given by the total parameter size divided by the world + size; if ``False``, then each bucket is wholly assigned to a single + :class:`ZeroRedundancyOptimizer` instance (i.e. to a single rank); + this should be set to the value passed into the hook constructor. + status (_OverlapStatus): current status; see :class:`_OverlapStatus` + for more information. + params_per_bucket (List[List[core.Tensor]]): ``params_per_bucket[i]`` + gives the model parameters in the ``i``th bucket. + params_per_rank (List[List[core.Tensor]]): ``params_per_rank[i]`` + gives the model parameters assigned to the ``i``th rank, where the + parameters are grouped by increasing bucket indices. + offsets (Dict[int, int]): maps from bucket index to the offset in + ``self.params_per_rank[rank]`` giving the index of the first + parameter in that bucket, where ``rank`` is this process's own + rank; the keys of this :class:`dict` are the bucket indices + assigned to this rank. + num_bucket_assignments (int): total number of bucket assignments across + all ranks; this is equal to the number of + :class:`DistributedDataParallel` gradient buckets if + ``shard_buckets=False`` and possibly greater otherwise. + total_size (int, optional): total size of all buckets (i.e. sum of + ``param.numel()`` for all ``param`` across all buckets) if + ``shard_buckets=True``; otherwise, ``None``. + broadcast_handles (List[Work]): :class:`list` of async work handles for + the parameter broadcasts. + bucket_index_to_future (Dict[int, core.futures.Future]): + :class:`dict` mapping bucket index to the corresponding all-reduce + future. + bucket_index_to_bucket (Dict[int, dist.GradBucket]): :class:`dict` + mapping bucket index to the corresponding bucket. + bucket_indices_seen (List[int]): :class:`list` of the bucket indices + seen on this iteration. + """ + + def __init__(self, world_size) -> None: + self.status: _OverlapStatus = _OverlapStatus.UNINITIALIZED + self.shard_buckets: bool = False + + # Modified per bucket reconstruction + self.params_per_bucket: List[List[core.Tensor]] = [] + self.params_per_rank: List[List[core.Tensor]] = [[] for _ in range(world_size)] + self.offsets: Dict[int, int] = {} + # Group Ranks + self.assigned_ranks_per_bucket: List[Set[int]] = [] + self.num_bucket_assignments: int = 0 + self.total_size: Optional[int] = None + + # Modified per iteration + self.broadcast_handles: List[Any] = [] + self.bucket_indices_seen: List[int] = [] + # Used by `hook_with_zero_step()` + self.bucket_index_to_future: Dict[int, core.futures.Future] = {} + self.bucket_index_to_bucket: Dict[int, dist.GradBucket] = {} + + def wait_for_broadcasts(self) -> None: + r""" + Wait for all parameter broadcasts. + + This function should be called once all broadcasts have been scheduled, + meaning ``self.broadcast_handles`` is filled. This clears ``self.broadcast_handles`` + in preparation for the next iteration. + """ + assert ( + len(self.broadcast_handles) == self.num_bucket_assignments + ), f"Missing at least one broadcast handle on rank {dist.get_rank()}" + _ = [x.wait() for x in self.broadcast_handles] + self.broadcast_handles.clear() + + def clear_per_iter_info(self) -> None: + r""" + Clear the data structures that are modified per-iteration. + + This function should be called at the end of an iteration. + """ + self.bucket_indices_seen.clear() + self.bucket_index_to_future.clear() + self.bucket_index_to_bucket.clear() + + +class ZeroRedundancyOptimizer(Optimizer, Joinable): + r""" + Wrap an arbitrary :class:`optim.Optimizer ` and shards its states across ranks in the group. + + The sharing is done as described by ZeRO_. + + The local optimizer instance in each rank is only + responsible for updating approximately ``1 / world_size`` parameters and + hence only needs to keep ``1 / world_size`` optimizer states. After + parameters are updated locally, each rank will broadcast its parameters to + all other peers to keep all model replicas in the same state. + ``ZeroRedundancyOptimizer`` can be used in conjunction with + :class:`core.nn.parallel.DistributedDataParallel` to reduce per-rank peak + memory consumption. + + ``ZeroRedundancyOptimizer`` uses a sorted-greedy algorithm to pack a number + of parameters at each rank. Each parameter belongs to a single rank and is + not divided among ranks. The partition is arbitrary and might not match the + the parameter registration or usage order. + + Arguments: + params (``Iterable``): an ``Iterable`` of :class:`core.Tensor` s + or :class:`dict` s giving all parameters, which will be sharded + across ranks. + + Keyword Args: + optimizer_class (:class:`core.nn.Optimizer`): the class of the local + optimizer. + process_group (``ProcessGroup``, optional): ``core.distributed`` + ``ProcessGroup`` (default: ``dist.group.WORLD`` initialized by + :meth:`core.distributed.init_process_group`). + parameters_as_bucket_view (bool, optional): if ``True``, parameters are + packed into buckets to speed up communication, and ``param.data`` + fields point to bucket views at different offsets; if ``False``, + each individual parameter is communicated separately, and each + ``params.data`` stays intact (default: ``False``). + overlap_with_ddp (bool, optional): if ``True``, :meth:`step` is + overlapped with :class:`DistributedDataParallel` 's gradient + synchronization; this requires (1) either a functional optimizer + for the ``optimizer_class`` argument or one with a functional + equivalent and (2) registering a DDP communication hook + constructed from one of the functions in ``ddp_zero_hook.py``; + parameters are packed into buckets matching those in + :class:`DistributedDataParallel`, meaning that the + ``parameters_as_bucket_view`` argument is ignored. + If ``False``, :meth:`step` runs disjointly after the backward pass + (per normal). + (default: ``False``) + **defaults: any trailing arguments, which are forwarded to the local + optimizer. + + Example:: + + >>> # xdoctest: +SKIP + >>> from mindnlp import core.nn as nn + >>> from core.distributed.optim import ZeroRedundancyOptimizer + >>> from core.nn.parallel import DistributedDataParallel as DDP + >>> model = nn.Sequential(*[nn.Linear(2000, 2000).to(rank) for _ in range(20)]) + >>> ddp = DDP(model, device_ids=[rank]) + >>> opt = ZeroRedundancyOptimizer( + >>> ddp.parameters(), + >>> optimizer_class=core.optim.Adam, + >>> lr=0.01 + >>> ) + >>> ddp(inputs).sum().backward() + >>> opt.step() + + .. warning:: + Currently, ``ZeroRedundancyOptimizer`` requires that all of the + passed-in parameters are the same dense type. + + .. warning:: + If you pass ``overlap_with_ddp=True``, be wary of the following: Given + the way that overlapping :class:`DistributedDataParallel` with + :class:`ZeroRedundancyOptimizer` is currently implemented, the first + two or three training iterations do not perform parameter updates in + the optimizer step, depending on if ``static_graph=False`` or + ``static_graph=True``, respectively. This is because it needs + information about the gradient bucketing strategy used by + :class:`DistributedDataParallel`, which is not finalized until the + second forward pass if ``static_graph=False`` or until the third + forward pass if ``static_graph=True``. To adjust for this, one option + is to prepend dummy inputs. + + .. warning:: ZeroRedundancyOptimizer is experimental and subject to change. + + .. _ZeRO: https://arxiv.org/abs/1910.02054 + + """ + + def __init__( + self, + params, + optimizer_class: Type[Optimizer], + process_group: Optional[Any] = None, + parameters_as_bucket_view: bool = False, + overlap_with_ddp: bool = False, + **defaults: Any, + ): + r"""Init.""" + # Perform type and assumption checks on the input parameters + params = self._verify_and_init_params(params) + self._verify_same_dense_param_type() + + # NOTE: The parent constructor uses `add_param_group()` which is + # partially overloaded in ZeroRedundancyOptimizer, so we use the + # `initialized` flag to dissociate the behaviour of `add_param_group()` + # between the parent and child. + self.initialized = False + + Optimizer.__init__(self, params, defaults) + Joinable.__init__(self) + # Now, all parameters are held in both `self._all_params` and + # `self.param_groups` + + # Internal data structures (`_cache` indicates lazily evaluated) + self._param_to_rank_cache: Dict[core.Tensor, int] = {} + self._param_to_index_cache: Dict[core.Tensor, int] = {} + self._partition_parameters_cache: List[List[Dict]] = [] + self._index_to_param_cache: List[core.Tensor] = [] + self._device_to_params_per_rank_cache: Dict[ + core.device, List[List[core.Tensor]] + ] = {} + self._bucket_assignments_per_rank_cache: List[ + Dict[int, _DDPBucketAssignment] + ] = [] + self._is_trainable_mask = self._get_is_trainable_mask() + + # Default device for collective communication and buckets + self._default_device = self._all_params[0].device + + self.process_group = ( + process_group if process_group is not None else dist.group.WORLD + ) + self.world_size: int = dist.get_world_size(self.process_group) + self.rank: int = dist.get_rank(self.process_group) + self.global_rank: int = dist.distributed_c10d.get_global_rank( + self.process_group, self.rank + ) + + self._overlap_with_ddp: bool = overlap_with_ddp + self._optim_defaults = defaults + self._optim_constructor = self._get_optimizer_constructor(optimizer_class) + + # If `overlap_with_ddp=True`, local optimizer initialization is delayed + # to run time after the necessary information has been collected + if not overlap_with_ddp: + self._init_local_optimizer() + else: + self._overlap_info: _OverlapInfo = _OverlapInfo(self.world_size) + if parameters_as_bucket_view: + logger.warning( + "`parameters_as_bucket_view=True` will be ignored since " + "`overlap_with_ddp=True`; instead, a different bucketing " + "strategy will be used" + ) + + # `self._buckets` is used if `parameters_as_bucket_view=True`, in + # which case parameter data is flattened into contiguous bucket tensors + self.parameters_as_bucket_view = parameters_as_bucket_view + self._buckets: List[List[core.Tensor]] = [] + self._build_param_buckets() + + # Optional consolidated optimizer state, only populated if this rank + # is the target in `consolidate_state_dict()` + self._all_state_dicts: List[Dict[str, Any]] = [] + + self.initialized = True + + def _clear_cache(self) -> None: + r"""Clear the cached data structures giving partition information.""" + self._partition_parameters_cache.clear() + self._param_to_rank_cache.clear() + self._index_to_param_cache.clear() + self._param_to_index_cache.clear() + self._device_to_params_per_rank_cache.clear() + self._bucket_assignments_per_rank_cache.clear() + + def add_param_group(self, param_group: Dict[str, Any]) -> None: + r""" + Add a parameter group to the :class:`Optimizer` 's ``param_groups``. + + This can be useful when fine tuning a pre-trained network, as frozen + layers can be made trainable and added to the :class:`Optimizer` as + training progresses. + + Arguments: + param_group (dict): specifies the parameters to be optimized and + group-specific optimization options. + + .. warning:: This method handles updating the shards on all partitions + but needs to be called on all ranks. Calling this on a subset of + the ranks will cause the training to hang because communication + primitives are called depending on the managed parameters and + expect all the ranks to participate on the same set of parameters. + """ + if self.initialized and self._overlap_with_ddp: + raise RuntimeError( + "ZeroRedundancyOptimizer with `overlap_with_ddp=True` only " + "supports a single parameter group" + ) + + super().add_param_group(param_group) + # NOTE: The rest of the method assumes that the call to the parent's + # `add_param_group()` appends the new parameter group and preserves + # the previous parameter-group ordering + + if self.initialized: + # Force a re-partitioning of the parameters + self._clear_cache() + param_groups = self._partition_parameters()[self.rank] + # NOTE: All parameters in the old parameter groups should be + # assigned to the same ranks so that the local optimizers do not + # need to be reinitialized + + # Add the parameters assigned to this rank from the new parameter + # group to the local optimizer, if any + if len(param_groups) == len(self.optim.param_groups) + 1: + self.optim.add_param_group(param_groups[-1]) + + # Update the bucketing strategy accordingly + if self.parameters_as_bucket_view: + self._build_param_buckets() + + def consolidate_state_dict(self, to: int = 0) -> None: + r""" + Consolidate a list of ``state_dict`` s (one per rank) on the target rank. + + Arguments: + to (int): the rank that receives the optimizer states (default: 0). + + Raises: + RuntimeError: if ``overlap_with_ddp=True`` and this method is + called before this :class:`ZeroRedundancyOptimizer` instance + has been fully initialized, which happens once + :class:`DistributedDataParallel` gradient buckets have been + rebuilt. + + .. warning:: This needs to be called on all ranks. + """ + self._check_overlap_initialized() + + # Sync the exposed `param_groups` attributes to the local optimizer in + # case they have been updated + self._sync_param_groups(self.param_groups, self.optim.param_groups) + + # Pull the sharded state from all ranks and store them in rank order + empty_messenger = core.tensor( + [0], dtype=core.uint8, device=self._default_device + ) + + # NOTE: We wastefully use `broadcast()` (e.g. instead of `gather()`) + # due to compatibility issues with NCCL backend; a possible follow-up + # is to move all sharded state management to RPC RRef + self._all_state_dicts = [] + for rank in range(self.world_size): + global_rank = dist.distributed_c10d.get_global_rank( + self.process_group, rank + ) + if self.rank == to: + # Consolidate all local `state_dict`s on this rank, storing on + # CPU to save GPU memory + if rank == self.rank: + # Directly append own optimizer state + self._all_state_dicts.append( + _recursive_copy_to_device( + self.optim.state_dict(), + non_blocking=True, + device=core.device("cpu"), + ) + ) + else: + # Receive the optimizer state from the source rank + local_state_dict = _broadcast_object( + empty_messenger, + src_rank=global_rank, + group=self.process_group, + device=self._default_device, + ) + self._all_state_dicts.append( + _recursive_copy_to_device( + local_state_dict, + non_blocking=True, + device=core.device("cpu"), + ) + ) + else: + if rank == self.rank: + # Send the optimizer state to the target rank + _ = _broadcast_object( + self.optim.state_dict(), + src_rank=self.global_rank, + group=self.process_group, + device=self._default_device, + ) + elif rank != to: + # Discard the received object; `broadcast()` is used for + # compatibility reasons + _ = _broadcast_object( + empty_messenger, + src_rank=global_rank, + group=self.process_group, + device=self._default_device, + ) + + def _verify_params_per_rank( + self, + params_per_rank: List[List[core.Tensor]], + ) -> None: + r""" + Verify ``params_per_rank`` for :meth:`_partition_parameters`. + + The verification is done by checking that ``params_per_rank`` has length equal + to the world size and that it does not contain any parameters not passed into the + :class:`ZeroRedundancyOptimizer` constructor. + + The parameters in ``params_per_rank`` being a strict subset of those + passed into the constructor is valid since some parameters may be + frozen. + + Raises: + ValueError: if ``params_per_rank`` does not have length equal to + the world size or if it contains a parameter that was not + passed into the :class:`ZeroRedundancyOptimizer` constructor. + """ + if len(params_per_rank) != self.world_size: + raise ValueError( + "`params_per_rank` must have length equal to the world size" + ) + all_params_set = set(self._all_params) + for params in params_per_rank: + for param in params: + if param not in all_params_set: + raise ValueError( + "Passing a new parameter in `params_per_rank` that " + "was not passed into the ZeroRedundancyOptimizer " + "constructor" + ) + + def _partition_param_group( + self, param_group: Dict[str, Any], params_per_rank: List[List[core.Tensor]] + ) -> None: + r""" + Partition the parameter group ``param_group`` according to ``params_per_rank``. + + The partition will modify the ``self._partition_parameters_cache``. This method should + only be used as a subroutine for :meth:`_partition_parameters`. + + Arguments: + param_group (dict[str, Any]): a parameter group as normally defined + in an optimizer state. + params_per_rank (list[list[core.Tensor]]): a :class:`list` of + length world size containing :class:`list` s of parameters to + assign to each rank. + """ + for rank, params in enumerate(params_per_rank): + rank_param_group = copy.copy(param_group) + rank_param_group["params"] = params + self._partition_parameters_cache[rank].append(rank_param_group) + + def _partition_parameters( + self, + params_per_rank: Optional[List[List[core.Tensor]]] = None, + ) -> List[List[Dict]]: + r""" + Partitions parameters across distributed data parallel ranks. + + Arguments: + params_per_rank (list[list[core.Tensor]], optional): a + :class:`list` of length world size containing :class:`list` s + of parameters to assign to each rank; this provides a way to + specify a partition manually. + If ``None``, the parameters are partitioned according to an + internal algorithm. + (default: ``None``) + + Returns: + A :class:`list` where each element of the list contains the + ``param_groups`` for a rank (which itself is a :class:`list` of + :class:`dict`); element 0 corresponds to rank 0, etc.; each rank + stores the ``param_groups`` for all ranks for the collective + communication in :meth:`step`. + + Raises: + ValueError: see :meth:`_validate_params_per_rank`. + RuntimeError: if ``params_per_rank`` is not ``None`` and this + :class:`ZeroRedundancyOptimizer` instance is using more than + one parameter group. + """ + if params_per_rank is None: + # Partition the parameters optimizing for uniformity + if len(self._partition_parameters_cache) == 0: + self._partition_parameters_cache = [[] for _ in range(self.world_size)] + sizes = [0] * self.world_size + for param_group in self.param_groups: + param_group_params_per_rank: List[List] = [ + [] for _ in range(self.world_size) + ] + # Sort the parameters by size (largest first) + params_sorted = sorted( + param_group["params"], key=lambda t: t.numel(), reverse=True + ) + for param in params_sorted: + # Greedily add the parameter to rank with smallest size so far + rank = self._get_min_index(sizes) + param_group_params_per_rank[rank].append(param) + sizes[rank] += param.numel() + # Apply the constructed partition of the parameter group + self._partition_param_group( + param_group, param_group_params_per_rank + ) + + return self._partition_parameters_cache + + # Partition the parameters according to `params_per_rank` + assert len(self._partition_parameters_cache) == 0, ( + "Specifying `params_per_rank` should only be done when the " + "parameters have not been partitioned yet" + ) + if len(self.param_groups) != 1: + raise RuntimeError( + "Specifying `params_per_rank` only supports a single parameter group" + ) + self._verify_params_per_rank(params_per_rank) + self._partition_parameters_cache = [[] for _ in range(self.world_size)] + + # Apply the passed-in partition of the parameter group + param_group = self.param_groups[0] + self._partition_param_group(param_group, params_per_rank) + + return self._partition_parameters_cache + + @property + def _param_to_rank(self) -> Dict[core.Tensor, int]: + r""":class:`dict` mapping parameters to their assigned data parallel rank in the partition.""" + if len(self._param_to_rank_cache) == 0: + for rank, param_groups in enumerate(self._partition_parameters()): + for param_group in param_groups: + for param in param_group["params"]: + self._param_to_rank_cache[param] = rank + return self._param_to_rank_cache + + @property + def _param_to_index(self) -> Dict[core.Tensor, int]: + r""" + :class:`dict` mapping parameters to their indices in the global optimizer state. + + NOTE: This assumes that the global optimizer state's indexing (in + ``state_dict``) follows a linear ordering over the parameter groups. + """ + if len(self._param_to_index_cache) == 0: + self._param_to_index_cache = { + p: i + for i, p in enumerate(chain(*(g["params"] for g in self.param_groups))) + } + return self._param_to_index_cache + + @property + def _index_to_param(self) -> List[core.Tensor]: + r"""List mapping parameter indices in the global optimizer scheme to the actual params.""" + if len(self._index_to_param_cache) == 0: + self._index_to_param_cache = list( + chain(*(g["params"] for g in self.param_groups)) + ) + return self._index_to_param_cache + + def _broadcast_params_from_rank(self, rank: int): + r""" + Broadcast the shard of parameters from a given rank to all other ranks asynchronously. + + Arguments: + rank (int): the source rank. + + Returns: + A :class:`list` of async work handles for the ``broadcast()`` s + performed to synchronize the parameters. + """ + assert not self._overlap_with_ddp, ( + "`_broadcast_params_from_rank()` should not be used if " + "`overlap_with_ddp=True`; instead, the broadcasting should " + "happen in the DDP communication hook" + ) + handles = [] + if self.parameters_as_bucket_view: + for dev_i_buckets in self._buckets: + bucket = dev_i_buckets[rank] + global_rank = dist.distributed_c10d.get_global_rank( + self.process_group, rank + ) + handles.append( + dist.broadcast( + tensor=bucket, + src=global_rank, + group=self.process_group, + async_op=True, + ) + ) + else: + param_groups = self._partition_parameters()[rank] + global_rank = dist.distributed_c10d.get_global_rank( + self.process_group, rank + ) + for param_group in param_groups: + handles.extend( + dist.broadcast( + tensor=param.data, + src=global_rank, + group=self.process_group, + async_op=True, + ) + for param in param_group["params"] + ) + return handles + + def _sync_params(self): + r""" + Sync all parameter shards across the ranks. + + This rank sends its shard of the parameters to all other ranks and + receives a shard from each other rank. This is done using + ``broadcast()``. Parameters are sent bucket-by-bucket if + ``parameters_as_bucket_view=True``and sent parameter-by-parameter + otherwise. + """ + handles = [] + for rank in range(self.world_size): + handles.extend(self._broadcast_params_from_rank(rank)) + _ = [x.wait() for x in handles] + + @property + def _device_to_params_per_rank( + self, + ) -> Dict[core.device, List[List[core.Tensor]]]: + r""" + Return device parameters assigned per rank. + + :class:`dict` mapping each device to a :class:`list` of the per-rank parameter + lists filtered to only include the parameters stored on that device. + Each per-rank parameter list gives the parameters assigned to that rank + to update. + + This is used for constructing the parameter buckets if + ``parameters_as_bucket_view=True``. + + Let ``dev_i`` denote the ``i``th device for this rank. Then: + ``dev_0`` maps to a list containing: + rank 0's assigned parameters stored on ``dev_0``, + rank 1's assigned parameters stored on ``dev_0``, + ... + ``dev_1`` maps to a list containing: + rank 0's assigned parameters stored on ``dev_1``, + rank 1's assigned parameters stored on ``dev_1``, + ... + ... + """ + assert self.parameters_as_bucket_view, ( + "`_device_to_params_per_rank` should only be used if " + "`parameters_as_bucket_view=True`" + ) + if len(self._device_to_params_per_rank_cache) == 0: + for rank, param_groups in enumerate(self._partition_parameters()): + for param_group in param_groups: + for param in param_group["params"]: + device = param.device + if device not in self._device_to_params_per_rank_cache: + self._device_to_params_per_rank_cache[device] = [ + [] for _ in range(self.world_size) + ] + self._device_to_params_per_rank_cache[device][rank].append( + param + ) + return self._device_to_params_per_rank_cache + + def _get_min_index( + self, + values: List[int], + disallowed_indices: Optional[Set[int]] = None, + ) -> int: + r""" + Return ``values.index(min(values))``, except only uses one pass. + + It also excludes any indices in ``disallowed_indices`` if provided. + + Arguments: + values: (List[int]): :class:`list` of values. + disallowed_indices (Optional[Set[int]]): indices that are + disallowed from being the returned min index. + """ + min_index = -1 + min_value = float("inf") + for i, value in enumerate(values): + if disallowed_indices and i in disallowed_indices: + continue + if value < min_value: + min_value = value + min_index = i + assert min_index >= 0, "All indices are disallowed" + return min_index + + def _assign_bucket_subset_to_rank( + self, + bucket_index: int, + bucket_params: List[core.Tensor], + bucket_offset: int, + assigned_rank: int, + assigned_ranks_per_bucket: List[Set[int]], + ) -> None: + r""" + Assign ``bucket_params`` to the rank with the least size assigned so far and collects relevant information. + + The model parameters given by ``bucket_params`` represents a (possibly non-strict) + subset of the parameters corresponding to a :class:`DistributedDataParallel` bucket. + + Arguments: + bucket_index (int): index of the :class:`DistributedDataParallel` + gradient bucket. + bucket_params (List[core.Tensor]): subset of the parameters + corresponding to the bucket to assign. + bucket_offset (int): offset giving the index of the first element + in ``bucket_params`` in the bucket's full parameter list. + assigned_rank (int): group rank to assign to. + assigned_ranks_per_bucket (List[Set[int]]): :class:`set` of group ranks + assigned to each bucket. + """ + overlap_info = self._overlap_info + if len(bucket_params) == 0: + raise ValueError("Empty bucket assignment") + params_per_rank = overlap_info.params_per_rank + offsets = overlap_info.offsets + + self._bucket_assignments_per_rank_cache[assigned_rank][ + bucket_index + ] = _DDPBucketAssignment(bucket_index, bucket_params, bucket_offset) + if self.global_rank == assigned_rank: + offsets[bucket_index] = len(params_per_rank[assigned_rank]) + params_per_rank[assigned_rank].extend(bucket_params) + assigned_ranks_per_bucket[bucket_index].add(assigned_rank) + self._overlap_info.num_bucket_assignments += 1 + + @property + def _bucket_assignments_per_rank(self) -> List[Dict[int, _DDPBucketAssignment]]: + r""" + Return DDP bucket parameters assigned per rank. + + :class:`list` of length world size consisting of :class:`dict` s + mapping bucket indices to :class:`_DDPBucketAssignment` s for each + rank. + """ + assert ( + self._overlap_with_ddp + ), "`_bucket_assignments_per_rank` only be used if `overlap_with_ddp=True`" + if len(self._bucket_assignments_per_rank_cache) > 0: + return self._bucket_assignments_per_rank_cache + + overlap_info = self._overlap_info + assert overlap_info.status == _OverlapStatus.INITIALIZED + + self._bucket_assignments_per_rank_cache = [{} for _ in range(self.world_size)] + params_per_bucket = overlap_info.params_per_bucket + + if overlap_info.shard_buckets: + # Define the assignment threshold to approximate uniformity + assert overlap_info.total_size is not None, "`total_size` was not computed" + threshold = overlap_info.total_size / self.world_size # type: ignore[operator] + size_per_rank = [0 for _ in range(self.world_size)] + + num_buckets = len(params_per_bucket) + overlap_info.assigned_ranks_per_bucket = [set() for _ in range(num_buckets)] + assigned_ranks_per_bucket = overlap_info.assigned_ranks_per_bucket + if not overlap_info.shard_buckets: + # Assign each DDP bucket entirely to a single rank + for bucket_index, bucket_params in enumerate(params_per_bucket): + assert len(bucket_params) > 0, "Empty bucket" + assigned_rank = self._get_assigned_rank(bucket_index) + self._assign_bucket_subset_to_rank( + bucket_index, + bucket_params, + 0, + assigned_rank, + assigned_ranks_per_bucket, + ) + else: + # Assign each DDP bucket to possibly multiple ranks + # Specifically, sort the DDP buckets by increasing size, and for + # each bucket, iteratively assign the maximal unassigned subset + # with size less than `threshold` to the rank with the least total + # size so far -- each such assignment is represented by a + # `_DDPBucketAssignment` instance and only contains parameters from + # a single DDP bucket + params_per_bucket_enum = sorted( + enumerate(params_per_bucket), key=lambda x: sum(p.numel() for p in x[1]) + ) + for bucket_index, bucket_params in params_per_bucket_enum: + assert len(bucket_params) > 0, "Empty bucket" + bucket_offset = 0 + assignment_size = 0 + for param_index, param in enumerate(bucket_params): + param_numel = param.numel() + if ( + assignment_size + param_numel >= threshold + and param_index > bucket_offset + ): + assigned_rank = self._get_min_index( + size_per_rank, assigned_ranks_per_bucket[bucket_index] + ) + # Include up to but not including the parameter that + # exceeded the threshold + self._assign_bucket_subset_to_rank( + bucket_index, + bucket_params[bucket_offset:param_index], + bucket_offset, + assigned_rank, + assigned_ranks_per_bucket, + ) + size_per_rank[assigned_rank] += assignment_size + bucket_offset = param_index + assignment_size = 0 + assignment_size += param_numel + # Assign the remainder of the bucket so that no assignment + # spans across two buckets + assigned_rank = self._get_min_index( + size_per_rank, assigned_ranks_per_bucket[bucket_index] + ) + self._assign_bucket_subset_to_rank( + bucket_index, + bucket_params[bucket_offset:], + bucket_offset, + assigned_rank, + assigned_ranks_per_bucket, + ) + size_per_rank[assigned_rank] += assignment_size + + return self._bucket_assignments_per_rank_cache + + def _local_step( + self, + gradients: Optional[List[Optional[core.Tensor]]] = None, + closure: Optional[Callable[[], float]] = None, + **kwargs: Any, + ) -> Optional[float]: + r""" + Perform a single optimizer step without syncing parameters across ranks. + + Arguments: + gradients (list[Optional[core.Tensor]], optional): a :class:`list` + of length equal to the number of parameters assigned to this + rank containing gradient tensors or ``None`` as its elements; + a ``None`` in the :class:`list` indicates that the + corresponding parameter should not be updated. + If the argument itself is ``None``, then all parameters are + updated, and the gradients are assumed to be already populated. + (default: ``None``) + closure (Callable): a closure that re-evaluates the model and + returns the loss; optional for most optimizers and should be + ``None`` if ``gradients`` is not ``None``; (default: ``None``) + Returns: + Optional loss depending on the underlying local optimizer. + + .. warning:: + The argument ``gradients`` should only be specified (i.e. not + ``None``) if ``overlap_with_ddp=True``, in which case + :class:`ZeroRedundancyOptimizer` wraps a functional optimizer. + """ + Join.notify_join_context(self) + # Check if the model trainability has changed + is_trainable_mask = self._get_is_trainable_mask() + if is_trainable_mask != self._is_trainable_mask: + if self._overlap_with_ddp: + raise RuntimeError( + "ZeroRedundancyOptimizer with `overlap_with_ddp=True` " + "does not support changing parameter trainability at run " + "time" + ) + logger.warning( + "ZeroRedundancyOptimizer detected that the trainable " + "parameters changed; rebuilding the parameter buckets if " + "enabled" + ) + self._build_param_buckets() + self._is_trainable_mask = is_trainable_mask + + # Sync the exposed `param_groups` attributes to the local optimizer in + # case they have been updated + self._sync_param_groups(self.param_groups, self.optim.param_groups) + + # Run the optimizer step on this shard only + if gradients is None: + loss = ( + self.optim.step(**kwargs) + if closure is None + else self.optim.step(closure=closure, **kwargs) + ) + else: + assert self._overlap_with_ddp, ( + "Specifying `gradients` should not " + "be used when `overlap_with_ddp=False`" + ) + assert ( + closure is None + ), "`closure` is not supported when using a local functional optimizer" + loss = self.optim.step(gradients=gradients) + + # Sync any updated attributes in the local optimizer to the exposed + # `param_groups` + self._sync_param_groups(self.optim.param_groups, self.param_groups) + + return loss + + def step( + self, + closure: Optional[Callable[[], float]] = None, + **kwargs: Any, + ) -> Optional[float]: + r""" + Perform a single optimizer step and syncs parameters across all ranks. + + Arguments: + closure (Callable): a closure that re-evaluates the model and + returns the loss; optional for most optimizers. + Returns: + Optional loss depending on the underlying local optimizer. + + .. note: Any extra parameters are passed to the base optimizer as-is. + """ + if self._overlap_with_ddp: + logger.warning( + "`step()` should not be included in the training loop when " + "`overlap_with_ddp=True`" + ) + return None + + # Perform the local optimizer step + loss = self._local_step(closure=closure, **kwargs) + + # Sync all of the updated parameter shards across the ranks + self._sync_params() + + return loss + + def join_hook(self, **kwargs): + r""" + Return the ZeRO join hook. + + It enables training on uneven inputs by + shadowing the collective communications in the optimizer step. + + Gradients must be properly set before this hook is called. + + Arguments: + kwargs (dict): a :class:`dict` containing any keyword arguments + to modify the behavior of the join hook at run time; all + :class:`Joinable` instances sharing the same join context + manager are forwarded the same value for ``kwargs``. + + This hook does not support any keyword arguments; i.e. ``kwargs`` is + unused. + """ + return _ZeROJoinHook(self) + + @property + def join_device(self) -> core.device: + r"""Return default device.""" + return self._default_device + + @property + def join_process_group(self) -> Any: + r"""Return process group.""" + return self.process_group + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + r""" + Load the state pertaining to the given rank from the input ``state_dict``, updating the local optimizer as needed. + + Arguments: + state_dict (dict): optimizer state; should be an object returned + from a call to :meth:`state_dict`. + + Raises: + RuntimeError: if ``overlap_with_ddp=True`` and this method is + called before this :class:`ZeroRedundancyOptimizer` instance + has been fully initialized, which happens once + :class:`DistributedDataParallel` gradient buckets have been + rebuilt. + """ + self._check_overlap_initialized() + + for index, value in state_dict["state"].items(): + param = self._index_to_param[index] + if self._param_to_rank[param] != self.rank: + # Clear any state irrelevant to this rank + state_dict["state"][index] = None + else: + # Load the parameter state to the local optimizer + self.optim.state[param] = _recursive_copy_to_device( + value, non_blocking=True, device=param.device + ) + # Force zero-dimensional tensors (like Adam "step") on CPU + for state_name, state_value in self.optim.state[param].items(): + if core.is_tensor(state_value) and state_value.dim() == 0: + self.optim.state[param][state_name] = state_value.cpu() + + super().load_state_dict(state_dict) + + # Sync the input state with the exposed and local optimizer states + self._sync_param_groups(state_dict["param_groups"], self.param_groups) + self._sync_param_groups(self.param_groups, self.optim.param_groups) + + def state_dict(self) -> Dict[str, Any]: + r""" + Return the last global optimizer state known to this rank. + + .. warning: + If the state has not been consolidated to this rank, this raises a + runtime error, and even if it has, the state may not be up-to-date, + depending on when :meth:`consolidate_state_dict` was last called. + + Raises: + RuntimeError: if ``overlap_with_ddp=True`` and this method is + called before this :class:`ZeroRedundancyOptimizer` instance + has been fully initialized, which happens once + :class:`DistributedDataParallel` gradient buckets have been + rebuilt; or if this method is called without a preceding call + to :meth:`consolidate_state_dict`. + """ + self._check_overlap_initialized() + + if len(self._all_state_dicts) == 0: + raise RuntimeError( + "Optimizer state has not been consolidated on this rank. " + f"Please call `consolidate_state_dict(to={self.rank})` on " + "all ranks beforehand if you meant to save the global state." + ) + + # Get the possibly-stale global optimizer state that uses global + # parameter indexing + state_dict = super().state_dict() + + # Update the global optimizer state with local state information, + # factoring in the translation from local to global indexing + for rank, local_state_dict in enumerate(self._all_state_dicts): + local_param_groups = local_state_dict["param_groups"] + global_param_groups = self._partition_parameters()[rank] + assert len(local_param_groups) == len( + global_param_groups + ), "Mismatch between number of local and global parameter groups" + + for local_param_group, global_param_group in zip( + local_param_groups, global_param_groups + ): + # `local_param_group` stores local indices, while + # `global_param_group` stores the tensors directly + local_param_indices = local_param_group["params"] + global_params = global_param_group["params"] + + assert len(local_param_indices) == len( + global_params + ), "Mismatch between number of local and global parameters in parameter group" + for local_param_index, global_param in zip( + local_param_indices, global_params + ): + # Update the global parameter state, if any + if local_param_index in local_state_dict["state"]: + global_param_index = self._param_to_index[global_param] + state_dict["state"][global_param_index] = local_state_dict[ + "state" + ][local_param_index] + + # Sort the parameters in the state + state_dict["state"] = dict(sorted(state_dict["state"].items())) + return state_dict + + @staticmethod + def _sync_param_groups( + src_param_groups: List[Dict[Any, Any]], + dst_param_groups: List[Dict[Any, Any]], + ) -> None: + r""" + Sync the attributes from the source parameter groups to the destination parameter groups. + + Example attributes include learning rate or scheduler attributes. The + two parameter groups should have the same length (i.e. same number of + parameter groups). + + Arguments: + src_param_groups (list[dict]): parameter groups giving the + attribute settings to copy. + dst_param_groups (list[dict]): parameter groups giving the + attribute settings to set. + """ + assert len(src_param_groups) == len( + dst_param_groups + ), "Mismatch between number of source and destination parameter groups" + for src_param_group, dst_param_group in zip(src_param_groups, dst_param_groups): + # Sync all attributes except the parameters + for attr in filter(lambda x: x != "params", src_param_group.keys()): + dst_param_group[attr] = src_param_group[attr] + + def _build_param_buckets(self) -> None: + r""" + Build parameter buckets if ``parameters_as_bucket_view=True``. + + For each device that stores this rank's parameters, there is a + bucket (represented as a tensor) containing all of the parameters on + that device that are assigned to a given rank in the parameter update + partition. + + This method is called in the constructor and any time parameter + trainability is changed. + + .. warning:: + The current implementation assumes that all of the parameters in a + bucket are of the same dense type when allocating the bucket's + tensor. + + .. warning:: + If the model parameters are stored across more than one device, + then the storage partitioning must be the same across all + processes in order for parameter synchronization to work. + """ + if not self.parameters_as_bucket_view or self._overlap_with_ddp: + return + + # `self._buckets[i][j]` are the parameters stored on device i and + # assigned to rank j + num_devices = len(self._device_to_params_per_rank) + self._buckets = [[] for _ in range(num_devices)] # type: ignore[assignment] + + for dev_i, (device, params_per_rank) in enumerate( + self._device_to_params_per_rank.items() + ): + for params in params_per_rank: + bucket_size = 0 + dtype = None + trainable_params = [] + for param in params: + if not _is_trainable(param): + # Clone in case the parameter was previously part of + # a bucket to avoid the data from being destroyed + param.data = param.data.detach().clone() + else: + bucket_size += param.numel() + trainable_params.append(param) + dtype = param.dtype # assumes all same dtype + + if bucket_size == 0: + # Create a dummy bucket if there are no parameters + bucket = core.zeros(1, device=device) + else: + # Construct the bucket (assuming all dense and same dtype) + bucket = core.empty(bucket_size, dtype=dtype, device=device) + offset = 0 + for param in trainable_params: + offset_next = offset + param.numel() + bucket[offset:offset_next].copy_(param.data.flatten()) + param.data = bucket[offset:offset_next].view_as(param.data) + offset = offset_next + self._buckets[dev_i].append(bucket) # type: ignore[arg-type] + + def _build_ddp_param_buckets(self) -> None: + r""" + Build the DDP bucket with parameters assigned to this rank. + + For each DDP bucket with parameters assigned to this rank, flattens the + data of those parameters into a single tensor and saves the tensor to + the ``tensor`` attribute in the corresponding + :class:`_DDPBucketAssignment` instance stored in + ``self._bucket_assignments_per_rank``. + + :class:`DistributedDataParallel` guarantees that the parameters + corresponding to a gradient bucket have the same device and the same + dtype. + """ + for bucket_assignments in self._bucket_assignments_per_rank: + for bucket_assignment in bucket_assignments.values(): + params = bucket_assignment.parameters + bucket_size = 0 + dtype = None + for param in params: + assert _is_trainable(param), ( + "Model parameter " + "corresponding to a gradient in a DDP bucket should " + "require a gradient" + ) + bucket_size += param.numel() + dtype = param.dtype # assumes all same dtype + assert bucket_size > 0, "Empty bucket" + + # Construct the bucket tensor (assuming all dense and same dtype) + tensor = core.empty( + bucket_size, dtype=dtype, device=bucket_assignment.device + ) + offset = 0 + for param in params: + offset_next = offset + param.numel() + tensor[offset:offset_next].copy_(param.data.flatten()) + param.data = tensor[offset:offset_next].view_as(param.data) + offset = offset_next + bucket_assignment.tensor = tensor + + def _verify_and_init_params( + self, + params: Any, + ) -> Union[List[core.Tensor], List[dict]]: + r""" + Verify the type of ``params`` and initializes ``self._all_params`` as a :class:`list` of all parameters. + + The initializagtion will first make sure that provided ``params`` is valid. + + Arguments: + params (Any): Candidate parameter list or parameter groups to verify. + + Raises: + TypeError: ``params`` has an invalid type. + ValueError: ``params`` is empty. + + Returns: + The persistent form of ``params`` to be passed into the parent + :class:`Optimizer` constructor -- i.e. returns ``params`` as a + :class:`list` to ensure that it can be iterated over again. + """ + if isinstance(params, core.Tensor): + raise TypeError( + "`params` argument should be an iterable of " + f"Tensors, but got {core.typename(params)}" + ) + try: + all_params = list(params) + except TypeError as e: + raise TypeError( + "`params` argument should be an iterable of Tensors" + f" or dicts, but got {core.typename(params)}" + ) from e + if len(all_params) == 0: + raise ValueError("ZeroRedundancyOptimizer got an empty parameter list") + all_tensors = True + all_dicts = True + for param in all_params: + all_tensors &= isinstance(param, core.Tensor) + all_dicts &= isinstance(param, dict) + if not all_tensors and not all_dicts: + raise TypeError( + "`params` argument should be an iterable of Tensors or dicts" + ) + # Ensure that `self._all_params` contains a list of all parameters + if all_tensors: + self._all_params = all_params + elif all_dicts: + self._all_params = [] + # `all_params` contains parameter groups (not parameters) + for param_group in all_params: + if "params" not in param_group: + raise ValueError( + "Each parameter group passed-in via `params` must " + "have a 'params' key mapping to the parameters in " + "the group" + ) + self._all_params.extend(param_group["params"]) + return all_params + + def _verify_same_dense_param_type(self) -> None: + r""" + Verify that all parameters are of the same dense type. + + The method assumes that ``self._all_params`` has been initialized + and is non-empty. + + Raises: + ValueError: ``params`` contains sparse parameters or parameters + of varying dense types. + + NOTE: This method can be removed once support for sparse parameters + and varying parameter types is added. + """ + typename = core.typename(self._all_params[0]) + if self._all_params[0].is_sparse: + raise ValueError( + "ZeroRedundancyOptimizer only supports using " + "the same dense type for all parameters but got " + f"{typename}" + ) + for param in self._all_params[1:]: + other_typename = core.typename(param) + if other_typename != typename: + raise ValueError( + "ZeroRedundancyOptimizer only supports " + "using the same dense type for all " + f"parameters but got both {typename} and " + f"{other_typename}" + ) + + def _get_is_trainable_mask(self) -> List[bool]: + r"""Return a boolean mask indicating if each parameter is trainable (``requires_grad``) or not.""" + return list(map(_is_trainable, self._all_params)) + + def _init_local_optimizer(self) -> None: + r""" + Initialize this rank's local optimizer, responsible for its subset of the parameters. + + The local optimizer is saved in ``self.optim``. + """ + assert ( + self._optim_constructor is not None + ), "The local optimizer class has not been set" + + param_groups = self._partition_parameters()[self.rank] + # `overlap_with_ddp=True` requires a local functional optimizer + if self._overlap_with_ddp: + # Functional optimizers only support a single parameter group and + # require passing in the parameters as a list + assert len(param_groups) == 1, ( + "Initializing the local " + "functional optimizer with more than one parameter group" + ) + params = param_groups[0]["params"] + # Try to pass `_allow_empty_param_list=True` to avoid erroring + if ( + "_allow_empty_param_list" + in inspect.signature(self._optim_constructor).parameters + ): + self.optim: Any = self._optim_constructor( + params, **self._optim_defaults, _allow_empty_param_list=True + ) + else: + logger.warning( + "%s does not support the argument " + "`_allow_empty_param_list`; ZeroRedundancyOptimizer may " + "error due to an empty parameter list", + self._optim_constructor, + ) + self.optim: Any = self._optim_constructor(params, **self._optim_defaults) # type: ignore[no-redef] + + # Log information about the DDP and ZeRO bucketing + if dist.get_debug_level() != dist.DebugLevel.OFF: + local_numel = sum(p.numel() for p in params) + num_assigned_buckets = len( + self._bucket_assignments_per_rank[self.global_rank] + ) + logger.info( + "rank %s with %s parameters " "across %s buckets", + self.global_rank, + local_numel, + num_assigned_buckets, + ) + if self.global_rank == 0: + logger.info( + "%s DDP " "buckets and " "%s bucket " "assignments", + len(self._overlap_info.params_per_bucket), + self._overlap_info.num_bucket_assignments, + ) + else: + # NOTE: Passing `param_groups` into the local optimizer constructor + # bypasses the empty parameter list check + self.optim: Optimizer = self._optim_constructor(param_groups, **self._optim_defaults) # type: ignore[no-redef] + + # TODO: Manually add `self.param_groups` if using a functional + # optimizer; remove this if/when the functional optimizers support + # multiple parameter groups + if self._overlap_with_ddp and not hasattr(self.optim, "param_groups"): + assert hasattr(self.optim, "param_group"), ( + "The functional optimizer should set at least one of the " + "attributes `param_group` or `param_groups`" + ) + self.optim.param_groups = [self.optim.param_group] # type: ignore[attr-defined] + + self._sync_param_groups(self.optim.param_groups, self.param_groups) + + def _init_zero_for_overlap(self) -> None: + r"""Perform a delayed initialization of the local optimizer and the supporting data structures.""" + assert self._overlap_with_ddp, ( + "`_init_zero_for_overlap()` should only be called when " + "`overlap_with_ddp=True`" + ) + self._overlap_info.status = _OverlapStatus.INITIALIZED + self._clear_cache() + self._partition_parameters(self._overlap_info.params_per_rank) + self._build_ddp_param_buckets() + self._init_local_optimizer() + + def _get_assigned_rank(self, bucket_index: int) -> int: + r""" + Return the single rank assigned to a :class:`DistributedDataParallel` gradient bucket. + + Arguments: + bucket_index (int): index of the :class:`DistributedDataParallel` + bucket for which to get the assigned rank. + """ + assert not self._overlap_info.shard_buckets, ( + "The bucket assignment requires global bucket information and " + "will be computed later; there should be no need to use this " + "method" + ) + return bucket_index % self.world_size + + def _check_overlap_initialized(self): + r""" + Check the delayed initialization depending on the value of ``overlap_with_ddp``. + + The delayed initialization has occurred (see + :meth:`_init_zero_for_overlap`) if ``overlap_with_ddp=True``, and + raises a ``RuntimeError`` if not. This should preface methods that + should not be run before that delayed initialization. + + Raises: + RuntimeError: if ``overlap_with_ddp=True`` and + :meth:`_init_zero_for_overlap` has not been called. + """ + if ( + self._overlap_with_ddp + and self._overlap_info.status != _OverlapStatus.INITIALIZED + ): + raise RuntimeError( + "This method should not be called until this " + "ZeroRedundancyOptimizer instance has been fully " + "initialized" + ) + + def _get_optimizer_constructor(self, optimizer_class: Any) -> Any: + r""" + Return the optimizer constructor using validation and transformation depending on ``overlap_with_ddp``. + + Returns: + - ``optimizer_class`` if ``overlap_with_ddp=False`` and + ``optimizer_class`` is not a functional optimizer. + - ``optimizer_class`` if ``overlap_with_ddp=True`` and + ``optimizer_class`` is already a functional optimizer. + - The functional equivalent of ``optimizer_class`` if + ``overlap_with_ddp=True`` and ``optimizer_class`` is not + already a functional optimizer (assuming the equivalent + exists). + + Raises: + ValueError: + + - if ``overlap_with_ddp=True`` but ``optimizer_class`` is + neither a functional optimizer nor translatable to a + functional optimizer. + - if ``overlap_with_ddp=False`` and ``optimizer_class`` is a + functional optimizer. + """ + functional_optims = functional_optim_map.values() + if not self._overlap_with_ddp: + if optimizer_class in functional_optims: + # Using a functional optimizer is only supported when + # `overlap_with_ddp=True` + raise ValueError( + f"Passing in a functional optimizer {optimizer_class} " + "when `overlap_with_ddp=False`" + ) + else: + return optimizer_class + else: + if optimizer_class in functional_optims: + # Already a functional optimizer + return optimizer_class + elif optimizer_class in functional_optim_map: + # Translate the passed-in optimizer class to its functional + # equivalent if `overlap_with_ddp=True` + optim_constructor = functional_optim_map[optimizer_class] + logger.info( + "Using the functional optimizer %s " + "instead of %s since " + "`overlap_with_ddp=True`", + optim_constructor, + optimizer_class, + ) + return optim_constructor + else: + raise ValueError( + "Using `ddp_with_overlap=True` requires using a " + "functional optimizer, but there is no supported functional " + f"optimizer equivalent for {optimizer_class}" + ) diff --git a/mindnlp/core/distributed/optim/zero_redundancy_optimizer.pyi b/mindnlp/core/distributed/optim/zero_redundancy_optimizer.pyi new file mode 100644 index 000000000..704f93db9 --- /dev/null +++ b/mindnlp/core/distributed/optim/zero_redundancy_optimizer.pyi @@ -0,0 +1,84 @@ +# mypy: allow-untyped-defs +import enum +from typing import Any, Callable, overload + +from mindnlp import core +from core.distributed.algorithms.join import Joinable, JoinHook +from core.optim import Optimizer + +class _ZeROJoinHook(JoinHook): + zero: Any = ... + def __init__(self, zero: Any) -> None: ... + def main_hook(self) -> None: ... + +class _DDPBucketAssignment: + bucket_index: int + parameters: list[core.Tensor] + offset: int + device: core.device + tensor: core.Tensor | None + +class _OverlapStatus(enum.IntEnum): + UNINITIALIZED: int = ... + DDP_HAS_REBUILT_BUCKETS: int = ... + INITIALIZED: int = ... + +class _OverlapInfo: + status: Any = ... + params_per_bucket: Any = ... + params_per_rank: Any = ... + offsets: Any = ... + broadcast_handles: Any = ... + bucket_index_to_future: Any = ... + bucket_index_to_bucket: Any = ... + bucket_indices_seen: Any = ... + assigned_ranks_per_bucket: list[set[int]] = ... + total_size: int = ... + shard_buckets: bool = ... + def __init__(self) -> None: ... + def wait_for_broadcasts(self) -> None: ... + def clear_per_iter_info(self) -> None: ... + +class ZeroRedundancyOptimizer(Optimizer, Joinable): + functional_optim_map: Any = ... + initialized: bool = ... + process_group: Any = ... + world_size: int = ... + rank: int = ... + global_rank: int = ... + parameters_as_bucket_view: bool = ... + optim: Any = ... + _device_to_device_index: dict[core.device, int] = ... + _overlap_with_ddp: bool = ... + _overlap_info: _OverlapInfo = ... + _buckets: list[list[core.Tensor]] = ... + _bucket_assignments_per_rank: list[dict[int, _DDPBucketAssignment]] = ... + def __init__( + self, + params: Any, + optimizer_class: type[Optimizer], + process_group: Any | None = ..., + parameters_as_bucket_view: bool = ..., + overlap_with_ddp: bool = ..., + **defaults: Any, + ) -> None: ... + def add_param_group(self, param_group: dict[str, Any]) -> None: ... + def consolidate_state_dict(self, to: int = ...) -> None: ... + @overload + def step(self, closure: None = ..., **kwargs: Any) -> None: ... + @overload + def step(self, closure: Callable[[], float], **kwargs: Any) -> float: ... + def load_state_dict(self, state_dict: dict[str, Any]) -> None: ... + def state_dict(self) -> dict[str, Any]: ... + def _local_step( + self, + gradients: list[core.Tensor | None] | None = None, + closure: Callable[[], float] | None = None, + **kwargs: Any, + ) -> float | None: ... + def _get_assigned_rank(self, bucket_index: int) -> int: ... + def _init_zero_for_overlap(self) -> None: ... + def join_hook(self, **kwargs): ... + @property + def join_device(self) -> core.device: ... + def join_process_group(self) -> Any: ... diff --git a/mindnlp/core/distributed/pipelining/README.md b/mindnlp/core/distributed/pipelining/README.md new file mode 100644 index 000000000..11938bf35 --- /dev/null +++ b/mindnlp/core/distributed/pipelining/README.md @@ -0,0 +1,7 @@ +# Pipeline Parallelism for PyTorch + +`core.distributed.pipelining` is a package for implementing pipeline parallelism on your model. + +Our documentation is available [here](https://pycore.org/docs/main/distributed.pipelining.html). + +![pipeline_diagram_web](https://github.com/pytorch/PiPPy/assets/6676466/c93e2fe7-1cd4-49a2-9fd8-231ec9905e0c) diff --git a/mindnlp/core/distributed/pipelining/_IR.py b/mindnlp/core/distributed/pipelining/_IR.py new file mode 100644 index 000000000..f433a82a3 --- /dev/null +++ b/mindnlp/core/distributed/pipelining/_IR.py @@ -0,0 +1,1242 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates +import copy +import logging +import operator +from collections import defaultdict +from enum import Enum +from inspect import Parameter, Signature, signature +from types import MethodType +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union + +from mindnlp import core +from mindnlp import core.fx as fx +from core.distributed import ProcessGroup +from core.export import ExportedProgram +from core.export.unflatten import ( + _assign_attr, + _AttrKind, + _sink_params, + InterpreterModule, +) +from core.fx.node import map_aggregate +from core.fx.passes.split_module import split_module + +from ._backward import _null_coalesce_accumulate, stage_backward +from ._unflatten import _outline_submodules +from ._utils import PipeInfo +from .stage import _PipelineStage + + +logger = logging.getLogger(__name__) + +# TODO: +# 1. investigate gradient sync for shared parameters. how does DDP do it? +# 2. Add parameter movement to split_module + + +def _find_loss_from_output_and_spec(output_val, spec_val): + if spec_val is False: + return None + if spec_val is True: + if not isinstance(output_val, fx.Node): + raise RuntimeError( + f"Loss spec must specify a dynamic value but got {output_val}" + ) + return output_val + + if isinstance(spec_val, (tuple, list)): + if not isinstance(output_val, (tuple, list)): + raise RuntimeError( + f"Output value {output_val} must match type of loss specification " + f"{spec_val}" + ) + if len(output_val) != len(spec_val): + raise RuntimeError( + f"Output value {output_val} must match length of loss specification " + f"{spec_val}" + ) + for out, spec in zip(output_val, spec_val): + loss_val = _find_loss_from_output_and_spec(out, spec) + if loss_val is not None: + return loss_val + raise RuntimeError(f"Did not find loss value in specification {spec_val}") + + if isinstance(spec_val, dict): + if not isinstance(output_val, dict): + raise RuntimeError( + f"Output value {output_val} must match type of loss specification " + f"{spec_val}" + ) + if set(output_val.keys()) != set(spec_val.keys()): + raise RuntimeError( + f"Output value {output_val} must match keys of loss specification " + f"{spec_val}" + ) + for k in spec_val: + loss_val = _find_loss_from_output_and_spec(output_val[k], spec_val[k]) + if loss_val is not None: + return loss_val + raise RuntimeError(f"Did not find loss value in specification {spec_val}") + + raise RuntimeError(f"Unsupported type {type(spec_val)} in loss specification") + + +def _find_loss_output(mod: core.nn.Module, g: fx.Graph, output_loss_value_spec): + output_nodes = [n for n in g.nodes if n.op == "output"] + assert len(output_nodes) == 1 + output_node = output_nodes[0] + output_val = output_node.args[0] + generated_spec: Any = None + + if isinstance(mod, TrivialLossWrapper): + # TrivialLossWrapper is pre-defined by PiPPy. + # It has loss as the only output so we can safely assume the first output arg is the loss. + assert len(output_node.args) == 1 + loss_node = output_val + generated_spec = TrivialLossWrapper.loss_spec + elif output_loss_value_spec is None: + # Use default spec, i.e. search for "loss" in output values + if isinstance(output_val, dict) and "loss" in output_val.keys(): + loss_node = output_val["loss"] + generated_spec = {k: k == "loss" for k in output_val} + else: + loss_node = None + generated_spec = None + else: + loss_node = _find_loss_from_output_and_spec(output_val, output_loss_value_spec) + generated_spec = output_loss_value_spec + + return loss_node, output_node, generated_spec + + +def _insert_stage_symbolic_backward( + g: fx.Graph, + loss_node: fx.Node, + output_node: fx.Node, +): + # Collect metadata about tuple output values. TODO: move this to split_module or FX IR + tuples: Dict[fx.Node, Tuple] = {} + for node in reversed(g.nodes): + if node.op == "call_function": + # In the forward pass, only emit placeholder, module calls, and + # getitem calls. If we have a target other than getitem in this + # (forward-only) code, there is a bug. + assert node.target == operator.getitem, ( + "Found non-getitem call in forward pass. " + "Please report a bug to PiPPy" + ) + assert ( + len(node.args) == 2 + ), "Found malformed getitem call. Please report a bug to PiPPy" + indexed_value, node_idx = tuple(node.args) + + # indexed_value is a collection that we are indexing into. It could + # exist in the tuples map if we've processed another `getitem` + # already. + existing_list_size = ( + len(tuples[indexed_value]) if indexed_value in tuples else -1 + ) + new_list_size = max(node_idx + 1, existing_list_size) + + reconstructed_list = [None for _ in range(new_list_size)] + + # Copy over existing elements if present + if indexed_value in tuples: + for i, val in enumerate(tuples[indexed_value]): + reconstructed_list[i] = val + + # Populate value represented by this node + reconstructed_list[node_idx] = node + + tuples[indexed_value] = tuple(reconstructed_list) + + # Keep track of nodes that dominate the loss node. + # We will only emit backward operations for nodes that can contribute + # to the specified loss value. + live_nodes = {loss_node: None} + val_to_grad: Dict[fx.Node, Optional[fx.Node]] = {loss_node: None} + + def assign_or_accumulate_grad(forward_node, grad_value): + if forward_node in val_to_grad and forward_node.op != "placeholder": + grad_value = g.call_function( + _null_coalesce_accumulate, + (val_to_grad[forward_node], grad_value), + ) + val_to_grad[forward_node] = grad_value + + with g.inserting_before(output_node): + for node in reversed(g.nodes): + if node not in live_nodes: + continue + + def add_to_live_nodes(n): + live_nodes.setdefault(n, None) + + fx.node.map_arg(node.args, add_to_live_nodes) + fx.node.map_arg(node.kwargs, add_to_live_nodes) + if node.op == "call_module": + output_grads: Union[Tuple[Optional[fx.Node], ...], Optional[fx.Node]] + if node in tuples: + stage_output = tuples[node] + output_grads = tuple(val_to_grad.get(n, None) for n in tuples[node]) + outputs_with_grads_idxs = [ + i for i, n in enumerate(tuples[node]) if n in live_nodes + ] + else: + stage_output = (node,) + output_grads = val_to_grad[node] + outputs_with_grads_idxs = [0] + + output_grads = ( + (output_grads,) + if not isinstance(output_grads, tuple) + else output_grads + ) + + grad_call = g.call_function( + stage_backward, + kwargs={ + "stage_output": stage_output, + "output_grads": output_grads, + "input_values": list(node.all_input_nodes), + "outputs_with_grads_idxs": outputs_with_grads_idxs, + }, + ) + # Insert backward stage debug info + kwargs_copy = dict(grad_call.kwargs) + grad_call.kwargs = kwargs_copy + + grad_call_proxy = fx.Proxy(grad_call) + grads = grad_call_proxy.node + + input_nodes = list(node.all_input_nodes) + grads_proxy = fx.Proxy(grads) + for i, input_node in enumerate(input_nodes): + assign_or_accumulate_grad(input_node, grads_proxy[i].node) # type: ignore[index] + + return g + + +class PipeSequential(core.nn.Sequential): + @staticmethod + def from_sequential(sequential_instance: core.nn.Sequential): + return PipeSequential(*[copy.copy(m) for m in sequential_instance]) + + def forward(self, input): + for i, module in enumerate(self): + input = module(input) + if i != len(self) - 1: + pipe_split() + return input + + +class LossWrapper(core.nn.Module): + """ + LossWrapper is a convenient abstract class that allows you to wrap up both + your model as well as its loss function and specify the connectivity between + the inputs, model, loss function, and output value. Example:: + + class MyModelWrapper(LossWrapper): + def forward(self, x, targets): + model_out = self.module(x) + loss_value = self.loss_fn(model_out, targets) + return loss_value + + The above example defines a connectivity where we expect the forward/loss/backward + training procedure to take two arguments (x and targets), pass x into the module + to get the output of the feedforward computation, pass the model output and the + targets value into the loss function, and get and return the loss value, which will + be backpropagated by PiPPy. The above class would then be instantiated like:: + + model = ... # instantiate the model + loss_fn = core.nn.MSELoss() # for the sake of demonstration + + wrapper = MyModelWrapper(model, loss_fn) + pipe = Pipe.from_tracing(wrapper, ...) + + """ + + def __init__(self, module, loss_fn): + super().__init__() + self.module = module + self.loss_fn = loss_fn + + def forward(self, *args, **kwargs): + raise NotImplementedError( + "This instance of LossWrapper does not have an overridden" + "forward(). Please implement forward() to specify the arguments, " + "connection between the module and loss, and loss output " + "value." + ) + + +class TrivialLossWrapper(LossWrapper): + def forward(self, x, targets): + model_out = self.module(x) + return self.loss_fn(model_out, targets) + + loss_spec = True + + +# Pipe model representation +# +# Pipe can be thought of as an `nn.Sequential++`. That is to say: it specifies +# a single topological ordering of pipeline "stages" that, when run in series, +# constitutes all of the operations of the program. However, unlike `nn.Sequential`, +# Pipe allows non-local usages of values, so long as those uses still respect +# topological ordering. In particular: +# +# 1. Non-local activations. This type of usage can appear in, for example, skip +# connections. These values will be directly transmitted from the "def" stage +# to all stages that use them skipping intermediate stages. During autograd, +# gradients will be propagated back through this skip connection reverse +# to how activations propagated in the forward pass. +# 2. Non-local parameter/module invocations. This occurs when a parameter is used +# in a stage downstream of where it is resident. These values can be carried +# forward similarly to (1), but in addition one might want to replicate the +# value on multiple stages. Gradients for these shared parameters will be +# accumulated separately on each stage, but there will be an additional +# gradient accumulation before the optimizer step. + + +# Register `_pipe_split()` as an ATen operator. This is required for Export to +# preserve this marker in the graph. +core.library.define("pippy::_pipe_split", "() -> ()") + + +@core.library.impl("pippy::_pipe_split", "BackendSelect") +def _pipe_split(): + return None + + +@core.library.register_fake("pippy::_pipe_split") # type: ignore[no-redef] +def _pipe_split(): # noqa: F811 + return None + + +# Add an alias for convenience +aten_pipe_split_alias = core.ops.pippy._pipe_split.default + +# Ask Export to preserve the `_pipe_split` op. +# See examples in pytorch/torch/fx/node.py +fx.node._side_effectful_functions.add(aten_pipe_split_alias) + + +# User facing API +def pipe_split(): + """ + pipe_split is a special operator that is used to mark the boundary between + stages in a module. It is used to split the module into stages. It is a + no-op if your annotated module is run eagerly. + + Example: + >>> # xdoctest: +SKIP + >>> def forward(self, x): + >>> x = core.mm(x, self.mm_param) + >>> x = core.relu(x) + >>> pipe_split() + >>> x = self.lin(x) + >>> return x + + The above example will be split into two stages. + """ + return core.ops.pippy._pipe_split() + + +class MultiUseParameterConfig(Enum): + TRANSMIT = 1 + REPLICATE = 2 + + +MultiUseParamSpec = Union[MultiUseParameterConfig, Dict[str, MultiUseParameterConfig]] + + +class DetachExecutor(fx.Interpreter): + """ + Special interpreter to run the split_gm in testing that detaches all inputs to + a module invocation. This is needed so that the values at the boundary are + leaf modules in autograd execution. + """ + + def __init__(self, module, garbage_collect_values=True): + garbage_collect_values = False + super().__init__(module, garbage_collect_values) + self.value_remap = {} + + def run(self, *args, initial_env=None): # type: ignore[override] + self.value_remap = {} + return super().run(*args, initial_env=initial_env) + + def call_module(self, target, args, kwargs): + def detach_tensors(a): + if isinstance(a, core.Tensor) and a.requires_grad: + if a not in self.value_remap: + new_val = a.detach().requires_grad_(True) + self.value_remap[a] = new_val + return self.value_remap[a] + else: + return a + + """ + def dont_traverse_size(a): + return type(a) != core.Size + """ + + args = map_aggregate( + args, + detach_tensors, # dont_traverse_size + ) + kwargs = map_aggregate( + kwargs, + detach_tensors, # dont_traverse_size + ) + + return super().call_module(target, args, kwargs) + + def call_function(self, target, args, kwargs): + # HACK to reroute saved input tensors to point to the detach()ed version + if target == stage_backward: + kwargs = dict(kwargs) + kwargs["input_values"] = [ + self.value_remap.get(v, v) for v in kwargs["input_values"] + ] + return super().call_function(target, args, kwargs) + + +class _NodeReference: + def __init__(self, name): + self.name = name + + name: str + + +class _LinearNodeList: + def __init__(self, node_list): + self.serialize_node_list = [] + for node in node_list: + node_args = fx.node.map_arg(node.args, lambda n: _NodeReference(n.name)) # type: ignore[arg-type,return-value] + node_kwargs = fx.node.map_arg(node.kwargs, lambda n: _NodeReference(n.name)) # type: ignore[arg-type,return-value] + serialize_node = fx.Node( + graph=None, # type: ignore[arg-type] + name=node.name, + op=node.op, + target=node.target, + args=node_args, # type: ignore[arg-type] + kwargs=node_kwargs, # type: ignore[arg-type] + return_type=node.type, + ) + serialize_node.meta = copy.copy(node.meta) + self.serialize_node_list.append(serialize_node) + + def to_graph(self): + graph = fx.Graph() + + ref_str_to_node: Dict[str, fx.Node] = {} + + def ref_to_node(arg): + if isinstance(arg, _NodeReference): + return ref_str_to_node[arg.name] + else: + return arg + + for node in self.serialize_node_list: + node_args = map_aggregate(node.args, ref_to_node) + node_kwargs = map_aggregate(node.kwargs, ref_to_node) + deser_node = graph.create_node( + op=node.op, + target=node.target, + args=node_args, # type: ignore[arg-type] + kwargs=node_kwargs, # type: ignore[arg-type] + name=node.name, + type_expr=node.type, + ) + ref_str_to_node[node.name] = deser_node + + return graph + + +def _direct_serialization_deserialize(body, nodes): + """ + Custom `__reduce__` method for serialization. + DO AS I SAY -- NOT AS I DO. This violates the principle that + GraphModules serialize via code export & re-tracing. We allow + for this here because **PIPE STAGES SHOULD NOT BE PERSISTED + TO DISK -- THIS IS ONLY FOR TRANSMISSION VIA RPC**. Persisting + these instances to disk will expose internal implementation + details of `fx.Graph` and related data structures and is + NOT advised. + """ + + class DummyModule(core.nn.Module): + def __init__(self, body): + super().__init__() + self.__dict__.update(body) + + dummy = DummyModule(body) + + return fx.GraphModule(dummy, nodes.to_graph()) + + +def _direct_serialization_reduce(self): + serialization_dict = dict(self.__dict__) + serialization_dict.pop("_graph") + return ( + _direct_serialization_deserialize, + (serialization_dict, _LinearNodeList(self.graph.nodes)), + ) + + +def _modify_graph_op_device( + gm: core.fx.GraphModule, + new_device: core.device, +): + """ + Modify the device argument of all "call_function" nodes in the graph. This + is useful for moving the graph to a different device. In particular for + generator ops, like core.ones. + """ + modified = False + for node in gm.graph.nodes: + if node.op == "call_function": + if "device" in node.kwargs and node.kwargs["device"] != new_device: + logger.debug( + f"Changing device of Node {node.name} from {node.kwargs['device']} to {new_device}" # noqa: G004 + ) + node.update_kwarg("device", new_device) + modified = True + elif node.op == "call_module": + # Recursively modify "device" in submodules + submod = gm.get_submodule(node.target) + if isinstance(submod, core.fx.GraphModule): + _modify_graph_op_device(submod, new_device) + elif isinstance(submod, InterpreterModule): + # If unflattening has been performed, we need to access its graph module by `.graph_module` + _modify_graph_op_device(submod.graph_module, new_device) # type: ignore[arg-type] + else: + logger.warning( + f"Skipping device modification for submodule {node.target} because it is a {type(submod)}" # noqa: G004 + ) + + if modified: + gm.recompile() + + +class Pipe(core.nn.Module): + def __init__( + self, + split_gm: fx.GraphModule, + num_stages: int, + has_loss_and_backward: bool, + loss_spec, + ): + # TODO: is there a way not to hard wire init? + core.nn.Module.__init__(self) + self.split_gm: fx.GraphModule = split_gm + self.executor: DetachExecutor = DetachExecutor(self.split_gm) + self.num_stages: int = num_stages + self.has_loss_and_backward = has_loss_and_backward + self.loss_spec = loss_spec + + for node in split_gm.graph.nodes: + assert ( + node.op in {"call_module", "placeholder", "output"} + or (node.op, node.target) == ("call_function", operator.getitem) + or (node.op, node.target) == ("call_method", "backward") + or (node.op, node.target) == ("call_function", stage_backward) + or (node.op, node.target) + == ("call_function", _null_coalesce_accumulate) + ), node + + # Detect replicated parameters so we know that we have to do an additional allreduce + # before applying the optimizer + # + # Note that this also handles the case where there were multiple calls to a single + # module from different stages, regardless of whether that module invocation + # was handled by the logic above. + + # Map parameter value to a dictionary that maps the user pipeline module + # to the local qualname within that module + params_to_users: Dict[core.nn.Parameter, Dict[str, str]] = {} + + for m_qualname, mod in self.split_gm.named_children(): + for p_qualname, param in mod.named_parameters(): + params_to_users.setdefault(param, {}) + params_to_users[param][m_qualname] = p_qualname + + self.replicated_params: List[Dict[str, str]] = [ + use_mapping + for _, use_mapping in params_to_users.items() + if len(use_mapping) > 1 + ] + + # We must break the aliasing relationship between the replicated parameters for correct + # numerics in reference runs. If we do not do this, the autograd tape in separate stages + # will have a reference to the same tensor value and will erroneously apply gradient + # updates multiple times. Therefore, for each replicated parameter set, we deepcopy the + # values so that we have separate instances. + for param_mapping in self.replicated_params: + for submod_name, param_qualname in param_mapping.items(): + submod = getattr(self.split_gm, submod_name) + atoms = param_qualname.split(".") + for atom in atoms[:-1]: + submod = getattr(submod, atom) + setattr(submod, atoms[-1], copy.deepcopy(getattr(submod, atoms[-1]))) + + def throw(self, *args, **kwargs): + raise RuntimeError( + "To run pipeline locally, invoke the Pipe object directly, not `split_gm`" + ) + + self.split_gm.forward = throw + + # Make submodules use custom direct-serialized GraphModule + i = 0 + while True: + try: + name = f"submod_{i}" + submod = getattr(self.split_gm, name) + submod.__class__.__reduce__ = _direct_serialization_reduce + i += 1 + except AttributeError: + break + + def forward(self, *args, **kwargs): + executor_args = args + if len(kwargs) > 0: + parameters = [] + for node in self.split_gm.graph.nodes: + if node.op == "placeholder": + if node.args and len(node.args) > 0: + parameters.append( + Parameter( + node.target, + Parameter.POSITIONAL_OR_KEYWORD, + default=node.args[0], + ) + ) + else: + parameter_kind = Parameter.POSITIONAL_OR_KEYWORD + param_name = node.target + if node.target.startswith("**"): + parameter_kind = Parameter.VAR_KEYWORD # type: ignore[assignment] + param_name = param_name[2:] + elif node.target.startswith("*"): + parameter_kind = Parameter.VAR_POSITIONAL # type: ignore[assignment] + param_name = param_name[1:] + parameters.append(Parameter(param_name, parameter_kind)) + signature = Signature(parameters) + ba = signature.bind(*args, **kwargs) + ba.apply_defaults() + executor_args = ba.arguments.values() # type: ignore[assignment] + + res = self.executor.run(*executor_args) + + return res + + def get_stage_module(self, stage_idx: int) -> core.nn.Module: + """ + Return a stage module corresponding to `stage_idx` of the `pipe`. + """ + if stage_idx < 0 or stage_idx >= self.num_stages: + raise ValueError(f"Invalid stage index {stage_idx}!") + return getattr(self.split_gm, f"submod_{stage_idx}") + + @staticmethod + def _number_and_count_forward_stages(gm: fx.GraphModule): + num_stages = 0 + found_idxs: Dict[int, None] = {} + for node in gm.graph.nodes: + if node.op == "call_module" and node.target.startswith("submod_"): + node.meta["stage_idx"] = int(node.target[len("submod_") :]) + found_idxs.setdefault(node.meta["stage_idx"]) + num_stages += 1 + + # this assert will fail if a split point is inserted before the first layer, which creates empty first submodule + # Update: the following assert may fail against some torch versions >= + # 2.2.0, as: + # submod_0, submod_1, submod_2, ... + # may be named as + # submod_0, submod_2, submod_4, ... + # TODO: investigate + # assert all(i in found_idxs for i in range(num_stages)) + + return num_stages + + @staticmethod + def _from_traced( + mod: core.nn.Module, + exported_program: ExportedProgram, + multi_use_param_spec: Optional[MultiUseParamSpec] = None, + output_loss_value_spec=None, + split_policy: Optional[ + Callable[[core.fx.GraphModule], core.fx.GraphModule] + ] = None, + ): + """ + Additionally, the ``output_loss_value_spec`` value can be specified to disambiguate + which value in the output of `forward` is the loss value on which PiPPy should apply + backpropagation. For example, if your ``forward`` returns a tuple ``(loss, model_out)``, + you can specify ``output_loss_value_spec=(True, False)``. Or, if your ``forward`` returns + a dict ``{'loss': loss_value, 'model_out': model_out}``, you can specify + ``output_loss_value_spec={'loss': True, 'model_out': False}`` + """ + + traced = exported_program.module() + + if split_policy is not None: + logger.info("Auto-splitting model") + traced = split_policy(traced) # type: ignore[arg-type] + + logger.debug(traced.print_readable(print_output=False)) # type: ignore[operator] + + # Deduplicate `get_attr` nodes that refer to the same parameter . Downstream code for moving + # parameters relies on the invariant that parameter accesses happen once. This is not necessarily + # the case (especially with custom tracers), so fix that up here. + get_attr_nodes: Dict[str, fx.Node] = {} + for node in traced.graph.nodes: # type: ignore[union-attr] + if node.op == "get_attr": + get_attr_nodes.setdefault(node.target, node) + + if get_attr_nodes[node.target] != node: + node.replace_all_uses_with(get_attr_nodes[node.target]) + traced.graph.erase_node(node) # type: ignore[operator, union-attr] + + # avoid looking at next node by keeping track of previous pipe_split + prev_pipe_split_idx = -1 + pipe_split_nodes_to_erase = set() + for i, node in enumerate(traced.graph.nodes): # type: ignore[arg-type, union-attr] + if (node.op, node.target) == ("call_function", pipe_split): + if prev_pipe_split_idx == i - 1: + pipe_split_nodes_to_erase.add(node) + prev_pipe_split_idx = i + + for node in pipe_split_nodes_to_erase: + traced.graph.erase_node(node) # type: ignore[operator, union-attr] + + traced.recompile() # type: ignore[operator] + + part_idx = 0 + + def split_callback(n: fx.Node): + nonlocal part_idx + if (n.op, n.target) == ( + "call_function", + aten_pipe_split_alias, + ): + logger.debug(f"Found pipe_split {part_idx}") # noqa: G004 + part_idx += 1 + return part_idx + + # TODO: what does split do with module invocations? does it move the modules + # into the submodules? + split = split_module(traced, mod, split_callback) # type: ignore[arg-type] + # a (custom) tracer can produce dead code like orphan get_attr nodes + split.graph.eliminate_dead_code() + + # peephole to remove pipe_split + for submodule in split.modules(): + if isinstance(submodule, fx.GraphModule): + for node in submodule.graph.nodes: + if (node.op, node.target) == ( + "call_function", + aten_pipe_split_alias, + ): + submodule.graph.erase_node(node) + submodule.recompile() + + for name, submodule in split.named_children(): + if isinstance(submodule, fx.GraphModule): + new_submod = _outline_submodules(submodule.graph) + # Replace old submod + split.register_module(name, new_submod) + + # TODO: backport this into split_module + def delete_user_reference(node, user): + """ + Delete reference of `node` from `user`'s arg list. + Args: + - node: a `get_attr` node at root. + - user: a submodule node that uses `node`. + """ + assert len(user.kwargs) == 0 + use_idxs = [i for i, arg in enumerate(user.args) if arg == node] + assert len(use_idxs) == 1 + args_copy = list(user.args) + args_copy.pop(use_idxs[0]) + user.args = tuple(args_copy) + logger.debug( + f"Deleted {node} from user {user}, arg index = {use_idxs[0]}" # noqa: G004 + ) + + # A list of param referrals for deferred deletion. + # To be accumulated in `move_param_to_callee`. + to_delete = [] + + def _recursive_getattr_with_parent(mod, fqn): + # Returns getattr call given a nested FQN, and the last parent + atoms = fqn.split(".") + for atom in atoms[:-1]: + if not hasattr(mod, atom): + return None, None + mod = getattr(mod, atom) + if not hasattr(mod, atoms[-1]): + return mod, None + attr = getattr(mod, atoms[-1]) + return mod, attr + + def move_param_to_callee( + root, + callee_name, + param_fqn, + ): + """ + Move a parameter from the root module to a submodule. + Args: + root: The root module. + callee_name: The name of the submodule to move the parameter to. + param_fqn: The fully qualified name of the parameter to move. + """ + # `atoms` is a list of strings representing the path to the + # parameter in the original model + atoms = param_fqn.split(".") + mod_itr, param_val = _recursive_getattr_with_parent(split, param_fqn) + # Check whether the parameter is a buffer or a parameter + is_buffer = atoms[-1] in mod_itr._buffers + + # Check whether the parameter is a tensor + assert isinstance(param_val, core.Tensor), ( + f"Expected '{param_fqn}' to be {core.Tensor} but got {type(param_val)}." + + ( + f" It might happen if module '{param_fqn}' was passed to some 'leaf function'" + f"(see https://pycore.org/docs/stable/fx.html#fx.wrap). Please inspect " + f"usages of '{param_fqn}' in the traced graph." + if isinstance(param_val, core.nn.Module) + else "" + ) + ) + + # Get submodule + callee = root.get_submodule(callee_name) + assert not hasattr( + callee, param_fqn + ), f"Module {callee_name} already has a parameter named {param_fqn}" + + # Assign the parameter to the submodule + if is_buffer: + _assign_attr( + param_val, + callee, + param_fqn, + attr_kind=_AttrKind.BUFFER, + persistent=True, # TODO: handle non-persistent buffer + ) + else: + _assign_attr( + param_val, + callee, + param_fqn, + attr_kind=_AttrKind.PARAMETER, + ) + logger.debug(f"Moved parameter {param_fqn} to {callee_name}") # noqa: G004 + + # Next step is to replace placeholder of submodule with a get_attr. + # Those placeholders are created by `split_module` inside each + # submodule. + # Update: this step is now moved to `_sink_params` because + # `_sink_params` can do it recursively (i.e. for modules inside + # submodule) + + to_delete.append((mod_itr, atoms[-1])) + + # Get the list of all parameters in the root module + attr_nodes = list(filter(lambda n: n.op == "get_attr", split.graph.nodes)) + for node in attr_nodes: + # Check whether the parameter is used in only one submodule + if len(node.users) > 1: + logger.info( + f"Parameter {node.target} used in multiple stages: {node.users}." # noqa: G004 + ) + for user in node.users: + assert user.op == "call_module" + # Move parameter into submodule + move_param_to_callee( + split, + user.target, + node.target, + ) + + # [aliasing] store tensor id -> list of FQNs, built from state dict + # Also assign non-persistent buffers + id_to_fqns: Dict[int, Set[str]] = defaultdict(set) + for fqn, tensor in mod.state_dict(keep_vars=True).items(): + id_to_fqns[id(tensor)].add(fqn) + for fqn, tensor in mod.named_buffers(): + id_to_fqns[id(tensor)].add(fqn) + + # After moving the params to their corresponding hierarchies, we also + # need to move the `get_attr` nodes from the root of the graph to those + # hierarchies. + # [aliasing] use id -> fqn mapping to list out all valid FQNs + inputs_to_state: Dict[str, List[str]] = {} + for attr in attr_nodes: + _, tensor = _recursive_getattr_with_parent(mod, attr.target) + fqns = list(id_to_fqns[id(tensor)]) + if fqns: + inputs_to_state[attr.name] = fqns + elif attr.target in exported_program.constants: # lifted constants + inputs_to_state[attr.name] = [attr.target] + + # [aliasing] for each submodule split, assign attributes on FQNs that may be used. + # We determine this based on whether or not the FQN attribute parent exists. + # i.e. if the last submodule exists, assign the attribute. + added_attributes: Dict[str, List[str]] = defaultdict(list) + for fqn, tensor in mod.state_dict(keep_vars=True).items(): + for name, submod in split.named_children(): + if isinstance(submod, fx.GraphModule): + parent, child = _recursive_getattr_with_parent(submod, fqn) + if ( + parent and child is None + ): # parent exists, attribute doesn't -> assign + added_attributes[name].append(fqn) + setattr(parent, fqn.split(".")[-1], tensor) + + # Deferral deletion: Remove the original attributes (to params) from the + # root GraphModule + for mod_itr, last_atom in to_delete: + try: + delattr(mod_itr, last_atom) + except AttributeError: + # This is expected if the parameter is used in multiple stages + pass + + # This is done by (1) `_sink_params` at each submodule; + for name, submod in split.named_children(): + if isinstance(submod, fx.GraphModule): + _sink_params(submod, inputs_to_state, []) + submod.graph.lint() + submod.recompile() + + # [aliasing] This step is not super necessary, but helps reduce parameter usage/memory. + # After _sink_params() routine has run, clean up unused attributes that we previously added. + # Determine this based on the get_attr nodes - if not used, remove it. + for name, attributes in added_attributes.items(): + submod = getattr(split, name) + unused_attributes = set(attributes) + # track used attributes in the submodule, running DFS on subgraph hierarchy + stack = [("", submod)] # (scope, submodule) + while stack: + scope, _mod = stack.pop() + if isinstance(_mod, (fx.GraphModule, InterpreterModule)): + for node in _mod.graph.nodes: + if node.op == "get_attr": + # get_attr might get access deeper level attribute + fqn = scope + "." + node.target if scope else node.target + unused_attributes.discard(fqn) + for _name, _submod in _mod.named_children(): + stack.append((scope + "." + _name if scope else _name, _submod)) + # delete unused attributes + for attr in unused_attributes: + mod_itr, atoms = submod, attr.split(".") + for atom in atoms[:-1]: + mod_itr = getattr(mod_itr, atom) + delattr(mod_itr, atoms[-1]) + + for node in attr_nodes: + # And (2): remove `get_attr` node from submod's arg list + for user in copy.copy(node.users): + assert user.op == "call_module" + delete_user_reference(node, user) + # And (3): remove the `get_attr` node from the root graph. + split.graph.erase_node(node) + + split.delete_all_unused_submodules() + split.graph.lint() + split.recompile() + + num_stages = Pipe._number_and_count_forward_stages(split) + + has_loss_and_backward = False + generated_loss_spec = output_loss_value_spec + + if output_loss_value_spec is not None: + loss_node, output_node, generated_loss_spec = _find_loss_output( + mod, split.graph, output_loss_value_spec + ) + if loss_node is not None: + _insert_stage_symbolic_backward( + split.graph, + loss_node, + output_node, + ) + split.recompile() + has_loss_and_backward = True + logger.debug("Pipeline is in training mode, backward pass generated") + else: + raise RuntimeError( + f"Did not find any loss value according to {output_loss_value_spec=}" + ) + else: + logger.debug("Pipeline is in inference mode, backward pass not generated") + + logger.debug("Full pipe model:\n" f"{split}") # noqa: G004 + + return Pipe( + split, + num_stages, + has_loss_and_backward, + generated_loss_spec, + ) + + def print_readable(self): + """ + Print the pipe in a human-readable format. + This will print both the root pipe and each stage module. + """ + self.split_gm.print_readable() + + @staticmethod + def _trace_with_export( + mod: core.nn.Module, + example_args: Tuple[Any, ...], + example_kwargs: Optional[Dict[str, Any]] = None, + ) -> ExportedProgram: + logger.info("Tracing model ...") + try: + ep = core.export.export_for_training( + mod, + example_args, + example_kwargs, + ) + except Exception as e: + raise RuntimeError( + "It seems that we cannot capture your model as a full graph. " + "Typical reasons include graph breaks, data/shape-dependent " + "control flow, or missing meta kernels for custom operators. " + "You can use our manual pipeline interfaces, or try to fix the " + "graph breaks, see https://pycore.org/docs/stable/export.html" + ) from e + + return ep + + @staticmethod + def from_tracing( + mod: core.nn.Module, + example_args: Tuple[Any, ...], + example_kwargs: Optional[Dict[str, Any]] = None, + split_policy: Optional[Callable[[fx.GraphModule], fx.GraphModule]] = None, + ): + # If a param will be used in multiple pipeline stages, we default the strategy to REPLICATE'ing the param across + # stages instead of TRANSMIT'ting it + multi_use_param_spec = MultiUseParameterConfig.REPLICATE + + # Figure out which output is loss from output_chunk_spec + output_loss_value_spec: Any = None + # Deprecated + """ + if output_chunk_spec is not None: + output_loss_value_spec = map_aggregate( + output_chunk_spec, lambda v: isinstance(v, _LossReducer) + ) + """ + + # Trace with export + exported_program = Pipe._trace_with_export( + mod, + example_args, + example_kwargs, + ) + + pipe = Pipe._from_traced( + mod, + exported_program, + multi_use_param_spec, + output_loss_value_spec=output_loss_value_spec, + split_policy=split_policy, + ) + + # Users want the first pipeline stage to accept kwargs if the original + # program does. This is controlled by the `_codegen` field of the graph, + # so we make a copy here. Note: we only want the input spec and not the + # output spec, because the output spec is for the last stage. Maybe a + # TODO? Not sure yet. + split = pipe.split_gm + traced = exported_program.module() + submod0 = next(iter(split.children())) + submod0_sign = signature(submod0.forward) + model_sign = signature(traced.forward) + if len(model_sign.parameters) != len(submod0_sign.parameters): + # We don't change the signature of the first stage if it takes + # different number of args than original model + logger.info( + f"Original model takes {len(model_sign.parameters)} args but the " # noqa: G004 + f"first pipeline stage takes {len(submod0_sign.parameters)}. " + "Please provide args to respective pipeline stages." + ) + else: + # Support kwargs for the first stage + submod0.graph._codegen = copy.deepcopy(traced.graph._codegen) # type: ignore[union-attr] + # `_replace` is actually not "private" or internal. based on this doc: + # To prevent conflicts with field names, the method and attribute names + # start with an underscore + submod0.graph._codegen.pytree_info = ( # type: ignore[union-attr] + submod0.graph._codegen.pytree_info._replace(out_spec=None) # type: ignore[operator, union-attr] + ) + submod0.recompile() + + return pipe + + def __str__(self): + return self.split_gm.__str__() + + def __repr__(self): + return self.split_gm.__repr__() + + def info(self) -> PipeInfo: + """ + Get information about the pipe. + + Returns + ------- + PipeInfo + A dataclass containing information about the pipe. + """ + return PipeInfo( + graph=self.split_gm.graph, + num_stages=self.num_stages, + has_loss_and_backward=self.has_loss_and_backward, + ) + + def build_stage( + self, + stage_index: int, + device: core.device, + group: Optional[ProcessGroup] = None, + ) -> _PipelineStage: + """ + Create a `PipelineStage` given a stage index and distributed group. + The `PipelineStage` can run with `PipelineSchedule`s. + """ + # Find stage module + stage_module = self.get_stage_module(stage_index) + + # Move ops argument to device + # Today PT2 tracer does not treat `x.device` as a symbolic device; + # instead, the device of tracing time got burned into the generated + # code. Here we provide a workaround for users to manually modify the + # "device" kwarg of operations. Such operation may include: + # `core.ones`, `core.zeros`, `core.rand`, etc. + if isinstance(stage_module, core.fx.GraphModule): + _modify_graph_op_device(stage_module, device) + else: + logger.warning( + f"Expected a `core.fx.GraphModule` but got {type(stage_module)}" # noqa: G004 + ) + + # Detach pipe info + # Note: be careful what's included in `pipe_info`. We don't want to keep + # a reference to `Pipe` or `Pipe.split_gm` which stops python from + # recycling them. When python recycles them, other stage modules (which + # are irrelevant to current rank) can be automatically freed. + pipe_info = self.info() + return _PipelineStage(stage_module, stage_index, pipe_info, device, group) + + +class SplitPoint(Enum): + BEGINNING = 1 + END = 2 + + +# For backward compatibility, we kept the PipeSplitWrapper class because `class +# SplitPoint` used to be defined in this class. +class PipeSplitWrapper: + # Create a class alias for BC + SplitPoint = SplitPoint + + +def _split_before_forward(self, *args, **kwargs): + pipe_split() + return self._orig_forward(*args, **kwargs) + + +def _split_after_forward(self, *args, **kwargs): + try: + return self._orig_forward(*args, **kwargs) + finally: + pipe_split() + + +def annotate_split_points(mod: core.nn.Module, spec: Dict[str, SplitPoint]): + # TODO: make this implementation out-of-place? + for qualname, split_type in spec.items(): + atoms = qualname.split(".") + predecessor_module = mod + for i, atom in enumerate(atoms[:-1]): + try: + predecessor_module = getattr(predecessor_module, atom) + except AttributeError as e: + raise AttributeError( + f"Specified target {qualname} referenced " + f'nonexistent module {".".join(atoms[: i + 1])}' + ) from e + + mod_to_wrap = getattr(predecessor_module, atoms[-1]) + mod_to_wrap._orig_forward = mod_to_wrap.forward + if split_type == SplitPoint.BEGINNING: + mod_to_wrap.forward = MethodType(_split_before_forward, mod_to_wrap) + elif split_type == SplitPoint.END: + mod_to_wrap.forward = MethodType(_split_after_forward, mod_to_wrap) + else: + raise ValueError("Unknown split point type.") + + +def pipeline( + module: core.nn.Module, + mb_args: Tuple[Any, ...], + mb_kwargs: Optional[Dict[str, Any]] = None, + split_spec: Optional[Dict[str, SplitPoint]] = None, + split_policy: Optional[Callable[[fx.GraphModule], fx.GraphModule]] = None, +) -> Pipe: + """ + Split a module based on a specification. + + See `Pipe` for more details. + + Arguments + --------- + module: + The module to be splitted. + mb_args: + Example positional inputs, in micro-batch form. + mb_kwargs: + Example keyword inputs, in micro-batch form. (default: `None`) + split_spec: + A dictionary using submodule names as split marker. (default: `None`) + split_policy: + The policy to use for splitting the module. (default: `None`) + + Returns + ------- + A pipeline representation of class `Pipe`. + """ + if split_spec is not None and split_policy is not None: + raise ValueError( + "Cannot specify both `split_spec` and `split_policy`. Please use only one of them." + ) + + if split_spec is not None: + # Annotate split points in the module based on user spec + annotate_split_points(module, split_spec) + return Pipe.from_tracing( + mod=module, + example_args=mb_args, + example_kwargs=mb_kwargs, + ) + else: + # Use split policy + return Pipe.from_tracing( + mod=module, + example_args=mb_args, + example_kwargs=mb_kwargs, + split_policy=split_policy, + ) diff --git a/mindnlp/core/distributed/pipelining/__init__.py b/mindnlp/core/distributed/pipelining/__init__.py new file mode 100644 index 000000000..95a6c2b6a --- /dev/null +++ b/mindnlp/core/distributed/pipelining/__init__.py @@ -0,0 +1,26 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +from ._IR import Pipe, pipe_split, pipeline, SplitPoint +from .schedules import ( + _ScheduleForwardOnly, + Schedule1F1B, + ScheduleGPipe, + ScheduleInterleaved1F1B, + ScheduleInterleavedZeroBubble, + ScheduleLoopedBFS, +) +from .stage import build_stage, PipelineStage + + +__all__ = [ + "Pipe", + "pipe_split", + "SplitPoint", + "pipeline", + "PipelineStage", + "build_stage", + "Schedule1F1B", + "ScheduleGPipe", + "ScheduleInterleaved1F1B", + "ScheduleLoopedBFS", + "ScheduleInterleavedZeroBubble", +] diff --git a/mindnlp/core/distributed/pipelining/_backward.py b/mindnlp/core/distributed/pipelining/_backward.py new file mode 100644 index 000000000..5871c567a --- /dev/null +++ b/mindnlp/core/distributed/pipelining/_backward.py @@ -0,0 +1,401 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates +import collections +import logging +from typing import Any, Deque, Dict, Iterator, List, Optional, Set, Tuple, Union + +from mindnlp import core +from core.autograd.graph import GradientEdge, Node +from core.nn import Parameter + +from ._debug import map_debug_info + + +logger = logging.getLogger(__name__) + + +def _get_grad_fn_or_grad_acc(t: core.Tensor) -> Union[Node, None]: + """ + Get the grad function or grad accumulator for a tensor. + + Accumulate grad nodes are lazily created, so we need to a + dummy view in order to trigger its creation. + """ + if t.requires_grad and t.grad_fn is None: + # if no grad function (leaf tensors) we use view + viewed_t = t.view_as(t) + grad_fn = viewed_t.grad_fn + if grad_fn is not None: + return grad_fn.next_functions[0][0] + else: + raise RuntimeError( + "Attempted to get grad_fn, but got None." + "Is this being created in a no-grad context?" + ) + else: + return t.grad_fn + + +def reverse_closure( + roots: List[Node], target_nodes: Set[Node], reverse_edges_dict +) -> Tuple[Set[Node], Set[Node]]: + """ + This function returns the reverse closure of the given roots, + i.e. the set of nodes that can be reached from the roots by following the + reverse edges of the graph. The target_nodes are the nodes that we want to + include in the closure. + """ + # Recurse until we reach a target node + closure: Set[Node] = set() + visited_target_nodes = set() + q: Deque[Node] = collections.deque() + for node in roots: + if node is not None and node not in closure: + closure.add(node) + q.append(node) + while q: + node = q.popleft() + reverse_edges = reverse_edges_dict[node] + for fn in reverse_edges: + if fn in closure or fn is None: + continue + if fn in target_nodes: + visited_target_nodes.add(fn) + continue + closure.add(fn) + q.append(fn) + return closure, visited_target_nodes + + +def construct_reverse_graph(roots: List[Node]) -> Dict[Node, List[Node]]: + q: Deque[Node] = collections.deque() + root_seen: Set[Node] = set() + reverse_edges_dict: Dict[Node, List[Node]] = collections.defaultdict(list) + for node in roots: + if node is not None and node not in root_seen: + q.append(node) + root_seen.add(node) + while q: + node = q.popleft() + for fn, _ in node.next_functions: + if fn is not None: + if len(reverse_edges_dict[fn]) == 0: + q.append(fn) + reverse_edges_dict[fn].append(node) + return reverse_edges_dict + + +def get_param_groups( + inputs: List[Node], params: List[Node], reverse_edges_dict +) -> List[Dict[str, Any]]: + """ + Given a list of inputs and a list of parameters, return a list of parameter + groups, where each group contains the parameters and the intermediates that + are connected to the parameters. + + The returned list of parameter groups is a list of dictionaries, where each + dictionary contains the following keys: + - "params": a set of parameters + - "intermediates": a set of intermediates + + The returned list of parameter groups is a list of dictionaries, + """ + # reverse graph that starts with inputs, and goes up to the dOutput or the loss, + # but omits weights and any subgraphs connecting weights to this closure + inputs_closure, _ = reverse_closure(inputs, set(), reverse_edges_dict) + param_groups: Dict[Node, Dict[str, Set]] = dict() # keyed on intermediates + for param in params: + closure, intersected = reverse_closure( + [param], inputs_closure, reverse_edges_dict + ) + param_group: Dict[str, Set] = { + "params": {param}, + "intermediates": intersected, + } + for input_node in intersected: + existing = param_groups.get(input_node, None) + if existing is not None: + existing["params"] = existing["params"].union(param_group["params"]) + existing["intermediates"] = existing["intermediates"].union( + param_group["intermediates"] + ) + param_group = existing + else: + param_groups[input_node] = param_group + + # Sanity check: union of all param_groups params should be equal to all params + union_params: Set[Node] = set() + seen_ids: Set[int] = set() + unique_param_groups = [] + for param_group in param_groups.values(): + if id(param_group) not in seen_ids: + seen_ids.add(id(param_group)) + unique_param_groups.append(param_group) + union_params = union_params.union(param_group["params"]) + + # The assert will only be true if the input tensor requires gradients, + # otherwise the autograd graph will miss the first layer of inputs + # assert union_params == set(params) + return unique_param_groups + + +def stage_backward_input( + stage_outputs_or_loss: List[core.Tensor], + output_grads: Optional[List[core.Tensor]], + input_values: List[core.Tensor], + weights: Iterator[Parameter], +) -> Tuple[Tuple[Optional[core.Tensor], ...], List[Dict[str, Any]]]: + """ + Compute the gradients for only the stage inputs with + respect to the stage outputs (if non-last stage) or loss (if last stage) + + After computing input gradients, we save the intermediate nodes in `param_groups` + for later use in stage_backward_weight. We don't need to save any other intermediate nodes + that aren't needed for dW because when we do dW calculation, we start from saved intermediates. + Detaching the stage_outputs_or_loss at the end of this function is important as + it frees up the memory that the autograd graph is anticipating to be used later (but doesn't actually need). + """ + stage_output_grad_fns: List[Node] = list( + filter(None, map(_get_grad_fn_or_grad_acc, stage_outputs_or_loss)) + ) + stage_input_grad_fns: List[Node] = list( + filter(None, map(_get_grad_fn_or_grad_acc, input_values)) + ) + weight_grad_fns: List[Node] = list( + filter(None, map(_get_grad_fn_or_grad_acc, weights)) + ) + + reverse_edges_dict = construct_reverse_graph(stage_output_grad_fns) + param_groups = get_param_groups( + stage_input_grad_fns, weight_grad_fns, reverse_edges_dict + ) + + handles = [] + for param_group in param_groups: + for i, intermediate in enumerate(param_group["intermediates"]): + + def get_hook(param_group, i): + def hook(grad_inputs): + if param_group.get("grads", None) is None: + param_group["grads"] = [None] * len( + param_group["intermediates"] + ) + param_group["grads"][i] = grad_inputs + + return hook + + # These are always "split" nodes that we need to recompute, so + # save their inputs. + handle = intermediate.register_prehook(get_hook(param_group, i)) + handles.append(handle) + + if output_grads is None: + # In case this is the loss and there are no output_grads, then we just use 1s + output_grads = [ + core.ones_like(stage_output) for stage_output in stage_outputs_or_loss + ] + + dinputs = core.autograd.grad( + stage_outputs_or_loss, + inputs=input_values, + grad_outputs=output_grads, + retain_graph=True, + ) + + # update the gradients for inputs + for i, inp in enumerate(input_values): + if inp.grad is None: + inp.grad = dinputs[i] + else: + inp.grad += dinputs[i] + + # stage_outputs_or_loss are not used in backwards after this point, so we can safely remove it from the autograd graph + # this allows autograd to clear up the graph dedicated for this tensor and free up significant memory + for t in stage_outputs_or_loss: + t.detach_() + + # hooks are no longer necessary, clean up for consistency + for handle in handles: + handle.remove() + + return dinputs, param_groups + + +def stage_backward_weight( + weights: Iterator[Parameter], param_groups: List[Dict[str, Any]], retain_graph=False +) -> Tuple[Optional[core.Tensor], ...]: + # map weights to param_group_weights + grad_acc_to_weight = {} + weight_grads: List[Optional[core.Tensor]] = [] + for index, weight in enumerate(weights): + grad_acc = _get_grad_fn_or_grad_acc(weight) + grad_acc_to_weight[grad_acc] = weight, index + weight_grads.append(weight.grad) + + for param_group in param_groups: + # TODO: Handle case where intermediate can have multiple outputs + intermediate_edges = tuple( + GradientEdge(i, 0) for i in param_group["intermediates"] + ) + weights_edges = tuple(GradientEdge(w, 0) for w in param_group["params"]) + + # Break a reference cycle caused inside stage_backward_input->get_hook->hook + # The summarized cycle is: + # `hook` -> cell -> param_group -> intermediates -> `hook` + # becuase we install the hook function onto each of the intermediate autograd nodes. + # We need to keep intermediates alive up until backward_weight, but we can free it now. + del param_group["intermediates"] + + assert all(len(g) == 1 for g in param_group["grads"]) + # [NEW!] Able to pass a GradientEdge to autograd.grad as output + # We do not need to retain_graph because... guarantee no overlap? + # print("trying to execute: ", intermediate_edges, weights_edges) + dweights = core.autograd.grad( + intermediate_edges, + weights_edges, + grad_outputs=sum(param_group["grads"], tuple()), + retain_graph=retain_graph, + ) + # release grad memory early after use + del param_group["grads"] + + for grad_acc, dw in zip(param_group["params"], dweights): + weight, index = grad_acc_to_weight[grad_acc] + if weight.grad is None: + weight.grad = dw + else: + weight.grad += dw + # return grads in the original order weights were provided in + return tuple(weight_grads) + + +def stage_backward( + stage_output, + output_grads, + input_values, + outputs_with_grads_idxs: Optional[List[int]] = None, # deprecated, not used +) -> Tuple[Optional[core.Tensor], ...]: + """ + This is a helper function to: + 1. compute the gradients for the stage inputs, and + 2. accumulate gradients for the stage module's parameters. + + Given the input value(s) and the corresponding gradient for the output + value(s), compute and accumulate gradients for all parameter values (leaves + in the autograd trace) as well as return a list of the gradients for the + input values + """ + if outputs_with_grads_idxs is not None: + # Deprecated, not used in runtime calls, only exists in compiler + stage_output = [stage_output[i] for i in outputs_with_grads_idxs] + output_grads = [output_grads[i] for i in outputs_with_grads_idxs] + + try: + # stage_output may be a composite datatype like dict. Extract all individual + # tensor values here + stage_output_tensors: List[core.Tensor] = [] + output_grad_tensors: List[Optional[core.Tensor]] = [] + + def extract_tensors_with_grads( + output_val, + grad_val, + # Don't delete me- see [Note: ref cycle] + extract_tensors_with_grads, + ): + if isinstance(output_val, core.Tensor): + if not output_val.requires_grad and output_val.grad_fn is None: + return + assert isinstance( + grad_val, (core.Tensor, type(None)) + ), f"Expected Tensor or None gradient but got {type(grad_val)}" + stage_output_tensors.append(output_val) + output_grad_tensors.append(grad_val) + elif isinstance(output_val, (tuple, list)): + if grad_val is None: + return + assert isinstance( + grad_val, (tuple, list) + ), f"grad_value expected to have type {type(output_val)} but got {type(grad_val)}" + assert len(output_val) == len(grad_val) + for ov, gv in zip(output_val, grad_val): + extract_tensors_with_grads( + ov, + gv, + extract_tensors_with_grads, + ) + elif isinstance(output_val, dict): + if grad_val is None: + return + assert isinstance(grad_val, dict) + assert set(output_val.keys()) == set(grad_val.keys()) + for k in output_val.keys(): + extract_tensors_with_grads( + output_val[k], grad_val[k], extract_tensors_with_grads + ) + else: + # Output is a non-tensor type; just ignore it + pass + + # Note: ref cycle + # break a ref cycle that would keep tensors alive until GC runs + # 1. extract_tensors_with_grads refers to a cell that holds refs to any vars defined in stage_backward + # and used in extract_tensors_with_grads + # 2. extract_tensors_with_grads referred to both stage_output_tensors, output_grad_tensors, + # and to itself (extract_tensors_with_grads) since it makes a recursive call + # 3. stage_output_tensors was kept alive by the above refcycle, and it holds activation tensors, which is bad + # fix -> explictly pass in the ref to the fn, so there is no gc cycle anymore + extract_tensors_with_grads( + stage_output, output_grads, extract_tensors_with_grads + ) + + core.autograd.backward( + stage_output_tensors, grad_tensors=output_grad_tensors # type: ignore[arg-type] + ) + + # Extract gradients wrt the input values + grad_inputs: List[Optional[core.Tensor]] = [] + for val in input_values: + if isinstance(val, core.Tensor): + grad_inputs.append(val.grad) + else: + grad_inputs.append(None) + + # Alternative impl: `core.autograd.grad`. + # Note that `core.autograd.grad` will not accumulate gradients into the + # model's parameters. + """ + inputs_with_grad = [] + for val in input_values: + if isinstance(val, core.Tensor) and val.requires_grad: + inputs_with_grad.append(val) + + grad_inputs = core.autograd.grad( + stage_output_tensors, inputs_with_grad, output_grad_tensors, # type: ignore[arg-type] + ) + """ + + except Exception as e: + exc_msg = f""" + Failed to run stage backward: + Stage output: {map_debug_info(stage_output)} + Output gradient: {map_debug_info(output_grads)} + Input: {map_debug_info(input_values)} + """ + raise RuntimeError(exc_msg) from e + + return tuple(grad_inputs) + + +# TODO: handling requires_grad=False dynamically. Can we analyze this during initial +# IR emission? +def _null_coalesce_accumulate(lhs, rhs): + """ + Coalesce two values, even if one of them is null, returning the non-null + value. + """ + if lhs is None: + return rhs + elif rhs is None: + return lhs + else: + return core.add(lhs, rhs) diff --git a/mindnlp/core/distributed/pipelining/_debug.py b/mindnlp/core/distributed/pipelining/_debug.py new file mode 100644 index 000000000..daf6dbaa7 --- /dev/null +++ b/mindnlp/core/distributed/pipelining/_debug.py @@ -0,0 +1,21 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates +from mindnlp import core + + +def friendly_debug_info(v): + """ + Helper function to print out debug info in a friendly way. + """ + if isinstance(v, core.Tensor): + return f"Tensor({v.shape}, grad={v.requires_grad}, dtype={v.dtype})" + else: + return str(v) + + +def map_debug_info(a): + """ + Helper function to apply `friendly_debug_info` to items in `a`. + `a` may be a list, tuple, or dict. + """ + return core.fx.node.map_aggregate(a, friendly_debug_info) diff --git a/mindnlp/core/distributed/pipelining/_unflatten.py b/mindnlp/core/distributed/pipelining/_unflatten.py new file mode 100644 index 000000000..a5a3e232a --- /dev/null +++ b/mindnlp/core/distributed/pipelining/_unflatten.py @@ -0,0 +1,30 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates +from collections import defaultdict +from typing import Dict, List, Set + +from mindnlp import core +from core.export.unflatten import _ModuleFrame, _SubmoduleEntry + + +def _outline_submodules(orig_graph: core.fx.Graph): + # Create an empty GraphModule to hold the outlined modules + new_module = core.fx.GraphModule(core.nn.Module(), core.fx.Graph()) + seen_nodes: Dict[str, core.fx.Node] = {} + seen_modules: Dict[int, List[_SubmoduleEntry]] = defaultdict(list) + seen_attrs: Dict[str, Set[str]] = defaultdict(set) + _ModuleFrame( + orig_graph, + tuple(orig_graph.nodes), + seen_nodes, + seen_modules, + seen_attrs, + None, + [("", 0)], + "", + {}, + module=new_module, + ).run_outer() + new_module.graph.lint() + new_module.recompile() + return new_module diff --git a/mindnlp/core/distributed/pipelining/_utils.py b/mindnlp/core/distributed/pipelining/_utils.py new file mode 100644 index 000000000..4307cf0e9 --- /dev/null +++ b/mindnlp/core/distributed/pipelining/_utils.py @@ -0,0 +1,99 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates +import logging +from dataclasses import dataclass +from typing import List, Tuple, Union + +from mindnlp import core +from mindnlp.core import fx + + +logger = logging.getLogger(__name__) + + +def flatten_args_detach(args): + """ + Flatten the args into a list form and detach the tensors from computational graph. + """ + flat_detached_args = [] + + def extract_tensor_args(a): + nonlocal flat_detached_args + if isinstance(a, core.Tensor): + val = a.detach().requires_grad_(a.requires_grad) + flat_detached_args.append(val) + return val + else: + flat_detached_args.append(a) + return a + + new_args = fx.node.map_aggregate( + args, + extract_tensor_args, + ) + + return new_args, flat_detached_args + + +def flatten_args(args): + """ + Flatten the args into a list form. + """ + flat_args = [] + + def extract_tensor_args(a): + nonlocal flat_args + flat_args.append(a) + return a + + fx.node.map_aggregate( + args, + extract_tensor_args, + ) + + return flat_args + + +class PipeliningShapeError(RuntimeError): + """Shape mismatch between configured and runtime values.""" + + +def validate_tensor_metadata(desc, expected, given): + if not expected.shape == given.shape: + raise PipeliningShapeError( + f"{desc} has a shape mismatch: expected {expected.shape} actual {given.shape}" + ) + if not expected.dtype == given.dtype: + raise PipeliningShapeError( + f"{desc} has a dtype mismatch: expected {expected.dtype} actual {given.dtype}" + ) + if not expected.stride() == given.stride(): + raise PipeliningShapeError( + f"{desc} has a stride mismatch: expected {expected.stride()} actual {given.stride()}" + ) + + +def validate_tensors_metadata( + desc, + expected_tensors: Union[List[core.Tensor], Tuple[core.Tensor, ...]], + actual_tensors: Union[List[core.Tensor], Tuple[core.Tensor, ...]], +): + if len(expected_tensors) != len(actual_tensors): + raise PipeliningShapeError( + f"{desc}: Number of values ({len(actual_tensors)}) does not match expected number ({len(expected_tensors)})" + ) + for i in range(len(expected_tensors)): + validate_tensor_metadata( + f"{desc}: value {i}", expected_tensors[i], actual_tensors[i] + ) + + +@dataclass +class PipeInfo: + """ + Captures information for a pipeline (`Pipe` object). + """ + + graph: fx.Graph + num_stages: int + has_loss_and_backward: bool diff --git a/mindnlp/core/distributed/pipelining/microbatch.py b/mindnlp/core/distributed/pipelining/microbatch.py new file mode 100644 index 000000000..ef86d645b --- /dev/null +++ b/mindnlp/core/distributed/pipelining/microbatch.py @@ -0,0 +1,468 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates +import logging +from typing import Any, Dict, List, Optional, Tuple + +from mindnlp import core +from core.fx.node import map_aggregate +from core.utils._pytree import tree_flatten, tree_unflatten + + +__all__ = [ + "TensorChunkSpec", + "split_args_kwargs_into_chunks", + "merge_chunks", +] + +logger = logging.getLogger(__name__) + +""" +_debug_mask_minibatches specifies to send masked versions of the mini-batch +through instead of micro-batch slices--this can be used for more stable +numerical testing (see [A Note About Correctness Testing]) +""" +_debug_mask_minibatches = False + + +class _CustomReducer: + """ + Custom reducer class that can be used to specify a custom operation that + reduces losses of multiple microbatches into one value. + + Example: + >>> # xdoctest: +SKIP + >>> sum_reducer = _CustomReducer( + >>> core.tensor(0.0), + >>> lambda a, b: a + b + >>> ) + """ + + def __init__(self, init_value, reduce_fn): + self.init_value = init_value + self.reduce_fn = reduce_fn + + +class _LossReducer(_CustomReducer): + pass + + +sum_reducer = _LossReducer(core.tensor(0.0), lambda a, b: a + b) + +# Default chunking dimension is 0. This is used for the case where the user did +# not specify a chunking dimension. +DEFAULT_CHUNK_DIM = 0 + + +class TensorChunkSpec: + """ + Class used to specify chunking of inputs + """ + + def __init__(self, split_dim): + self.split_dim = split_dim + + split_dim: int + + def __repr__(self): + return ( + f"{self.__class__.__module__}.{self.__class__.__name__}({self.split_dim})" + ) + + def __str__(self): + return f"TensorChunkSpec({self.split_dim})" + + @staticmethod + def from_tuple( + chunk_dims: Tuple[int, ...], + ): + """ + A helper for creating a tuple of `TensorChunkSpec` from a tuple of chunk + dimensions (int's). + Example: + >>> # xdoctest: +SKIP + >>> # There are three positional arguments to the model, and + >>> # we are chunking them along dimension 0, 0 and 1, respectively + >>> args_chunk_spec = TensorChunkSpec.from_tuple((0, 0, 1)) + """ + args_chunk_spec = map_aggregate( + chunk_dims, + lambda dim: TensorChunkSpec(dim), # type: ignore[arg-type,return-value] + ) + return args_chunk_spec + + @staticmethod + def from_dict( + chunk_dims: Dict[str, int], + ): + """ + A helper for creating a dictionary of `TensorChunkSpec` from a + dictionary of chunk dimensions (int's). + Example: + >>> # xdoctest: +SKIP + >>> # Chunk dimension 0 for the "id" argument, 1 for the "mask" argument + >>> kwargs_chunk_spec = TensorChunkSpec.from_dict({"id": 0, "mask": 1}) + """ + kwargs_chunk_spec = map_aggregate( + chunk_dims, + lambda dim: TensorChunkSpec(dim), # type: ignore[arg-type,return-value] + ) + return kwargs_chunk_spec + + +# Class used to specify replication of inputs +class _Replicate: + pass + + +def _shard_dict_of_args( + args_dict, + args_chunk_spec, + num_chunks, +): + """ + Given a dictionary of args, and a dictionary of chunking specs, shard the + args according to the chunking specs. + + Args: + args_dict: Dictionary of args + args_chunk_spec: Dictionary of chunking specs + num_chunks: Number of chunks to shard the args into + + Returns: + args_split: List of sharded args + """ + # Stage 1+2: flatten and shard/replicate + + # args_sharded_replicated : [num args, num flat values, num chunks] + args_sharded_replicated = {} + arg_specs = [] + + real_num_chunks = num_chunks + first_tensor = True + + assert len(args_dict) == len( + args_chunk_spec + ), f"args_dict.keys() = {list(args_dict.keys())} args_chunk_spec.keys() = {list(args_chunk_spec.keys())}" + + for arg_key, arg in args_dict.items(): + flat, spec = tree_flatten(arg) + arg_specs.append(spec) + + chunk_spec = args_chunk_spec[arg_key] + assert chunk_spec is not None # Should have been set by caller + chunk_spec_flat, _ = tree_flatten(chunk_spec) + if len(flat) != len(chunk_spec_flat): + raise ValueError( + f"Argument value {arg} did not have the same number of " + f"values as as chunk spec {chunk_spec}" + ) + + sharded_arg_flat = [] + + for v, chunk_v in zip(flat, chunk_spec_flat): + if chunk_v is _Replicate or not isinstance(v, core.Tensor): + sharded_arg_flat.append([v] * real_num_chunks) + elif isinstance(chunk_v, TensorChunkSpec): + # TODO: check type of v. If it's a tensor, use chunk (or debug mask). + # If it's a collection type, split it as you would expect. Otherwise, + # Throw an error + assert isinstance(v, core.Tensor), f"{v} is not a tensor" + + v_split_dim_size = v.size(chunk_v.split_dim) + if v_split_dim_size < real_num_chunks: + if first_tensor: + # We can only adjust number of chunks when we hit this + # issue at the first tensor encountered + logger.warning( + f"Tensor size on chunking dimension is {v_split_dim_size}, " # noqa: G004 + f"downsizing the number of chunks from {num_chunks} to {v_split_dim_size}." + ) + real_num_chunks = v_split_dim_size + else: + raise RuntimeError( + f"Arg {arg_key} on chunking dimension has a size of {v_split_dim_size}, " + f"smaller than the number of chunks {num_chunks}. " + "PiPPy cannot reduce the number of chunks because " + "other arguments have bigger chunk-dimension sizes. " + "Please adjust your num_chunks setting." + ) + + chunk_tensors = core.tensor_split( + v, real_num_chunks, chunk_v.split_dim + ) + + if _debug_mask_minibatches: + expanded_chunks = [] + + split_dim_idx = 0 + for chunk_tensor in chunk_tensors: + new_val = core.zeros_like(v) + upper_idx = split_dim_idx + chunk_tensor.size(chunk_v.split_dim) + + slice_indices = [slice(None, None, None)] * new_val.ndim + slice_indices[chunk_v.split_dim] = slice( + split_dim_idx, upper_idx + ) + new_val[slice_indices] = chunk_tensor + + expanded_chunks.append(new_val) + + split_dim_idx += chunk_tensor.size(chunk_v.split_dim) + + sharded_arg_flat.append(expanded_chunks) + else: + sharded_arg_flat.append(chunk_tensors) # type: ignore[arg-type] + + first_tensor = False + else: + raise TypeError(f"Unrecognized chunk spec: {chunk_v}") + + args_sharded_replicated[arg_key] = sharded_arg_flat + + # chunks_flat : [num chunks, num args, num flat values] + chunks_flat = [] + for chunk_idx in range(real_num_chunks): + chunk_args = {} + for key, arg in args_sharded_replicated.items(): + arg_single_chunk = [v_flat[chunk_idx] for v_flat in arg] + chunk_args[key] = arg_single_chunk + chunks_flat.append(chunk_args) + + # args_split : [num chunks, num args] + args_split = [] + + for chunk in chunks_flat: + per_chunk_args = {} + assert len(arg_specs) == len(chunk) + for (key, arg), arg_spec in zip(chunk.items(), arg_specs): + per_chunk_args[key] = tree_unflatten(arg, arg_spec) + args_split.append(per_chunk_args) + + return args_split + + +def split_args_kwargs_into_chunks( + args: Tuple[Any, ...], + kwargs: Optional[Dict[str, Any]], + chunks: int, + args_chunk_spec: Optional[Tuple[TensorChunkSpec, ...]] = None, + kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None, +) -> Tuple[List[Tuple], List[Dict]]: + """ + Given a sequence of args and kwargs, split them into a number of chunks + according to their respective chunking specs. + + Args: + args: Tuple of args + kwargs: Dict of kwargs + chunks: Number of chunks to split the args and kwargs into + args_chunk_spec: chunking specs for args, in same shape as args + kwargs_chunk_spec: chunking specs for kwargs, in same shape as kwargs + + Returns: + args_split: List of sharded args + kwargs_split: List of sharded kwargs + """ + # Given `args` and `kwargs`, we want to yield a set of `chunks` args and kwargs such that + # the constituent Tensor values have been sharded/replicated according to the `args_chunk_spec` + # and `kwargs_chunk_spec` specifications. The steps are as follows: + # + # 1. Use pytree.tree_flatten to flatten each arg and its spec into nto a 1d array of values. + # To use a running example: suppose our inputs look like + # + # args = ([A, [B, C]], D) args_spec = ([None, [None, TensorChunkSpec]], None) + # (kwargs not shown but it's a similar process) + # + # Then for this step we would end up with + # + # args = ([A, B, C], D) args_spec = ([None, None, TensorChunkSpec], None) + # + # 2. Shard or replicate the arguments subject to the policy in the spec. Suppose chunks = 2 + # + # args = ([[A, A], [B, B], [C_1, C_2]], [D, D]) + # + # 3. Rotate the nesting order such that chunks are the outer dimension + # + # args_chunks = [ + # ([A, B, C_1], D), + # ([A, B, C_2], D), + # ] + # + # 4. Unflatten each chunk according to the spec + # + # args_chunks = [ + # ([A, [B, C_1]], D), + # ([A, [B, C_2]], D), + # ] + + # TODO: _debug_mask_minibatches + # Handle the case where kwargs is None + if kwargs is None: + kwargs = {} + + # If user did not provide args_chunk_spec or kwargs_chunk_spec, we extend + # their format and use default chunking along dim 0 + if args_chunk_spec is None: + args_chunk_spec = (TensorChunkSpec(DEFAULT_CHUNK_DIM),) * len(args) + + if kwargs_chunk_spec is None: + kwargs_chunk_spec = dict.fromkeys(kwargs, TensorChunkSpec(DEFAULT_CHUNK_DIM)) + + args_split_dict = _shard_dict_of_args( + dict(enumerate(args)), + dict(enumerate(args_chunk_spec)), + chunks, + ) + real_num_chunks = len(args_split_dict) + + kwargs_split = _shard_dict_of_args( + kwargs, + kwargs_chunk_spec, + real_num_chunks, + ) + + if len(kwargs_split) < real_num_chunks: + # In case kwargs are sharded into less chunks + # e.g. when `args` has no tensor, just values + real_num_chunks = len(kwargs_split) + # Re-shard args + args_split_dict = _shard_dict_of_args( + dict(enumerate(args)), + dict(enumerate(args_chunk_spec)), + real_num_chunks, + ) + + if len(args_split_dict) != len(kwargs_split): + raise RuntimeError( + "args and kwargs are split into different number of chunks: " + f"{len(args_split_dict)}, {len(kwargs_split)}" + ) + + args_split = [ + tuple(chunk_args[i] for i in range(len(chunk_args))) + for chunk_args in args_split_dict + ] + + return args_split, kwargs_split + + +def merge_chunks( + chunks: List[Any], + chunk_spec, +): + """ + Given a list of chunks, merge them into a single value according to + the chunk spec. + + Args: + chunks: list of chunks + chunk_spec: Chunking spec for the chunks + + Returns: + value: Merged value + """ + # This is essentially the inverse of `split_args_kwargs_into_chunks`, so the + # steps are similar to the steps in that function but in reverse. Given the + # input values: + # + # chunks = [ + # ([A, [B, C_1]], D), + # ([A, [B, C_2]], D), + # ] + # args_spec = ([None, [None, TensorChunkSpec]], None) + # + # 1. Flatten the chunks according to the chunk_spec + # + # chunks_flat = [ + # ([A, B, C_1], D), + # ([A, B, C_2], D), + # ] + # + # 2. Rotate the nesting order such that chunks are the inner dimension + # + # value_inner = ([A, B, [C_1, C_2]], D) + # + # 3. Concatenate sharded arguments + # + # value_combined = ([A, B, C], D) + # + # 4. Unflatten the combined args given the spec + # + # value = ([A, [B, C]], D) + + # Preliminary: flatten the chunk spec + if chunk_spec is not None: + spec_flattened, flatten_spec = tree_flatten(chunk_spec) + else: + # If chunk_spec is not provided, we will merge chunks along the default dimension (0), for all output fields + # We obtain the output structure by flattening chunk 0 and generate the chunk_spec + chunk0_flat, flatten_spec = tree_flatten(chunks[0]) + spec_flattened = [TensorChunkSpec(DEFAULT_CHUNK_DIM)] * len(chunk0_flat) + + # Stage 1: flatten chunks + # chunks_flattened : [num chunks, num args] + chunks_flattened = [] + + for chunk in chunks: + chunk_flattened, _ = tree_flatten(chunk) + if len(chunk_flattened) != len(spec_flattened): + raise ValueError(f"Chunk {chunk} did not match chunk spec {chunk_spec}") + + chunks_flattened.append(chunk_flattened) + + # Stage 2 and 3: Rotate nesting order s.t. chunks are inner dimension and + # concatenate sharded operands + # args_flattened : [num args] + args_flattened = [] + for arg_idx, arg in enumerate(spec_flattened): + if isinstance(arg, TensorChunkSpec): + partial_values = [ + chunks_flattened[chunk_idx][arg_idx] + for chunk_idx in range(len(chunks_flattened)) + ] + + if _debug_mask_minibatches: + # Infer size of individual chunks by running `tensor_split` again + overall_shape = partial_values[0].shape + for val in partial_values[1:]: + assert val.shape == overall_shape + meta_chunks = core.tensor_split( + core.empty(*overall_shape, device="meta"), + sections=len(partial_values), + dim=arg.split_dim, + ) + + values_to_cat = [] + chunk_start_idx = 0 + assert len(partial_values) == len(meta_chunks) + for partial_value, meta_chunk in zip(partial_values, meta_chunks): + chunk_end_idx = chunk_start_idx + meta_chunk.size(arg.split_dim) + + slice_indices = [slice(None, None, None)] * partial_value.ndim + slice_indices[arg.split_dim] = slice(chunk_start_idx, chunk_end_idx) + sliced = partial_value[slice_indices] + values_to_cat.append(sliced) + + chunk_start_idx = chunk_end_idx + + else: + values_to_cat = partial_values + + args_flattened.append(core.cat(values_to_cat, dim=arg.split_dim)) + elif isinstance(arg, _CustomReducer): + reduced_val = arg.init_value + + for chunk_idx in range(len(chunks_flattened)): + reduced_val = arg.reduce_fn( + reduced_val, chunks_flattened[chunk_idx][arg_idx] + ) + + args_flattened.append(reduced_val) + else: + value = chunks_flattened[0][arg_idx] + for chunk_idx in range(1, len(chunks_flattened)): + assert chunks_flattened[chunk_idx][arg_idx] == value + args_flattened.append(value) + + # Stage 4: Unflatten combined args + return tree_unflatten(args_flattened, flatten_spec) diff --git a/mindnlp/core/distributed/pipelining/schedules.py b/mindnlp/core/distributed/pipelining/schedules.py new file mode 100644 index 000000000..8fbedd1e9 --- /dev/null +++ b/mindnlp/core/distributed/pipelining/schedules.py @@ -0,0 +1,2354 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates + +import copy +import csv +import itertools +import logging +import re +from abc import ABC, abstractmethod +from collections import Counter, defaultdict +from enum import Enum +from typing import ( + Any, + Callable, + Dict, + List, + NamedTuple, + Optional, + Set, + Tuple, + TYPE_CHECKING, + Union, +) + +from mindnlp import core +from mindnlp import core.distributed as dist +from core.distributed._composable.fsdp.fully_shard import FSDPModule, UnshardHandle +from core.profiler import record_function + +from .microbatch import merge_chunks, split_args_kwargs_into_chunks, TensorChunkSpec +from .stage import _PipelineStageBase + + +if TYPE_CHECKING: + from core.distributed import Work + +__all__ = [ + "get_schedule_class", + "PipelineScheduleSingle", + "PipelineScheduleMulti", + "Schedule1F1B", + "ScheduleGPipe", + "ScheduleInterleaved1F1B", + "ScheduleLoopedBFS", + "ScheduleInterleavedZeroBubble", +] + +logger = logging.getLogger(__name__) + + +class _ComputationType(Enum): + # TODO(whc) rename to _ActType? + FORWARD = 1 + BACKWARD_INPUT = 2 + BACKWARD_WEIGHT = 3 + UNSHARD = 4 + RESHARD = 5 + SEND_F = 6 + RECV_F = 7 + SEND_B = 8 + RECV_B = 9 + FULL_BACKWARD = 10 + + def __str__(self): + str_map = { + _ComputationType.FORWARD: "F", + _ComputationType.BACKWARD_INPUT: "I", + _ComputationType.BACKWARD_WEIGHT: "W", + _ComputationType.UNSHARD: "UNSHARD", + _ComputationType.RESHARD: "RESHARD", + _ComputationType.SEND_F: "SEND_F", + _ComputationType.RECV_F: "RECV_F", + _ComputationType.SEND_B: "SEND_B", + _ComputationType.RECV_B: "RECV_B", + _ComputationType.FULL_BACKWARD: "B", + } + return str_map[self] + + @staticmethod + def from_str(action): + if action == "F": + return _ComputationType.FORWARD + elif action == "I": + return _ComputationType.BACKWARD_INPUT + elif action == "W": + return _ComputationType.BACKWARD_WEIGHT + elif action == "UNSHARD": + return _ComputationType.UNSHARD + elif action == "RESHARD": + return _ComputationType.RESHARD + elif action == "SEND_F": + return _ComputationType.SEND_F + elif action == "RECV_F": + return _ComputationType.RECV_F + elif action == "SEND_B": + return _ComputationType.SEND_B + elif action == "RECV_B": + return _ComputationType.RECV_B + elif action == "B": + return _ComputationType.FULL_BACKWARD + else: + raise RuntimeError(f"Invalid computation type {action}") + + +FORWARD = _ComputationType.FORWARD +BACKWARD_INPUT = _ComputationType.BACKWARD_INPUT +BACKWARD_WEIGHT = _ComputationType.BACKWARD_WEIGHT +UNSHARD = _ComputationType.UNSHARD +RESHARD = _ComputationType.RESHARD +SEND_F = _ComputationType.SEND_F +RECV_F = _ComputationType.RECV_F +SEND_B = _ComputationType.SEND_B +RECV_B = _ComputationType.RECV_B +FULL_BACKWARD = _ComputationType.FULL_BACKWARD + +# Convenience shorthand for compute actions only since they are used in 'simple schedule format' +F = FORWARD +I = BACKWARD_INPUT +W = BACKWARD_WEIGHT +B = FULL_BACKWARD + +# Helper to parse an action string like 1F0 into a tuple of (stage_index, computation_type, microbatch_index) +_action_regex = re.compile( + r"(\d+)(F|I|B|W|UNSHARD|RESHARD|SEND_F|RECV_F|SEND_B|RECV_B)(\d*)" +) + + +class _Action(NamedTuple): + stage_index: int + computation_type: _ComputationType + microbatch_index: Optional[int] = None + + def __repr__(self): + repr = str(self.stage_index) + repr += str(self.computation_type) + if self.microbatch_index is not None: + repr += str(self.microbatch_index) + return repr + + @staticmethod + def from_str(str): + """ + Reverse of __repr__ + + String should be formatted as [stage][action type][(microbatch)] + e.g. `2F0`, `1UNSHARD`, `3SEND_F1` + """ + if match := _action_regex.match(str): + stage_index, computation_type, microbatch_index = match.groups() + return _Action( + int(stage_index), + _ComputationType.from_str(computation_type), + int(microbatch_index) if len(microbatch_index) else None, + ) + elif str == "" or str.isspace(): + return None + raise RuntimeError( + f"Invalid action string: {str}, should be formatted as [stage][action type][(microbatch)] e.g. 2F0" + ) + + +def _format_pipeline_order(pipeline_order: Dict[int, List[Optional[_Action]]]) -> str: + """ + Formats the pipeline order in a timestep (row) x rank (column) grid of actions + and returns the formatted string + """ + + # don't mutate the original + pipeline_order = copy.deepcopy(pipeline_order) + + # Replace None with "" + for rank in pipeline_order: + for i in range(len(pipeline_order[rank])): + if pipeline_order[rank][i] is None: + # TODO make a real 'None action' that prints as empty string and make mypy happy + pipeline_order[rank][i] = "" # type: ignore[call-overload] + + # Calculate the maximum number of steps across all ranks + num_steps = max(len(actions) for actions in pipeline_order.values()) + step_labels = [ + "Step " + str(i).zfill(len(str(num_steps - 1))) for i in range(num_steps) + ] + # Sorting the dictionary by keys and retrieving values in that order + rank_actions = [ + pipeline_order.get(key, [""] * num_steps) for key in sorted(pipeline_order) + ] + # Transpose the list of lists (rows to columns) + transposed_actions = list(itertools.zip_longest(*rank_actions, fillvalue="")) + # Generate column labels for ranks + num_ranks = len(pipeline_order) + rank_labels = ["Rank " + str(i) for i in range(num_ranks)] + # Calculate the maximum length of each column, considering labels + max_lengths = [ + max(len(str(item)) if item is not None else 0 for item in col) + for col in zip(step_labels, *transposed_actions) + ] + # Format the header row with rank labels + header_row = " " * (len(step_labels[0]) + 2) + " ".join( + f"{label:<{max_lengths[i]}}" for i, label in enumerate(rank_labels) + ) + # Format each row with its corresponding label + formatted_rows = [ + f"{label}: " + + " ".join(f"{str(item):<{max_lengths[i]}}" for i, item in enumerate(row)) + for label, row in zip(step_labels, transposed_actions) + ] + # Join the rows into a single string + formatted_table = header_row + "\n" + "\n".join(formatted_rows) + "\n" + return formatted_table + + +class _PipelineSchedule(ABC): + def __init__( + self, + n_microbatches: int, + loss_fn: Optional[Callable[..., core.Tensor]] = None, + args_chunk_spec: Optional[Tuple[TensorChunkSpec, ...]] = None, + kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None, + output_merge_spec: Optional[Union[Dict[str, Any], Tuple[Any]]] = None, + ): + # From arguments + self._n_microbatches = n_microbatches + self._loss_fn = loss_fn + # Chunking specification for positional inputs. (default: `None`) + self._args_chunk_spec = args_chunk_spec + # Chunking specification for keyword inputs. (default: `None`) + self._kwargs_chunk_spec = kwargs_chunk_spec + self._output_merge_spec = output_merge_spec + """ + # args_chunk_spec and kwargs_chunk_spec specify how to chunk inputs. + # They are used to convert batch to microbatches in `step(x)`. See + # `TensorChunkSpec` for helper methods for creating them. + """ + + # Derived + self._has_backward = self._loss_fn is not None + + # Holds the losses for each microbatch. + self._internal_losses: List[core.Tensor] = [] + logger.info("Using %s", self.__class__.__name__) + + def _maybe_compute_loss(self, stage, output, target_mbs, mb_index): + if stage.is_last and self._has_backward: + loss = self._compute_loss(output, target_mbs[mb_index]) # type: ignore[index] + self._internal_losses.append(loss) + + def _maybe_get_loss(self, stage, mb_index): + valid_index = 0 <= mb_index < len(self._internal_losses) + if stage.is_last and self._has_backward and valid_index: + return self._internal_losses[mb_index] + elif len(self._internal_losses) != 0 and not valid_index: + raise RuntimeError( + f"Loss for microbatch {mb_index} is not available. " + f"Available losses for microbatches: {self._internal_losses}" + ) + else: + return None + + def _update_losses(self, stages, losses): + """ + Update the losses to those in the internal state + """ + # if stages not a list turn into a list + if not isinstance(stages, list): + stages = [stages] + contains_last_stage = any(stage.is_last for stage in stages) + + # Return losses if there is a container passed in + if contains_last_stage and losses is not None: + if len(self._internal_losses) != self._n_microbatches: + raise RuntimeError( + f"Expecting {self._n_microbatches} losses but got {len(self._internal_losses)}" + ) + + # Clean external container first + losses.clear() + # Copy internal losses to external container + losses.extend(self._internal_losses) + + self._internal_losses.clear() + + @abstractmethod + def _step_microbatches( + self, + arg_mbs: Optional[List] = None, + kwarg_mbs: Optional[List] = None, + target_mbs: Optional[List] = None, + losses: Optional[List] = None, + ): + """ + Run one iteration of the pipeline schedule with list of microbatches. + Will go through all the microbatches according to the schedule + implementation. + + Args: + microbatches: list of microbatch args. + """ + raise NotImplementedError + + @abstractmethod + def step(self, *args, target=None, losses: Optional[List] = None, **kwargs): + """ + Run one iteration of the pipeline schedule with *whole-batch* input. + Will chunk the input into microbatches automatically, and go through the + microbatches according to the schedule implementation. + + args: positional arguments to the model (as in non-pipeline case). + kwargs: keyword arguments to the model (as in non-pipeline case). + target: target for the loss function. + losses: a list to store the losses for each microbatch. + """ + raise NotImplementedError + + def _check_inputs( + self, + arg_mbs: Optional[List] = None, + kwarg_mbs: Optional[List] = None, + target_mbs: Optional[List] = None, + losses: Optional[List] = None, + ): + """ + Pre-process/check inputs + """ + + def check_type_and_len(mbs, name: str): + if not isinstance(mbs, list): + raise TypeError(f"{name} must be a list but got a {type(mbs)}") + if len(mbs) != self._n_microbatches: + raise ValueError( + f"Expecting {self._n_microbatches} {name} but got {len(mbs)}" + ) + + if arg_mbs is not None: + check_type_and_len(arg_mbs, "arg_mbs") + else: + arg_mbs = [()] * self._n_microbatches + + if kwarg_mbs is not None: + check_type_and_len(kwarg_mbs, "kwarg_mbs") + else: + kwarg_mbs = [{}] * self._n_microbatches + + if target_mbs is not None: + check_type_and_len(target_mbs, "target_mbs") + + if losses is not None: + if not isinstance(losses, list): + raise TypeError(f"losses must be a list but got a {type(losses)}") + + return arg_mbs, kwarg_mbs + + def _compute_loss(self, output, target): + return self._loss_fn(output, target) # type: ignore[misc] + + def _split_inputs( + self, + args: Tuple[Any, ...], + kwargs: Optional[Dict[str, Any]] = None, + ): + """ + Splits a full-batch input into chunks (i.e. microbatches) and returns + the chunks + """ + if args or kwargs: + args_split, kwargs_split = split_args_kwargs_into_chunks( + args, + kwargs, + self._n_microbatches, + self._args_chunk_spec, + self._kwargs_chunk_spec, + ) + return args_split, kwargs_split + else: + # Empty inputs (e.g. when called on middle stages) + # Return a list of empty tuples/dicts with matching length as chunks + return [()] * self._n_microbatches, [{}] * self._n_microbatches + + def _merge_outputs(self, output_chunks: List[Any]) -> Any: + """ + Merge output chunks back to a batch state. + If output_merge_spec is None, the utility will merge output chunks by dimension 0 (batch dim). + """ + return merge_chunks( + output_chunks, + self._output_merge_spec, + ) + + +def _batch_p2p(p2p_ops: List[dist.P2POp], desc: Optional[str] = None): + """ + Simple wrapper over batch_isend_irecv from core.distributed, which just adds a descriptive logger on top. + """ + if len(p2p_ops) == 0: + return None + desc_str = f"{desc}, " if desc else "" + logger.debug("batch_p2p %s%s", desc_str, p2p_ops) + return dist.batch_isend_irecv(p2p_ops).pop() + + +def _sorted_batch_p2p( + p2p_ops: List[dist.P2POp], desc: Optional[str] = None +) -> Dict[int, dist.Work]: + """ + Sorts the list of P2P ops by the peer rank, and then calls + batch_isend_irecv. Return a dictionary of works by peer rank. This function + helps us avoid hangs in case of skip connections. + """ + # Arrange p2p_ops by peer rank: + # int is the peer rank; + # List is the list of ops towards the peer + ops_by_peer: Dict[int, List[dist.P2POp]] = defaultdict(list) + work_by_peer: Dict[int, dist.Work] = {} + if len(p2p_ops) == 0: + return work_by_peer + + # Classify the ops by peer rank + for op in p2p_ops: + ops_by_peer[op.peer].append(op) + + # Call batch_isend_irecv per peer, in sorted order of the peers (to avoid hangs) + for peer, ops in sorted(ops_by_peer.items()): + work_by_peer[peer] = _batch_p2p(ops, desc=desc) + + return work_by_peer + + +class PipelineScheduleSingle(_PipelineSchedule): + """ + Base class for single-stage schedules. + Implements the `step` method. + Derived classes should implement `_step_microbatches`. + """ + + def __init__( + self, + stage: _PipelineStageBase, + n_microbatches: int, + loss_fn: Optional[Callable] = None, + args_chunk_spec: Optional[Tuple[TensorChunkSpec, ...]] = None, + kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None, + output_merge_spec: Optional[Union[Dict[str, Any], Tuple[Any]]] = None, + ): + # Init parent + super().__init__( + n_microbatches=n_microbatches, + loss_fn=loss_fn, + args_chunk_spec=args_chunk_spec, + kwargs_chunk_spec=kwargs_chunk_spec, + output_merge_spec=output_merge_spec, + ) + # Self attributes + self._stage = stage + self._num_stages = stage.num_stages + # Set the same has_backward flag for stage object + self._stage.has_backward = self._has_backward + self._stage_initialized = False + + def _initialize_stage(self, args, kwargs): + self._stage._prepare_forward_infra(self._n_microbatches, args, kwargs) + if self._has_backward: + self._stage._prepare_backward_infra(self._n_microbatches) + self._stage_initialized = True + + def step(self, *args, target=None, losses: Optional[List] = None, **kwargs): + """ + Run one iteration of the pipeline schedule with *whole-batch* input. + Will chunk the input into microbatches automatically, and go through the + microbatches according to the schedule implementation. + + args: positional arguments to the model (as in non-pipeline case). + kwargs: keyword arguments to the model (as in non-pipeline case). + target: target for the loss function. + losses: a list to store the losses for each microbatch. + """ + + # Clean per iteration + self._stage.clear_runtime_states() + + # Split inputs into microbatches + args_split, kwargs_split = self._split_inputs(args, kwargs) + + # Split target into microbatches + if target is not None: + targets_split = list(core.tensor_split(target, self._n_microbatches)) + else: + targets_split = None + + # Run microbatches + self._step_microbatches(args_split, kwargs_split, targets_split, losses) + + # Return merged results per original format + if self._stage.is_last: + return self._merge_outputs(self._stage.output_chunks) + else: + return None + + +class _ScheduleForwardOnly(PipelineScheduleSingle): + """ + The forward-only schedule. + Will go through all the microbatches and perform only the forward pass + """ + + def _step_microbatches( + self, + arg_mbs: Optional[List] = None, + kwarg_mbs: Optional[List] = None, + target_mbs: Optional[List] = None, + losses: Optional[List] = None, + ): + """ + Run one iteration of the pipeline schedule + """ + if target_mbs is not None or losses is not None: + raise RuntimeError( + "Forward-only schedule does not support loss computation" + ) + + arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses) + if not self._stage_initialized: + self._initialize_stage(arg_mbs[0], kwarg_mbs[0]) + + # Delay send waits + fwd_sends_to_wait: List[dist.Work] = [] + + # Run microbatches + for i in range(self._n_microbatches): + with record_function(f"Forward {i}"): + ops = self._stage.get_fwd_recv_ops(i) + works = _sorted_batch_p2p(ops, desc="fwd_recv") + for work in works.values(): + work.wait() + + self._stage.forward_one_chunk(i, arg_mbs[i], kwarg_mbs[i]) # type: ignore[index] + + ops = self._stage.get_fwd_send_ops(i) + works = _sorted_batch_p2p(ops, desc="fwd_send") + fwd_sends_to_wait.extend(works.values()) + + logger.debug("[%s] Forwarded microbatch %s", self._stage.stage_index, i) + + # Wait for all forward sends to finish + # This should not have performance impact because by the time the first + # backward arrives all the forward sends should have been finished. + for work in fwd_sends_to_wait: + work.wait() + + +class ScheduleGPipe(PipelineScheduleSingle): + """ + The GPipe schedule. + Will go through all the microbatches in a fill-drain manner. + """ + + def _step_microbatches( + self, + arg_mbs: Optional[List] = None, + kwarg_mbs: Optional[List] = None, + target_mbs: Optional[List] = None, + losses: Optional[List] = None, + ): + """ + Run one iteration of the pipeline schedule with list of microbatches. + Will go through all the microbatches according to the GPipe schedule. + + Args: + microbatches: list of microbatch args. + """ + arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses) + + if not self._stage_initialized: + self._initialize_stage(arg_mbs[0], kwarg_mbs[0]) + + # Delay send waits + fwd_sends_to_wait: List[dist.Work] = [] + + # Run microbatches + for i in range(self._n_microbatches): + with record_function(f"Forward {i}"): + ops = self._stage.get_fwd_recv_ops(i) + works = _sorted_batch_p2p(ops, desc="fwd_recv") + for work in works.values(): + work.wait() + + output = self._stage.forward_one_chunk(i, arg_mbs[i], kwarg_mbs[i]) # type: ignore[index] + + ops = self._stage.get_fwd_send_ops(i) + works = _sorted_batch_p2p(ops, desc="fwd_send") + fwd_sends_to_wait.extend(works.values()) + + logger.debug("[%s] Forwarded microbatch %s", self._stage.stage_index, i) + + self._maybe_compute_loss(self._stage, output, target_mbs, i) + + # Wait for all forward sends to finish + # This should not have performance impact because by the time the first + # backward arrives all the forward sends should have been finished. + for work in fwd_sends_to_wait: + work.wait() + + # No loss function, no need to run backward + if not self._has_backward: + return + + # Run backward + # Delay send waits + bwd_sends_to_wait: List[dist.Work] = [] + for i in range(self._n_microbatches): + with record_function(f"Backward {i}"): + ops = self._stage.get_bwd_recv_ops(i) + works = _sorted_batch_p2p(ops, desc="bwd_recv") + for work in works.values(): + work.wait() + + loss = self._maybe_get_loss(self._stage, i) + self._stage.backward_one_chunk( + i, loss=loss, last_backward=i == self._n_microbatches - 1 + ) + + ops = self._stage.get_bwd_send_ops(i) + works = _sorted_batch_p2p(ops, desc="bwd_send") + bwd_sends_to_wait.extend(works.values()) + + logger.debug("[%s] Backwarded microbatch %s", self._stage.stage_index, i) + + # Return losses if there is a container passed in + self._update_losses(self._stage, losses) + + # Wait for all backward sends to finish + for work in bwd_sends_to_wait: + work.wait() + + +class Schedule1F1B(PipelineScheduleSingle): + """ + The 1F1B schedule. + Will perform one forward and one backward on the microbatches in steady state. + """ + + def _step_microbatches( + self, + arg_mbs: Optional[List] = None, + kwarg_mbs: Optional[List] = None, + target_mbs: Optional[List] = None, + losses: Optional[List] = None, + ): + """ + Run one iteration of the pipeline schedule with list of microbatches. + Will go through all the microbatches according to the 1F1B schedule. + + Args: + microbatches: list of microbatch args. + """ + arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses) + + if not self._stage_initialized: + self._initialize_stage(arg_mbs[0], kwarg_mbs[0]) + + # Last stage has 1 warmup, second-to-last 2 warmups, ... + # first stage `num_stages` warmups + warmup_chunks = min( + self._n_microbatches, + self._num_stages - self._stage.stage_index, + ) + + # Chunk counters + fwd_mb_index = 0 + bwd_mb_index = 0 + + # Warmup phase + send_work = None + fwd_sends = [] + for _ in range(warmup_chunks): + # Receive activations + fwd_recvs = self._stage.get_fwd_recv_ops(fwd_mb_index) + if recv_work := _batch_p2p(fwd_recvs, desc="fwd_recv"): + recv_work.wait() + + # Compute + output = self._stage.forward_one_chunk(fwd_mb_index, arg_mbs[fwd_mb_index], kwarg_mbs[fwd_mb_index]) # type: ignore[index] + + # Clear previous chunk's forward sends (hopefully they have well + # finished, otherwise, we are heavily communication bound, in which + # case it doesn't create a lot of benefit to compute next chunk + # eagerly either) + if send_work: + send_work.wait() + + # Send activations + fwd_sends = self._stage.get_fwd_send_ops(fwd_mb_index) + if fwd_mb_index != warmup_chunks - 1: + # Safe to fire + send_work = _batch_p2p(fwd_sends, desc="fwd_send") + # otherwise: + # The last foward send is left for fuse with first 1B in 1B1F below + + # Compute loss + self._maybe_compute_loss(self._stage, output, target_mbs, fwd_mb_index) + fwd_mb_index += 1 + + # Now we should have send ops left over, to be fused with first 1B of 1B1F phase below. + + # 1B1F phase + while True: # Don't worry, we have a break inside + # We actually do 1B first as the `1B1F` name indicates, so prepare its recv ops + bwd_recvs = self._stage.get_bwd_recv_ops(bwd_mb_index) + + # Now, we need to fire the fwd_sends and bwd_recvs together + if fuse_work := _batch_p2p(fwd_sends + bwd_recvs, desc="fwd_send_bwd_recv"): + fuse_work.wait() + + # Backward one chunk + loss = self._maybe_get_loss(self._stage, bwd_mb_index) + self._stage.backward_one_chunk( + bwd_mb_index, + loss=loss, + last_backward=bwd_mb_index == self._n_microbatches - 1, + ) + + # Get the bwd send ops, but don't fire, to be fused with the 1F below + bwd_sends = self._stage.get_bwd_send_ops(bwd_mb_index) + bwd_mb_index += 1 + + if fwd_mb_index == self._n_microbatches: + # We are done with 1B1F, so break with some left-over bwd_sends + break + + # We prepare 1F of the `1B1F` + fwd_recvs = self._stage.get_fwd_recv_ops(fwd_mb_index) + + # Fuse it with bwd_sends above + if fuse_work := _batch_p2p(bwd_sends + fwd_recvs, desc="bwd_send_fwd_recv"): + fuse_work.wait() + + # Now do the fwd + output = self._stage.forward_one_chunk(fwd_mb_index, arg_mbs[fwd_mb_index], kwarg_mbs[fwd_mb_index]) # type: ignore[index] + + # Compute loss + self._maybe_compute_loss(self._stage, output, target_mbs, fwd_mb_index) + + # Get the fwd send ops, but don't fire, leave it for the next iter (wrap-around) + fwd_sends = self._stage.get_fwd_send_ops(fwd_mb_index) + fwd_mb_index += 1 + + # Remember we still have some bwd_sends left over after the break? Now it is time to fire it + send_work = _batch_p2p(bwd_sends, desc="bwd_send") + + # Cooldown + while bwd_mb_index < self._n_microbatches: + # prepare bwd recv ops + bwd_recvs = self._stage.get_bwd_recv_ops(bwd_mb_index) + if recv_work := _batch_p2p(bwd_recvs, desc="bwd_recv"): + recv_work.wait() + + # Backward one chunk + loss = self._maybe_get_loss(self._stage, bwd_mb_index) + self._stage.backward_one_chunk( + bwd_mb_index, + loss=loss, + last_backward=bwd_mb_index == self._n_microbatches - 1, + ) + + # Clear previous chunk's backward sends (hopefully they have well finished) + if send_work: + send_work.wait() + + # Get the bwd send ops, fire it + bwd_sends = self._stage.get_bwd_send_ops(bwd_mb_index) + send_work = _batch_p2p(bwd_sends, desc="bwd_send") + bwd_mb_index += 1 + + # Wait for the last backward send to finish + if send_work: + send_work.wait() + + # Return losses if there is a container passed in + self._update_losses(self._stage, losses) + + +def _add_unshard_reshard( + compute_actions: List[Optional[_Action]], + max_active_stages: int = 3, +) -> List[_Action]: + """Given a basic schedule involving only compute actions (F,B,W), add UNSHARD/RESHARD actions for FSDP. + + UNSHARD refers to fetching the full contents of an FSDP-sharded layer, requiring an all-gather operation. + RESHARD does the opposite, releasing memory (but doing no commmunication) + + We abandon the "timestep lock" during lowering + + max_active_stages controls how many prefetches we allow. It should be measured in mb and tuneable but in practice + 3 stages is probably the thing we want? + (to account for having one f and one b active, and something else prefetching?) + """ + + def next_stage_indices( + count: int, next_actions: List[Optional[_Action]] + ) -> List[int]: + """Remove duplicates (same stage, different microbatch), find next 'count' stages that will do compute.""" + seen: Set[int] = set() + ret: List[int] = [] + + for a in next_actions: + if a is not None and a.stage_index not in seen: + seen.add(a.stage_index) + ret.append(a.stage_index) + if len(ret) == count: + break + return ret + + active_stages: Set[int] = set() + fsdp_aware_actions: List[_Action] = [] + + def _unshard(stage_index: int): + active_stages.add(stage_index) + fsdp_aware_actions.append(_Action(stage_index, UNSHARD, None)) + + def _reshard(stage_index: int): + active_stages.remove(stage_index) + fsdp_aware_actions.append(_Action(stage_index, RESHARD, None)) + + for i, action in enumerate(compute_actions): + if action is None: + continue + + # We prefetch the next N stages we'll see, dropping existing stages to make room + next_n = next_stage_indices(max_active_stages, compute_actions[i:]) + # Fetch needs to be ordered correctly, so don't use a set + fetch = list(filter(lambda s: s not in active_stages, next_n)) + # Unclear what the best policy is for eviction, but we can maintain order so we do + evict = list(filter(lambda s: s not in next_n, active_stages)) + + # logger.debug( + # "_add_unshard_reshard Step %d active: %s fetch %s, evict %s", + # i, + # active_stages, + # fetch, + # evict, + # ) + + for stage in evict: + _reshard(stage) + for stage in fetch: + _unshard(stage) + fsdp_aware_actions.append(action) + + return fsdp_aware_actions + + +def _merge_bw( + compute_actions: List[Optional[_Action]], +) -> List[_Action]: + """Given a basic schedule involving only compute actions (F,I,W), merge adjacent I and W ops into B ops. + (note: I = BACKWARD_INPUT, W = BACKWARD_WEIGHT, B = FULL_BACKWARD) + + B refers to running the whole backward (not separating grad_input and grad_weight), which can be more efficient + in some cases. + """ + merged_actions = [] + while compute_actions: + action = compute_actions.pop(0) + if action is None: + continue + + while len(compute_actions) and (next_action := compute_actions[0]) is None: + # remove any None actions between 'action' and 'next_action' + compute_actions.pop(0) + + if ( + action.computation_type == BACKWARD_INPUT + and next_action is not None + and next_action.computation_type == BACKWARD_WEIGHT + and action.stage_index == next_action.stage_index + and action.microbatch_index == next_action.microbatch_index + ): + merged_actions.append( + _Action(action.stage_index, FULL_BACKWARD, action.microbatch_index) + ) + compute_actions.pop(0) + else: + merged_actions.append(action) + return merged_actions + + +def _add_send_recv( + compute_actions: Dict[int, List[_Action]], + stage_to_rank: Callable[[int], int], + num_stages: int, +) -> Dict[int, List[_Action]]: + comm_actions: Dict[int, List[_Action]] = {rank: [] for rank in compute_actions} + prev_actions: Dict[int, Set[_Action]] = {rank: set() for rank in compute_actions} + + def _has_comms(action: _Action) -> bool: + if action.computation_type == F: + return action.stage_index != num_stages - 1 and stage_to_rank( + action.stage_index + 1 + ) != stage_to_rank(action.stage_index) + elif action.computation_type in (BACKWARD_INPUT, FULL_BACKWARD): + return action.stage_index != 0 and stage_to_rank( + action.stage_index - 1 + ) != stage_to_rank(action.stage_index) + return False + + def _get_comms(action: _Action) -> Tuple[_Action, _Action]: + assert _has_comms(action), f"{action} is not a valid comm action" + stage_idx = action.stage_index + ctype = action.computation_type + mb_idx = action.microbatch_index + send = _Action(stage_idx, SEND_F if ctype == F else SEND_B, mb_idx) + recv_stage_idx = stage_idx + 1 if ctype == F else stage_idx - 1 + recv = _Action(recv_stage_idx, RECV_F if ctype == F else RECV_B, mb_idx) + return send, recv + + def _ready_to_schedule( + action: Optional[_Action], prev_actions: Set[_Action] + ) -> bool: + """We don't put our own recv ops in the schedule, we let a sender on another rank put our recv ops in place. + This helps ensure a sane (non-hanging) ordering of sends and recvs. + But it also means we might not be able to schedule our next compute action yet. + """ + if action is None: + return True + elif action.computation_type == F and not action.stage_index == 0: + if ( + _Action(action.stage_index, RECV_F, action.microbatch_index) + in prev_actions + ): + return True + elif ( + _Action(action.stage_index - 1, F, action.microbatch_index) + in prev_actions + ): + return True + return False + elif ( + action.computation_type in (BACKWARD_INPUT, FULL_BACKWARD) + and not action.stage_index == num_stages - 1 + ): + if ( + _Action(action.stage_index, RECV_B, action.microbatch_index) + in prev_actions + ): + return True + elif ( + _Action(action.stage_index + 1, BACKWARD_INPUT, action.microbatch_index) + in prev_actions + ): + return True + elif ( + _Action(action.stage_index + 1, FULL_BACKWARD, action.microbatch_index) + in prev_actions + ): + return True + return False + else: + return True + + while compute_actions: + progress = False + # go in order of ranks even if dict keys aren't ordered + for rank in sorted(compute_actions): + assert ( + len(compute_actions[rank]) > 0 + ), f"{rank=}, {len(compute_actions[rank])=}" + action = compute_actions[rank][0] + + if not _ready_to_schedule(action, prev_actions[rank]): + continue + + if action is not None: + comm_actions[rank].append(action) + prev_actions[rank].add(action) + if _has_comms(action): + send, recv = _get_comms(action) + # TODO we can avoid send/recv if the 2 stages are on the same rank. + # should we avoid that in the runtime or here? + comm_actions[rank].append(send) + prev_actions[rank].add(send) + comm_actions[stage_to_rank(recv.stage_index)].append(recv) + prev_actions[stage_to_rank(recv.stage_index)].add(recv) + + compute_actions[rank].pop(0) + if len(compute_actions[rank]) == 0: + del compute_actions[rank] + progress = True + assert progress, "Malformed compute schedule, can't schedule sends/recvs" + return comm_actions + + +class PipelineScheduleMulti(_PipelineSchedule): + """ + Base class for multi-stage schedules. + Implements the `step` method. + """ + + def __init__( + self, + stages: List[_PipelineStageBase], + n_microbatches: int, + loss_fn: Optional[Callable] = None, + args_chunk_spec: Optional[Tuple[TensorChunkSpec, ...]] = None, + kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None, + output_merge_spec: Optional[Union[Dict[str, Any], Tuple[Any]]] = None, + stage_index_to_group_rank: Optional[Dict[int, int]] = None, + use_full_backward: Optional[bool] = None, + ): + # Init parent + super().__init__( + n_microbatches=n_microbatches, + loss_fn=loss_fn, + args_chunk_spec=args_chunk_spec, + kwargs_chunk_spec=kwargs_chunk_spec, + output_merge_spec=output_merge_spec, + ) + # Self attributes + self._stages = stages + self._num_stages = stages[0].num_stages + self.pp_group_size = stages[0].group_size + self.rank = stages[0].group_rank + # Set the pipeline stage states + if stage_index_to_group_rank is not None: + for stage in self._stages: + stage.stage_index_to_group_rank = stage_index_to_group_rank + self.stage_index_to_group_rank = stages[0].stage_index_to_group_rank + + # Set the same has_backward flag for stage object + for stage in self._stages: + stage.has_backward = self._has_backward + self._stages_initialized = False + + # avoid putting a reference to 'self' inside the lambda, it creates a ref cycle + has_loss: bool = self._loss_fn is not None + self._should_compute_loss = lambda stage: stage.is_last and has_loss + + # This will be set during init of derived schedules + self.pipeline_order: Dict[int, List[Optional[_Action]]] = {} + + if use_full_backward is not None: + logger.warning( + "Deprecation warning: 'use_full_backward' is no longer supported. " + "Simply stop passing it, and everything should still work fine." + ) + + def _initialize_stages(self, args: Tuple[Any, ...], kwargs): + # may be 'none' value (if this stage sends its output shapes to the next stage via P2P) + # or real value (if this stage and next stage are on the same device) + next_stage_args: Tuple[Any, ...] = tuple() + for stage in self._stages: + if stage.is_first: + next_stage_args = stage._prepare_forward_infra( + self._n_microbatches, args, kwargs + ) + else: + next_stage_args = stage._prepare_forward_infra( + self._n_microbatches, next_stage_args, kwargs + ) + + if self._has_backward: + stage._prepare_backward_infra(self._n_microbatches) + self._stages_initialized = True + + def _dump_csv(self, filename): + """Dump a CSV representation of the schedule into a file with the provided filename.""" + with open(filename, "w", newline="") as csvfile: + writer = csv.writer(csvfile) + for rank in self.pipeline_order: + writer.writerow(self.pipeline_order[rank]) + + def _validate_schedule(self): + # TODO(whc) this should be merged with the logic in test_schedule.py#L453-L554 + def _validate_rank_actions( + actions: Dict[int, List[_Action | None]], + num_stages: int, + num_microbatches: int, + ): + # We will count all the actions per stage and ensure they happen in a valid order + # (e.g. F before (B, I) before W for a given microbatch) + stage_actions: Dict[int, Dict[_ComputationType, Set]] = { + stage_id: { + F: set(), + B: set(), + W: set(), + } + for stage_id in range(num_stages) + } + for rank in actions: + for action in actions[rank]: + if action is None: + continue + assert isinstance( + action, _Action + ), f"Got an invalid action: {action}, expected instance of _Action" + s_id = action.stage_index + ctype = action.computation_type + mb_id = action.microbatch_index + if ctype == F: + stage_actions[s_id][F].add(mb_id) + elif ctype == B: + assert ( + mb_id in stage_actions[s_id][F] + ), f"Running Full Backward for stage {s_id}, microbatch {mb_id} without first running Forward" + stage_actions[s_id][B].add(mb_id) + elif ctype == I: + assert ( + mb_id in stage_actions[s_id][F] + ), f"Running Backward Input for stage {s_id}, microbatch {mb_id} without first running Forward" + # TODO(whc) do we need to track I separately from B or should we just merge them for simplicity + stage_actions[s_id][B].add(mb_id) + elif ctype == W: + assert ( + mb_id in stage_actions[s_id][B] + ), f"Running Backward Weight for stage {s_id}, microbatch {mb_id} without first running Backward" + stage_actions[s_id][W].add(mb_id) + + for s_id in stage_actions: + for ctype in (F, B, W): + stage_mb = len(stage_actions[s_id][ctype]) + assert ( + stage_mb == num_microbatches + ), f"Got {stage_mb} {ctype} microbatches for stage {s_id}, expected {num_microbatches}" + + assert ( + len(self.pipeline_order) == self.pp_group_size + ), f"Schedule has incorrect number of ranks - expected {self.pp_group_size}, actual {len(self.pipeline_order)}" + for rank in range(self.pp_group_size): + assert ( + rank in self.pipeline_order + ), f"Schedule is missing actions for rank {rank}" + _validate_rank_actions( + self.pipeline_order, + self._num_stages, + self._n_microbatches, + ) + + def _load_csv(self, filename, format="compute_only"): + """Load a CSV representation of the schedule from a file with the provided filename. + This API will most likely get renamed/refactored so is marked as internal for now. + + format must be "compute_only" for PipelineScheduleMulti + """ + assert format == "compute_only" + with open(filename, newline="") as csvfile: + reader = csv.reader(csvfile) + for rank, row in enumerate(reader): + self.pipeline_order[rank] = [_Action.from_str(s) for s in row] + self._validate_schedule() + + def step(self, *args, target=None, losses: Optional[List] = None, **kwargs): + """ + Run one iteration of the pipeline schedule with *whole-batch* input. + Will chunk the input into microbatches automatically, and go through the + microbatches according to the schedule implementation. + + args: positional arguments to the model (as in non-pipeline case). + kwargs: keyword arguments to the model (as in non-pipeline case). + target: target for the loss function. + losses: a list to store the losses for each microbatch. + """ + # Clean per iteration + for stage in self._stages: + stage.clear_runtime_states() + + # Split inputs into microbatches + args_split, kwargs_split = self._split_inputs(args, kwargs) + + # Split target into microbatches + if target is not None: + targets_split = list(core.tensor_split(target, self._n_microbatches)) + else: + targets_split = None + + # Run microbatches + self._step_microbatches(args_split, kwargs_split, targets_split, losses) + + # Return merged results per original format + for stage in self._stages: + if stage.is_last: + return self._merge_outputs(stage.output_chunks) + # Does not contain the last stage + return None + + def _step_microbatches( + self, + arg_mbs: Optional[List] = None, + kwarg_mbs: Optional[List] = None, + target_mbs: Optional[List] = None, + losses: Optional[List] = None, + ): + """ + Operate on the microbatches for looped schedules (multiple stages on each rank). + + TODO: Does not use sorted_batch_isend_irecv(). As a result, this schedule does + not support models with skip connections. + """ + arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses) + + if not self._stages_initialized: + self._initialize_stages(arg_mbs[0], kwarg_mbs[0]) + + # Based on the plan in Step 1 created in __init__: + # 2. Perform communication based on the pipeline_order + stage_index_to_stage: Dict[int, _PipelineStageBase] = { + stage.stage_index: stage for stage in self._stages + } + + # determine prev_rank and next_rank based on which ranks are next to + # the stages in the pipeline_order + all_prev_ranks: Set[int] = set() + all_next_ranks: Set[int] = set() + for stage_index in stage_index_to_stage.keys(): + # TODO: assumption that stages only communicate from distances of +1/-1 (no skip connections) + if stage_index > 0: + all_prev_ranks.add(self.stage_index_to_group_rank[stage_index - 1]) + if stage_index < self._num_stages - 1: + all_next_ranks.add(self.stage_index_to_group_rank[stage_index + 1]) + # count either full_backward or backward_weight together, to determine when to sync DP grads + backward_counter: Counter[int] = Counter() + for time_step, action in enumerate(self.pipeline_order[self.rank]): + try: + ops: List[dist.P2POp] = [] + if action is not None: + computation_type = action.computation_type + mb_index = action.microbatch_index + stage_index = action.stage_index + assert ( + mb_index is not None + ), "All currently supported action types require valid microbatch_index" + if computation_type == _ComputationType.FORWARD: + # perform forward computation + stage = stage_index_to_stage[stage_index] + output = stage.forward_one_chunk( + mb_index, arg_mbs[mb_index], kwarg_mbs[mb_index] + ) + self._maybe_compute_loss(stage, output, target_mbs, mb_index) + ops.extend(stage.get_fwd_send_ops(mb_index)) + elif computation_type == _ComputationType.FULL_BACKWARD: + # perform backward computation + stage = stage_index_to_stage[stage_index] + loss = self._maybe_get_loss(stage, mb_index) + backward_counter[stage_index] += 1 + stage.backward_one_chunk( + mb_index, + loss=loss, + full_backward=True, + last_backward=backward_counter[stage_index] + == self._n_microbatches, + ) + ops.extend(stage.get_bwd_send_ops(mb_index)) + elif computation_type == _ComputationType.BACKWARD_INPUT: + # perform backward computation + stage = stage_index_to_stage[stage_index] + loss = self._maybe_get_loss(stage, mb_index) + stage.backward_one_chunk( + mb_index, + loss=loss, + full_backward=False, + last_backward=False, + ) + ops.extend(stage.get_bwd_send_ops(mb_index)) + elif computation_type == _ComputationType.BACKWARD_WEIGHT: + # perform weight update + stage = stage_index_to_stage[stage_index] + backward_counter[stage_index] += 1 + stage.backward_weight_one_chunk( + mb_index, + last_backward=backward_counter[stage_index] + == self._n_microbatches, + ) + else: + raise ValueError(f"Unknown computation type {computation_type}") + + # Look at the neighboring ranks for this current timestep and determine whether + # this current rank needs to do any recv communication + for prev_rank in all_prev_ranks: + prev_rank_ops = self.pipeline_order[prev_rank] + prev_rank_action = None + if time_step < len(prev_rank_ops): + prev_rank_action = prev_rank_ops[time_step] + if prev_rank_action is not None: + computation_type = prev_rank_action.computation_type + mb_index = prev_rank_action.microbatch_index + stage_index = prev_rank_action.stage_index + assert ( + mb_index is not None + ), "All currently supported action types require valid microbatch_index" + # Only handle sends for the forward from a previous rank + if computation_type == _ComputationType.FORWARD: + # If not the last stage, then receive fwd activations + if stage_index + 1 in stage_index_to_stage: + # TODO: We are assuming that stage will always receive from stage-1 + # however that is not necessarily true of get_fwd_recv_ops + stage = stage_index_to_stage[stage_index + 1] + ops.extend(stage.get_fwd_recv_ops(mb_index)) + elif computation_type in ( + FULL_BACKWARD, + BACKWARD_INPUT, + BACKWARD_WEIGHT, + ): + # Previous rank doing backward has no influence for the current rank forward recv + pass + else: + raise ValueError( + f"Unknown computation type {computation_type}" + ) + for next_rank in all_next_ranks: + next_rank_ops = self.pipeline_order[next_rank] + next_rank_action = None + if time_step < len(next_rank_ops): + next_rank_action = next_rank_ops[time_step] + if next_rank_action is not None: + computation_type = next_rank_action.computation_type + mb_index = next_rank_action.microbatch_index + stage_index = next_rank_action.stage_index + assert ( + mb_index is not None + ), "All currently supported action types require valid microbatch_index" + # Only handle receives for the backwards from a next rank + if computation_type in (FORWARD, BACKWARD_WEIGHT): + # Next rank doing forward or weight update has no influence for the current rank backward recv + pass + elif computation_type in (BACKWARD_INPUT, FULL_BACKWARD): + # If not the first stage, then receive bwd gradients + if stage_index - 1 in stage_index_to_stage: + # TODO: We are assuming that stage will always receive from stage+1 + # however that is not necessarily true of get_bwd_recv_ops + stage = stage_index_to_stage[stage_index - 1] + ops.extend(stage.get_bwd_recv_ops(mb_index)) + else: + raise ValueError( + f"Unknown computation type {computation_type}" + ) + + # do the communication + if ops: + _batch_p2p(ops).wait() + except Exception as e: + logger.error( + "[Rank %s] pipeline schedule %s caught the following exception \ + at time_step %s when running action %s", + self.rank, + self.__class__.__name__, + time_step, + action, + ) + logger.error("%s", _format_pipeline_order(self.pipeline_order)) + raise e + # Return losses if there is a container passed in + self._update_losses(self._stages, losses) + + +class _PipelineScheduleRuntime(PipelineScheduleMulti): + """ + Provides a simple runtime that requires a 'schedule IR' including specified communication operations. + + Can be instantiated directly by creating _PipelineScheduleRuntime and calling load_csv, or can be + subclassed and the subclass can be responsible for creating a schedule IR. + """ + + def _load_actions( + self, + actions: Dict[int, List[Optional[_Action]]], + format: str = "compute_only", + ): + """ + Given an in-memory representation for a simple compute-only schedule, lower it to a complex schedule including + communication actions. Stores the schedule in self, and must be called before running step_mo() + """ + assert ( + self.stage_index_to_group_rank is not None + ), "stage_index_to_group_rank is required for PipelineScheduleRuntime" + self.pipeline_order_with_comms: Dict[int, List[_Action]] = {} + if format == "compute_comms": + for rank in actions: + self.pipeline_order_with_comms[rank] = [] + for action in actions[rank]: + assert action is not None + self.pipeline_order_with_comms[rank].append(action) + # TODO what level of validation should we offer for compute+comms schedule? + elif format == "compute_only": + # Perform schedule lowering + for rank in actions: + self.pipeline_order_with_comms[rank] = _add_unshard_reshard( + actions[rank] + ) + + self.pipeline_order_with_comms = _add_send_recv( + self.pipeline_order_with_comms, + stage_to_rank=lambda s: self.stage_index_to_group_rank[s], + num_stages=self._num_stages, + ) + else: + raise NotImplementedError(f"{format=} is not implemented") + + def _load_csv(self, filename: str, format: str = "compute_only"): + """Loads a csv in simple format and then lowers it to include comunication actions + + format must be either "compute_only" or "compute_comms". If compute_only, the lowering passes + will automatically be run to generate a compute_comms schedule. + """ + if format == "compute_only": + # this will populate self.pipeline_order + super()._load_csv(filename) + # this will populate self.pipeline_order_with_comms + self._load_actions(self.pipeline_order) + elif format == "compute_comms": + actions = {} + with open(filename, newline="") as csvfile: + reader = csv.reader(csvfile) + for rank, row in enumerate(reader): + actions[rank] = [_Action.from_str(s) for s in row] + self._load_actions(actions, format=format) + else: + raise NotImplementedError(f"{format=} is not implemented") + + def _dump_csv(self, filename: str): + """Dump a CSV representation of the compute + comms schedule into a file with the provided filename.""" + # TODO should there be an option to dump the compute_only schedule from PipelineScheduleRuntime? It's possible + # that it does not exist if it was created from a compute_comms schedule. + assert ( + self.pipeline_order_with_comms is not None + ), "Must initialize compute_comms schedule before dump_csv" + with open(filename, "w", newline="") as csvfile: + writer = csv.writer(csvfile) + for rank in self.pipeline_order_with_comms: + writer.writerow(self.pipeline_order_with_comms[rank]) + + def _simulate(self): + return _simulate_comms_compute( + self.pipeline_order_with_comms, + lambda s: self.stage_index_to_group_rank[s], + self._num_stages, + ) + + def _step_microbatches( + self, + arg_mbs: Optional[List] = None, + kwarg_mbs: Optional[List] = None, + target_mbs: Optional[List] = None, + losses: Optional[List] = None, + ): + """ + Operate on the microbatches for looped schedules (multiple stages on each rank). + + TODO: Does not use sorted_batch_isend_irecv(). As a result, this schedule does + not support models with skip connections. + """ + arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses) + if not self._stages_initialized: + self._initialize_stages(arg_mbs[0], kwarg_mbs[0]) + + # Based on the plan in Step 1 created in __init__: + # 2. Perform communication based on the pipeline_order + stage_index_to_stage: Dict[int, _PipelineStageBase] = { + stage.stage_index: stage for stage in self._stages + } + + assert ( + self.pipeline_order_with_comms is not None + ), "Must call _load_actions() before calling _step_microbatches()" + + # recv ops indexed by (stage_idx, mb_idx) need to be waited on before use + bwd_recv_ops: Dict[Tuple[int, int], Work] = {} + fwd_recv_ops: Dict[Tuple[int, int], Work] = {} + + # send ops should be waited on before step() exists, mainly for hygeine + send_ops: List[Work] = [] + + # we track which stages are 'active' when used with FSDP, and wait on unshard ops before computing on stages + unshard_ops: Dict[int, UnshardHandle] = {} + unsharded_stages = set() + + def _assert_unsharded(stage_idx: int): + """If an unshard is active for `stage_idx`, wait() it and mark `stage_idx` unshared.""" + if stage_idx in unshard_ops: + unshard_ops[stage_idx].wait() + del unshard_ops[stage_idx] + unsharded_stages.add(stage_idx) + assert ( + stage_idx in unsharded_stages + ), f"Attempted to compute on sharded {stage_idx=}" + + # count either full_backward or backward_weight together, to determine when to sync DP grads + backward_counter: Counter[int] = Counter() + for time_step, action in enumerate(self.pipeline_order_with_comms[self.rank]): + try: + comp_type = action.computation_type + mb_index: int = ( + action.microbatch_index + if action.microbatch_index is not None + else -1 + ) + assert mb_index >= 0 or comp_type in ( + UNSHARD, + RESHARD, + ), f"{action=} missing mb_index" + stage_idx = action.stage_index + stage = stage_index_to_stage[stage_idx] + stage_uses_fsdp = isinstance(stage.submod, FSDPModule) + # see [Note: V-schedule special case] + is_next_stage_on_this_rank = stage_idx + 1 in stage_index_to_stage + is_prev_stage_on_this_rank = stage_idx - 1 in stage_index_to_stage + + logger.debug( + "_PipelineScheduleRuntime running time_step %d, action %s", + time_step, + action, + ) + + # TODO(whc) it's not actually safe to use _batch_p2p here in the uncommon case the model has skip-connections, + # since we do not want to batch up ops between more than a pair of ranks. _sorted_batch_p2p would be + # safe to use instead. + # However, I was wondering if I should avoid calling batched operators at all in the case that there is + # only one operator per batch. I could iterate through the 'fwd_send_ops' one by one and run them. + if comp_type == SEND_F: + send_ops.append(_batch_p2p(stage.get_fwd_send_ops(mb_index))) + elif comp_type == SEND_B: + send_ops.append(_batch_p2p(stage.get_bwd_send_ops(mb_index))) + elif comp_type == RECV_F: + assert ( + stage_idx, + mb_index, + ) not in fwd_recv_ops, "Recv twice for {stage_idx=} {mb_index=} without executing forward" + fwd_recv_ops[(stage_idx, mb_index)] = _batch_p2p( + stage.get_fwd_recv_ops(mb_index) + ) + elif comp_type == RECV_B: + assert ( + stage_idx, + mb_index, + ) not in bwd_recv_ops, "Recv twice for {stage_idx=} {mb_index=} without executing backward" + bwd_recv_ops[(stage_idx, mb_index)] = _batch_p2p( + stage.get_bwd_recv_ops(mb_index) + ) + elif comp_type == UNSHARD: + if stage_uses_fsdp: + assert ( + stage_idx not in unsharded_stages + and stage_idx not in unshard_ops + ), f"Unsharding the same {stage_idx=} twice" + unshard_ops[stage_idx] = stage.submod.unshard(async_op=True) # type: ignore[operator] + elif comp_type == RESHARD: + if stage_uses_fsdp: + assert ( + stage_idx in unsharded_stages + ), f"Resharding {stage_idx=} without unsharding" + assert ( + stage_idx not in unshard_ops + ), f"Resharding {stage_idx=} before finishing unshard" + stage.submod.reshard() # type: ignore[operator] + elif comp_type == FORWARD: + if stage_uses_fsdp: + _assert_unsharded(stage_idx) + + if ( + not stage.is_first + # no recv op expected for V-schedule special case (see [Note: V-schedule special case]) + and not is_prev_stage_on_this_rank + ): + assert ( + stage_idx, + mb_index, + ) in fwd_recv_ops, f"Computing {action=} before receiving input" + fwd_recv_ops.pop((stage_idx, mb_index)).wait() + + output = stage.forward_one_chunk( + mb_index, arg_mbs[mb_index], kwarg_mbs[mb_index] + ) + self._maybe_compute_loss(stage, output, target_mbs, mb_index) + + # SEND/RECV op are avoided for special case with 2 adjacent stages on same rank + # see [Note: V-schedule special case] + if is_next_stage_on_this_rank: + stage_index_to_stage[stage_idx + 1].set_local_fwd_input( + output, mb_index + ) + + elif comp_type == FULL_BACKWARD: + if stage_uses_fsdp: + _assert_unsharded(stage_idx) + + if ( + not stage.is_last + # no recv op expected for V-schedule special case (see [Note: V-schedule special case]) + and not is_next_stage_on_this_rank + ): + assert ( + stage_idx, + mb_index, + ) in bwd_recv_ops, ( + f"Attempted to run compute {action=} before receiving input" + ) + bwd_recv_ops.pop((stage_idx, mb_index)).wait() + loss = self._maybe_get_loss(stage, mb_index) + backward_counter[stage_idx] += 1 + stage.backward_one_chunk( + mb_index, + loss=loss, + full_backward=True, + last_backward=backward_counter[stage_idx] + == self._n_microbatches, + ) + # SEND/RECV op are avoided for special case with 2 adjacent stages on same rank + # see [Note: V-schedule special case] + if is_prev_stage_on_this_rank: + stage_index_to_stage[stage_idx - 1].set_local_bwd_input( + stage.get_local_bwd_output(mb_index), mb_index + ) + elif comp_type == BACKWARD_INPUT: + if stage_uses_fsdp: + _assert_unsharded(stage_idx) + + if not stage.is_last: + assert ( + stage_idx, + mb_index, + ) in bwd_recv_ops, ( + f"Attempted to run compute {action=} before receiving input" + ) + bwd_recv_ops.pop((stage_idx, mb_index)).wait() + loss = self._maybe_get_loss(stage, mb_index) + stage.backward_one_chunk( + mb_index, + loss=loss, + full_backward=False, + last_backward=False, + ) + # SEND/RECV op are avoided for special case with 2 adjacent stages on same rank + # see [Note: V-schedule special case] + if is_prev_stage_on_this_rank: + stage_index_to_stage[stage_idx - 1].set_local_bwd_input( + stage.get_local_bwd_output(mb_index), mb_index + ) + elif comp_type == BACKWARD_WEIGHT: + if stage_uses_fsdp: + _assert_unsharded(stage_idx) + backward_counter[stage_idx] += 1 + stage.backward_weight_one_chunk( + mb_index, + last_backward=backward_counter[stage_idx] + == self._n_microbatches, + ) + else: + raise ValueError(f"{action=} is unknown or unsupported") + except Exception as e: + logger.error( + "_PipelineScheduleRuntime caught exception at step %s when running action %s. Full Schedule:", + time_step, + action, + ) + # TODO(whc) what is the best practice for printing a multiline log? + # logger will split it into multiple log lines, but this makes it hard to read (too wide) + print(_format_pipeline_order(self.pipeline_order_with_comms)) # type: ignore[arg-type] + raise e + + # Mostly these operations should have finished long ago, but there isn't an obvious time when to wait for them + while len(send_ops): + send_ops.pop().wait() + + assert len(unshard_ops) == 0, "Unused unshard operations" + + # Return losses if there is a container passed in + self._update_losses(self._stages, losses) + + +class ScheduleLoopedBFS(PipelineScheduleMulti): + """ + Breadth-First Pipeline Parallelism. + See https://arxiv.org/abs/2211.05953 for details. + Simliar to Interleaved 1F1B, Looped BFS supports multiple stages per rank. + What is different is that when microbatches are ready for multiple local + stages, Loops BFS will prioritizes the earlier stage, running all available + microbatches at once. + """ + + def __init__( + self, + stages: List[_PipelineStageBase], + n_microbatches: int, + loss_fn: Optional[Callable] = None, + output_merge_spec: Optional[Union[Dict[str, Any], Tuple[Any]]] = None, + ): + super().__init__( + stages=stages, + n_microbatches=n_microbatches, + loss_fn=loss_fn, + output_merge_spec=output_merge_spec, + ) + + # 1. Create the pipeline_order (all ranks do this calculation) + # This will be used to keep track of the current state of the entire pipeline + # pipeline_order[rank] = [Action(computation_type, microbatch_index, stage_index), ...] + self.pipeline_order: Dict[int, List[Optional[_Action]]] = {} + # ======================================================================== + for rank in range(self.pp_group_size): + rank_ops = self._calculate_single_rank_operations(rank) + self.pipeline_order[rank] = rank_ops + + def _calculate_single_rank_operations(self, rank): + n_local_stages = len(self._stages) + stage_indices = range( + rank, self.pp_group_size * n_local_stages, self.pp_group_size + ) + + # Store the list of operations used for that rank + # Pre-padding, rank starts with no-ops based on the warmup. + rank_ops: List[Optional[_Action]] = [None for _ in range(rank)] + + for stage_index in stage_indices: + rank_ops.extend( + _Action(stage_index, _ComputationType.FORWARD, mb_index) + for mb_index in range(self._n_microbatches) + ) + + # wait for the first backward to trickle up + # which is 2 for every hop away + post_warmup_ops = 2 * (self.pp_group_size - 1 - rank) + rank_ops.extend([None] * post_warmup_ops) + + for stage_index in reversed(stage_indices): + rank_ops.extend( + _Action(stage_index, _ComputationType.FULL_BACKWARD, mb_index) + for mb_index in reversed(range(self._n_microbatches)) + ) + return rank_ops + + +def _get_1f1b_rank_ops( + n_local_stages, + pp_group_size, + warmup_ops, + fwd_bwd_ops, + cooldown_ops, + rank, + forward_stage_index, + backward_stage_index, + num_1f1b_microbatches=0, + enable_zero_bubble=False, +): + # All stages start with handling microbatch 0 + fwd_stage_mb_index: Dict[int, int] = defaultdict(int) + bwd_stage_mb_index: Dict[int, int] = defaultdict(int) + weight_stage_mb_index: Dict[int, int] = defaultdict(int) + + # Store the list of operations used for that rank + # Pre-padding, rank starts with no-ops based on the warmup. + rank_ops: List[Optional[_Action]] = [None for _ in range(rank)] + # These are used to calculate the number of slots to fill with no-ops, to account for the delay in warmup + # when we want to wait for the backward to trickle back up and start 1f1b to align all ranks. + # Formula: + # pre-padding + warmup_ops + post_warmup_ops = earliest time step of first backward + # post_warmup_ops = [earliest time step of first backward] - (warmup_ops + pre-padding) + # earliest time step of first backward = [local_stages * group_size + 2 * (group_size - 1 - rank)] + # warmup_ops = calculated above + post_warmup_ops = ( + n_local_stages * pp_group_size + 2 * (pp_group_size - 1 - rank) + ) - (warmup_ops + rank) + + if enable_zero_bubble: + post_warmup_ops = pp_group_size - rank - 1 + + total_ops = warmup_ops + fwd_bwd_ops + cooldown_ops + + backward_op_ids = [] + weight_op_count = 0 + + FULL_BACKWARD_OR_BACKWARD_INPUT = ( + BACKWARD_INPUT if enable_zero_bubble else FULL_BACKWARD + ) + + for op in range(total_ops): + # Warmup phase + if op < warmup_ops: + fwd_stage_index = forward_stage_index(op) + # This will assign the current microbatch index and update it as well + fwd_stage_mb_index[fwd_stage_index] = ( + mb_index := fwd_stage_mb_index[fwd_stage_index] + ) + 1 + rank_ops.append( + _Action(fwd_stage_index, _ComputationType.FORWARD, mb_index) + ) + if op == warmup_ops - 1: + # This is the last step in the warmup phase, so we need to wait for the backward to trickle back up + rank_ops.extend([None] * post_warmup_ops) + # 1F1B Phase (forward and backward) + elif warmup_ops <= op < warmup_ops + fwd_bwd_ops: + fwd_stage_index = forward_stage_index(op) + fwd_stage_mb_index[fwd_stage_index] = ( + fwd_mb_index := fwd_stage_mb_index[fwd_stage_index] + ) + 1 + rank_ops.append( + _Action(fwd_stage_index, _ComputationType.FORWARD, fwd_mb_index) + ) + bwd_stage_index = backward_stage_index(op) + bwd_stage_mb_index[bwd_stage_index] = ( + bwd_mb_index := bwd_stage_mb_index[bwd_stage_index] + ) + 1 + rank_ops.append( + _Action(bwd_stage_index, FULL_BACKWARD_OR_BACKWARD_INPUT, bwd_mb_index) + ) + backward_op_ids.append(op) + + if enable_zero_bubble and op - warmup_ops >= num_1f1b_microbatches: + weight_stage_index = backward_stage_index( + backward_op_ids[weight_op_count] + ) + weight_stage_mb_index[weight_stage_index] = ( + weight_mb_index := weight_stage_mb_index[weight_stage_index] + ) + 1 + rank_ops.append( + _Action( + weight_stage_index, + _ComputationType.BACKWARD_WEIGHT, + weight_mb_index, + ) + ) + weight_op_count += 1 + # Cooldown phase + else: + # During cooldown phase, we need steps to align with 1f1b happening in other ranks + # TODO: we don't need to always append, after all 1f1b are finished we can stop appending None + if not enable_zero_bubble: + rank_ops.append(None) + + bwd_stage_index = backward_stage_index(op) + bwd_stage_mb_index[bwd_stage_index] = ( + bwd_mb_index := bwd_stage_mb_index[bwd_stage_index] + ) + 1 + rank_ops.append( + _Action(bwd_stage_index, FULL_BACKWARD_OR_BACKWARD_INPUT, bwd_mb_index) + ) + backward_op_ids.append(op) + + if enable_zero_bubble and op - warmup_ops >= num_1f1b_microbatches: + weight_stage_index = backward_stage_index( + backward_op_ids[weight_op_count] + ) + weight_stage_mb_index[weight_stage_index] = ( + weight_mb_index := weight_stage_mb_index[weight_stage_index] + ) + 1 + rank_ops.append( + _Action( + weight_stage_index, + _ComputationType.BACKWARD_WEIGHT, + weight_mb_index, + ) + ) + weight_op_count += 1 + + while enable_zero_bubble and weight_op_count < len(backward_op_ids): + weight_stage_index = backward_stage_index(backward_op_ids[weight_op_count]) + weight_stage_mb_index[weight_stage_index] = ( + weight_mb_index := weight_stage_mb_index[weight_stage_index] + ) + 1 + rank_ops.append( + _Action( + weight_stage_index, _ComputationType.BACKWARD_WEIGHT, weight_mb_index + ) + ) + weight_op_count += 1 + + return rank_ops + + +class ScheduleInterleaved1F1B(PipelineScheduleMulti): + """ + The Interleaved 1F1B schedule. + See https://arxiv.org/pdf/2104.04473 for details. + Will perform one forward and one backward on the microbatches in steady + state and supports multiple stages per rank. When microbatches are ready for + multiple local stages, Interleaved 1F1B prioritizes the earlier microbatch + (also called "depth first"). + + This schedule is mostly similar to the original paper. + It differs by being relaxing the requirement of num_microbatch % pp_size == 0. + Using the flex_pp schedule, we will have num_rounds = max(1, n_microbatches // pp_group_size) and + it works as long as n_microbatches % num_rounds is 0. As a few examples, support + + 1. pp_group_size = 4, n_microbatches = 10. We will have num_rounds = 2 and n_microbatches % 2 is 0. + 2. pp_group_size = 4, n_microbatches = 3. We will have num_rounds = 1 and n_microbatches % 1 is 0. + """ + + def __init__( + self, + stages: List[_PipelineStageBase], + n_microbatches: int, + loss_fn: Optional[Callable] = None, + args_chunk_spec: Optional[Tuple[TensorChunkSpec, ...]] = None, + kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None, + output_merge_spec: Optional[Union[Dict[str, Any], Tuple[Any]]] = None, + ): + self.pp_group_size = stages[0].group_size + super().__init__( + stages=stages, + n_microbatches=n_microbatches, + loss_fn=loss_fn, + args_chunk_spec=args_chunk_spec, + kwargs_chunk_spec=kwargs_chunk_spec, + output_merge_spec=output_merge_spec, + ) + self.n_local_stages = len(stages) + self.rank = stages[0].group_rank + self.number_of_rounds = max(1, n_microbatches // self.pp_group_size) + self.microbatches_per_round = n_microbatches // self.number_of_rounds + if n_microbatches % self.number_of_rounds != 0: + raise ValueError( + "Interleaved 1F1B requires the number of microbatches to be a " + f"multiple of the number of rounds ({self.number_of_rounds}), " + f"but got {n_microbatches}." + ) + # 1. Create the pipeline_order (all ranks do this calculation) + # This will be used to keep track of the current state of the entire pipeline + # pipeline_order[rank] = [Action(computation_type, microbatch_index, stage_index), ...] + self.pipeline_order: Dict[int, List[Optional[_Action]]] = {} + for rank in range(self.pp_group_size): + rank_ops = self._calculate_single_rank_operations(rank) + self.pipeline_order[rank] = rank_ops + + def _calculate_single_rank_operations(self, rank) -> List[Optional[_Action]]: + def get_rank_warmup_ops(rank): + # Warms up operations for last stage + warmups_ops_last_stage = ( + self.n_local_stages - 1 + ) * self.microbatches_per_round + # Increment warmup operations by 2 for each hop away from the last stage + multiply_factor = 2 + warmup_ops = warmups_ops_last_stage + multiply_factor * ( + (self.pp_group_size - 1) - rank + ) + + # We cannot have more warmup operations than there are number of microbatches, so cap it there + return min(warmup_ops, self._n_microbatches * self.n_local_stages) + + warmup_ops = get_rank_warmup_ops(rank) + microbatch_ops = self.n_local_stages * self._n_microbatches + # fwd_bwd_ops should encompass the remaining forwards + fwd_bwd_ops = microbatch_ops - warmup_ops + # cooldown_ops should encompass the remaining backwards + cooldown_ops = microbatch_ops - fwd_bwd_ops + # total ops encompass both forward and backward ops + total_ops = warmup_ops + fwd_bwd_ops + cooldown_ops + # warmup_ops + fwd_bwd_ops * 2 + cooldown_ops == microbatch_ops * 2 + logger.debug( + "rank %s, warmup_ops %s, 1f1b %s, cooldown_ops %s total_ops %s", + rank, + warmup_ops, + fwd_bwd_ops, + cooldown_ops, + total_ops, + ) + + # Calculates the stage index based on step and pp_group_size + def forward_stage_index(step): + # Get the local index from 0 to n_local_stages-1 + local_index = (step // self.microbatches_per_round) % self.n_local_stages + return (local_index * self.pp_group_size) + rank + + def backward_stage_index(step): + local_index = ( + self.n_local_stages + - 1 + - ((step - warmup_ops) // self.microbatches_per_round) + % self.n_local_stages + ) + return (local_index * self.pp_group_size) + rank + + return _get_1f1b_rank_ops( + self.n_local_stages, + self.pp_group_size, + warmup_ops, + fwd_bwd_ops, + cooldown_ops, + rank, + forward_stage_index, + backward_stage_index, + ) + + +class ScheduleInterleavedZeroBubble(PipelineScheduleMulti): + """ + The Interleaved Zero Bubble schedule. + See https://arxiv.org/pdf/2401.10241 for details. + Will perform one forward and one backward on inputs for the microbatches in steady + state and supports multiple stages per rank. Uses the backward for weights to fill in + the pipeline bubble. + + In particular this is implementing the ZB1P schedule in the paper. + """ + + def __init__( + self, + stages: List[_PipelineStageBase], + n_microbatches: int, + loss_fn: Optional[Callable] = None, + args_chunk_spec: Optional[Tuple[TensorChunkSpec, ...]] = None, + kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None, + output_merge_spec: Optional[Union[Dict[str, Any], Tuple[Any]]] = None, + ): + self.pp_group_size = stages[0].group_size + super().__init__( + stages=stages, + n_microbatches=n_microbatches, + loss_fn=loss_fn, + args_chunk_spec=args_chunk_spec, + kwargs_chunk_spec=kwargs_chunk_spec, + output_merge_spec=output_merge_spec, + ) + self.n_local_stages = len(stages) + self.rank = stages[0].group_rank + self.number_of_rounds = max(1, n_microbatches // self.pp_group_size) + self.microbatches_per_round = n_microbatches // self.number_of_rounds + if n_microbatches % self.number_of_rounds != 0: + raise ValueError( + "Zero bubble requires the number of microbatches to be a " + f"multiple of the number of rounds ({self.number_of_rounds}), " + f"but got {n_microbatches}." + ) + # 1. Create the pipeline_order (all ranks do this calculation) + # This will be used to keep track of the current state of the entire pipeline + # pipeline_order[rank] = [Action(computation_type, microbatch_index, stage_index), ...] + self.pipeline_order: Dict[int, List[Optional[_Action]]] = {} + for rank in range(self.pp_group_size): + rank_ops = self._calculate_single_rank_operations(rank) + self.pipeline_order[rank] = rank_ops + + # This function add bubbles to the generated schedule based on dependencies of actions + # Note that the ZB1P schedule will not require bubbles to be manually added and it is + # only useful when n_microbatches <= microbatches_per_round + self.pipeline_order = self._add_bubbles_to_actions( + self.n_local_stages * self.pp_group_size, + ) + + def _calculate_single_rank_operations(self, rank) -> List[Optional[_Action]]: + def get_rank_warmup_ops(rank): + # Warms up operations for last stage + warmups_ops_last_stage = ( + self.n_local_stages - 1 + ) * self.microbatches_per_round + # Increment warmup operations by 2 for each hop away from the last stage + multiply_factor = 1 + warmup_ops = warmups_ops_last_stage + multiply_factor * ( + (self.pp_group_size - 1) - rank + ) + + # We cannot have more warmup operations than there are number of microbatches, so cap it there + return min(warmup_ops, self._n_microbatches * self.n_local_stages) + + warmup_ops = get_rank_warmup_ops(rank) + microbatch_ops = self.n_local_stages * self._n_microbatches + # fwd_bwd_ops should encompass the remaining forwards + fwd_bwd_ops = microbatch_ops - warmup_ops + # cooldown_ops should encompass the remaining backwards + cooldown_ops = microbatch_ops - fwd_bwd_ops + # total ops encompass both forward and backward ops + total_ops = warmup_ops + fwd_bwd_ops + cooldown_ops + # warmup_ops + fwd_bwd_ops * 2 + cooldown_ops == microbatch_ops * 2 + logger.debug( + "rank %s, warmup_ops %s, 1f1b %s, cooldown_ops %s total_ops %s", + rank, + warmup_ops, + fwd_bwd_ops, + cooldown_ops, + total_ops, + ) + + # Calculates the stage index based on step and pp_group_size + + def forward_stage_index(step): + # Get the local index from 0 to n_local_stages-1 + local_index = (step // self.microbatches_per_round) % self.n_local_stages + return (local_index * self.pp_group_size) + rank + + def backward_stage_index(step): + local_index = ( + self.n_local_stages + - 1 + - ((step - warmup_ops) // self.microbatches_per_round) + % self.n_local_stages + ) + return (local_index * self.pp_group_size) + rank + + num_1f1b_microbatches = rank + + return _get_1f1b_rank_ops( + self.n_local_stages, + self.pp_group_size, + warmup_ops, + fwd_bwd_ops, + cooldown_ops, + rank, + forward_stage_index, + backward_stage_index, + num_1f1b_microbatches, + enable_zero_bubble=True, + ) + + def _add_bubbles_to_actions(self, num_stages_global): + actions = self.pipeline_order + + def need_bubble(stage, op, microbatch, num_stages_global, seen_ops): + if op == _ComputationType.FORWARD: + if stage != 0 and (stage - 1, op, microbatch) not in seen_ops: + return True + elif op == _ComputationType.FULL_BACKWARD: + if stage == num_stages_global - 1: + return (stage, _ComputationType.FORWARD, microbatch) not in seen_ops + return (stage + 1, op, microbatch) not in seen_ops + return False + + seen_ops: Set[Tuple[int, _ComputationType, int]] = set() + result: Dict[int, List[Optional[_Action]]] = {} + next_pointer: Dict[int, int] = {} + bubbles_added: Dict[int, int] = {} + total_bubbles_added = 0 + + for rank in range(self.pp_group_size): + result[rank] = [] + next_pointer[rank] = 0 + bubbles_added[rank] = 0 + + while True: + should_stop = True + + temp_seen_ops: Set[Tuple[int, _ComputationType, int]] = set() + + for rank in range(self.pp_group_size): + timestamp = next_pointer[rank] + if timestamp >= len(actions[rank]): + continue + + should_stop = False + + if actions[rank][timestamp] is not None: + temp_action = actions[rank][timestamp] + assert temp_action is not None + stage_index, op, microbatch = temp_action + if not need_bubble( + stage_index, op, microbatch, num_stages_global, seen_ops + ): + result[rank].append(actions[rank][timestamp]) + if microbatch is not None: + temp_seen_ops.add((stage_index, op, microbatch)) + next_pointer[rank] += 1 + else: + result[rank].append(None) + bubbles_added[rank] += 1 + else: + next_pointer[rank] += 1 + result[rank].append(None) + + seen_ops.update(temp_seen_ops) + if should_stop: + break + + if total_bubbles_added > 0: + logger.warning( + "Non zero bubbles added: total_bubbles_added=%s bubbles_added=%s", + total_bubbles_added, + bubbles_added, + ) + return result + + +def get_schedule_class(schedule_name: str): + """ + Maps a schedule name (case insensitive) to its corresponding class object. + + Args: + schedule_name (str): The name of the schedule. + """ + schedule_map = { + "1F1B": Schedule1F1B, + "Interleaved1F1B": ScheduleInterleaved1F1B, + "GPipe": ScheduleGPipe, + "LoopedBFS": ScheduleLoopedBFS, + "InterleavedZeroBubble": ScheduleInterleavedZeroBubble, + "PipelineScheduleSingle": PipelineScheduleSingle, + "PipelineScheduleMulti": PipelineScheduleMulti, + } + lowercase_keys = {k.lower(): k for k in schedule_map.keys()} + lowercase_schedule_name = schedule_name.lower() + if lowercase_schedule_name not in lowercase_keys: + raise ValueError( + f"Unknown schedule name '{schedule_name}'. The valid options are {list(schedule_map.keys())}" + ) + return schedule_map[lowercase_keys[lowercase_schedule_name]] + + +def _simulate_comms_compute( + pipeline_order, stage_to_rank: Callable[[int], int], num_stages: int +): + """This function dry-run simulates the actions in the schedule from the perspective of all ranks, and flags + any deadlocks caused by missing or misordered communications. It also simulates any bubbles in time where a rank + can not execute any action due to waiting for unmet dependencies. The total number of simulator steps can be used + as a metric for unit tests involving IR optimization passes as reordering and merging of IR can reduce the number + of simulated steps. + + The simulation is not high-fidelity and does not model overlapping of compute and communication, or cuda streams. + Future work may be to enhance this and model the compute time, comms overlap, and even memory. + """ + pipeline_order = { + rank: [a for a in pipeline_order[rank] if a is not None] + for rank in sorted(pipeline_order) + } + _schedule: Dict[int, List[_Action | None]] = { + rank: [] for rank in sorted(pipeline_order) + } + + _prev_ops_rank: Dict[int, Set[_Action]] = {rank: set() for rank in _schedule} + + def add_to_schedule(rank: int, action: Optional[_Action]): + _schedule[rank].append(action) + if action is not None: + _prev_ops_rank[rank].add(action) + + def _ready_to_schedule(action: Optional[_Action]) -> bool: + if action is None: + return True + + stage_idx = action.stage_index + prev_ops = _prev_ops_rank[stage_to_rank(stage_idx)] + if action.computation_type == F: + if action.stage_index == 0: + return True + elif ( + _Action(action.stage_index, RECV_F, action.microbatch_index) in prev_ops + ): + return True + elif ( + _Action(action.stage_index - 1, F, action.microbatch_index) in prev_ops + ): + return True + return False + elif action.computation_type in (BACKWARD_INPUT, FULL_BACKWARD): + if action.stage_index == num_stages - 1: + return True + if _Action(action.stage_index, RECV_B, action.microbatch_index) in prev_ops: + return True + if ( + _Action(action.stage_index + 1, BACKWARD_INPUT, action.microbatch_index) + in prev_ops + ): + return True + if ( + _Action(action.stage_index + 1, FULL_BACKWARD, action.microbatch_index) + in prev_ops + ): + return True + return False + elif action.computation_type == BACKWARD_WEIGHT: + return True + elif action.computation_type == SEND_F: + expected_f = _Action(action.stage_index, F, action.microbatch_index) + return expected_f in prev_ops + elif action.computation_type == RECV_F: + peer_stage_idx = stage_idx - 1 + expected_send = _Action(peer_stage_idx, SEND_F, action.microbatch_index) + return expected_send in _prev_ops_rank[stage_to_rank(peer_stage_idx)] + elif action.computation_type == SEND_B: + expected_b = _Action( + action.stage_index, BACKWARD_INPUT, action.microbatch_index + ) + expected_bw = _Action( + action.stage_index, FULL_BACKWARD, action.microbatch_index + ) + return expected_b in prev_ops or expected_bw in prev_ops + elif action.computation_type == RECV_B: + peer_stage_idx = stage_idx + 1 + expected_send = _Action(peer_stage_idx, SEND_B, action.microbatch_index) + return expected_send in _prev_ops_rank[stage_to_rank(peer_stage_idx)] + else: + raise ValueError(f"Unsupported action type {action}") + + while pipeline_order: + progress = False + for rank in sorted(pipeline_order): + if len(pipeline_order[rank]) == 0: + continue + + action = pipeline_order[rank][0] + if _ready_to_schedule(action): + if action is not None: + add_to_schedule(rank, action) + pipeline_order[rank].pop(0) + progress = True + else: + add_to_schedule(rank, None) + + for i in sorted(pipeline_order, reverse=True): + if len(pipeline_order[i]) == 0: + del pipeline_order[i] + + # hacky, but do a second pass to replace any 'none' at this timestep with a real action, if it got unblocked + # by one of the later ranks + for rank in sorted(pipeline_order): + if len(pipeline_order[rank]) == 0: + continue + + if _schedule[rank][-1] is not None: + continue + + action = pipeline_order[rank][0] + if _ready_to_schedule(action): + if action is not None: + _schedule[rank][-1] = action + _prev_ops_rank[rank].add(action) + pipeline_order[rank].pop(0) + + for i in sorted(pipeline_order, reverse=True): + if len(pipeline_order[i]) == 0: + del pipeline_order[i] + + if not progress: + print("WIP comms schedule:\n", _format_pipeline_order(_schedule)) + for rank in pipeline_order: + print(f"{rank=} next action= {pipeline_order[rank][0]}") + raise ValueError("Schedule is not progressing") + + return _schedule + + +def _dump_chrometrace(schedule, filename): + """ + This function dumps a schedule IR into a chrometrace format so it can be visualized. + + It is currently very basic and only serves as a graphical alternative to dumping the schedule IR as text. + + As future work we may extend this to include more accurate heuristics for durations, or let users input durations, + add 'flow events' to let the UI show the connection between sends and recvs, and model cuda streams for comm/compute + as separate streams on the chrometrace view. + """ + events = [] + for rank in sorted(schedule): + for timestep, action in enumerate(schedule[rank]): + if action is None: + continue + events.append( + { + "name": str(action), + "cat": ( + "computation" + if action.computation_type in (F, B, W) + else "communication" + ), + "ph": "X", + "pid": rank, + "tid": rank, + "ts": timestep, + "dur": 1, + } + ) + import json + + with open(filename, "w") as f: + json.dump({"traceEvents": events}, f) diff --git a/mindnlp/core/distributed/pipelining/stage.py b/mindnlp/core/distributed/pipelining/stage.py new file mode 100644 index 000000000..3269bfb88 --- /dev/null +++ b/mindnlp/core/distributed/pipelining/stage.py @@ -0,0 +1,1506 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates +import logging +import operator +from abc import ABC, abstractmethod +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +from mindnlp import core +from mindnlp import core.distributed as dist +from mindnlp import core.fx as fx +from mindnlp import core.nn as nn +from core._subclasses.fake_tensor import FakeTensor +from core.distributed._composable.fsdp.fully_shard import FSDPModule, fully_shard +from core.fx.node import map_aggregate +from core.nn.parallel import DistributedDataParallel +from core.utils._pytree import tree_map_only + +from ._backward import stage_backward, stage_backward_input, stage_backward_weight +from ._debug import map_debug_info +from ._utils import flatten_args, PipeInfo, validate_tensors_metadata + + +__all__ = [ + "PipelineStage", + "build_stage", +] + +logger = logging.getLogger(__name__) + + +def _normalize_model_output_as_tuple(output: Any) -> Tuple[Any]: + """[Note: pipeline model output type] + + The output of the model passed to pipelining can be any type, controlled by the user. + + However, there are 2 API surfaces that complicate this. + (1) the outputs of intermediate stages are passed via Send/Recv ops to subsequent stages. The implicit assumption + is that each element of the outputs is a tensor. Otherwise, Send/Recv would not be supported. The exception + is the last layer of the model, which can output anything any which won't be communicated via Send/Recv. + (2) the outputs of the last layer of the model are returned to the user, or, passed to the loss function. + The loss function can be written in any way, such that its inputs match the outputs of the model. + + It would be convenient if we could strictly type the output signature of the pipeline stage wrapping the model, + but we do not want to impose an unnecessary constraint on user provided models. + + Currently, we let user provided models return either a Tensor or a tuple of Tensors from each stage. Due to + core.export tracing, compiled models may also return a list instead of a Tuple, which we will normalize back to a + tuple for consistency. + + TODO: should we be stricter about asserting that stage modules (intermediate and output) all return only Tensor + values? + """ + if type(output) is list: + # HACK: this is a hacky workaround for the fact that export creates + # output in list format + output = tuple(output) + + # Unify output form to tuple for easy correspondance with + # `act_send_info` + output_tuple = output if type(output) is tuple else (output,) + return output_tuple + + +class _RootArgPlaceholder: + """ + Placeholder for model-level inputs. + """ + + def __init__(self, tensor): + self.meta = tensor.to("meta") + + +class _RecvInfo: + """ + Represents a stage input. + """ + + def __init__( + self, + input_name: str, + source: int, + buffer: core.Tensor, + ): + # Name of this input + self.input_name = input_name + # Stage index of the source of this input + self.source = source + # Buffer to receive the input into. + self.buffer = buffer + + def __repr__(self): + return f"_RecvInfo(input={self.input_name}, source={self.source}, shape={self.buffer.size()})" + + +# An input can be either a received activation or a model input +InputInfo = Union[_RecvInfo, _RootArgPlaceholder] + + +def _make_tensor_from_meta( + example: Union[core.Tensor, FakeTensor], + device: core.device, +) -> core.Tensor: + """ + Create a real tensor from a tensor. + """ + return core.empty( + example.size(), + dtype=example.dtype, + layout=example.layout, + device=device, + ) + + +class _PipelineStageBase(ABC): + """ + Base class for pipeline stages. + Defines or implements common methods used by the `_PipelineStage` used by + the tracing frontend and `PipelineStage` used by manual frontend. + """ + + def __init__( + self, + submodule: core.nn.Module, + stage_index: int, + num_stages: int, + device: core.device, + group: Optional[dist.ProcessGroup] = None, + dw_builder: Optional[Callable[[], Callable[..., None]]] = None, + ): + """ + Args: + submodule (core.nn.Module): The module to be executed in this stage. + stage_index (int): The index of this stage. + num_stages (int): The total number of stages in this pipeline. + device (core.device): The device to run this stage on. + group (Optional[dist.ProcessGroup]): The process group to use for communication. + If `None`, the default process group will be used. + Default: `None`. + dw_builder (Optional[Callable[[], Callable[..., None]]): If provided, dw_runner is a builder function + that will build a new dw_runner function that will run parts of module backward that were intentionally + skipped during the module's actual backward pass. The builder must be invoked by stage after stage runs + model backwards, and stage should save the latest dw_runner to run during weight pass. + If not provided, a dw_runner will be generated automatically by traversing the autograd graph. + When used with schedules that only have F and B steps, the fresh dw_runner function will be called as + part of B. + When used with F,B,W schedules, the dw_runner function implements 'W'. + """ + super().__init__() + if stage_index >= num_stages: + raise ValueError( + f"Stage index {stage_index} is out of range of {num_stages}" + ) + + self.submod = submodule + self.stage_index = stage_index + self.num_stages = num_stages + self.device = device + self.group = group + + self.dw_builder = dw_builder + + # backward state + self.backward_state: Dict[int, Tuple[Any, ...]] = {} + + # store dw_runner per microbatch_id + self.dw_runner: Dict[int, Callable[..., None]] = {} + + # `group_rank` is rank in process group `group`. + self.group_rank = dist.get_rank(self.group) + self.group_size = dist.get_world_size(self.group) + if self.group_size > self.num_stages: + raise RuntimeError( + f"Pipeline group size {self.group_size} cannot be larger than number of stages {self.num_stages}" + ) + + # Run time states + self._outputs_meta: Optional[Tuple[core.Tensor, ...]] = None + # map microbatch ID to list of forward tensor args + self.fwd_cache: Dict[int, Tuple[Any, List[core.Tensor]]] = {} + # map microbatch ID to list of backward grad tensor args + self.bwd_cache: Dict[int, Tuple[Optional[core.Tensor], ...]] = {} + # Caching chunk outputs for final output merge or reduction + self.output_chunks: List[Any] = [] + + # Initialize has_backward to false; this will be set to true if loss + # function is passed to pipeline schedule + self.has_backward = False + # Log prefix + self.log_prefix = f"[Stage {self.stage_index}]" + + # Forward infra + self.args_recv_info: Dict[int, Tuple[InputInfo, ...]] = {} + self.act_send_info: Dict[int, List] = {} + + # Backward infra will created lazily + self.grad_recv_info: Dict = {} + self.grad_send_info: Optional[List] = None + + # To be populated later by the Schedule + self.chunks: Optional[int] = None + self.stage_index_to_group_rank: Dict[int, int] = { + i: i % self.group_size for i in range(self.num_stages) + } + + @property + def has_backward(self) -> bool: + """ + Returns true if this stage has a backward pass. + """ + return self._has_backward + + @has_backward.setter + def has_backward(self, has_backward: bool): + self._has_backward = has_backward + + @property + def is_first(self): + """ + Returns true if this stage is the first stage in the pipeline. + """ + return self.stage_index == 0 + + @property + def is_last(self): + """ + Returns true if this stage is the last stage in the pipeline. + """ + return self.stage_index == self.num_stages - 1 + + def _check_chunk_id(self, chunk_id: int): + if self.chunks is None: + raise RuntimeError( + "Attempted to access chunk_id before chunks have been configured." + ) + if chunk_id >= self.chunks: + raise RuntimeError( + f"Chunk id {chunk_id} is out of range [0, {self.chunks})" + ) + + def _configure_outputs_meta(self, outputs_meta: Tuple[core.Tensor, ...]): + """ + Track the output shapes/dtype of this stage since they determine the send operation(s) which must match + recv operations of the next stage. The next stage _will_ be freezing its recv buffers based on its initial + configuration, so it's important to also freeze/validate the output side to avoid any send/recv mismatches + which could show up as hangs, silent corruption, or other errors. + """ + assert ( + self._outputs_meta is None + ), "Attempting to reconfigure output_meta, which is not supported" + self._outputs_meta = tuple(outputs_meta) # type: ignore[assignment] + + def get_outputs_meta(self) -> Tuple[core.Tensor, ...]: + """Get the output metadata (meta tensors) reprensenting the outputs of this stage""" + assert ( + self._outputs_meta is not None + ), "Attempted to get_outputs_meta() without configuring output meta" + return self._outputs_meta + + def _create_grad_send_info( + self, + args_recv_info: Tuple, + ) -> List[Optional[int]]: + """ + Create a list of stage indices to send gradients to. + """ + grad_send_info: List[Optional[int]] = [] + + def map_recv_to_send(a): + # Note: we send gradients back to previous stage as long as in + # forward it is a received input, regardless of whether it requires + # grad. It is up to the previous stage to disgard this gradient. + if isinstance(a, _RecvInfo): + grad_send_info.append(a.source) + return a.source + else: + grad_send_info.append(None) + return None + + map_aggregate(args_recv_info, map_recv_to_send) + + logger.debug("%s Grad send info: %s", self.log_prefix, grad_send_info) + return grad_send_info + + @abstractmethod + def _prepare_forward_infra( + self, + num_microbatches: int, + args: Tuple[Any, ...], + kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[Any, ...]: + raise NotImplementedError + + def _prepare_backward_infra(self, num_microbatches: int): + # TODO: this is needed for backward_maybe_with_nosync + self.chunks = num_microbatches + + for mb_index in range(num_microbatches): + # `grad_recv_info` is a mirror of `act_send_info` + self.grad_recv_info[mb_index] = self._create_grad_recv_info( + self.act_send_info + ) + + @abstractmethod + def _create_grad_recv_info( + self, + act_send_info: Dict, + ) -> Tuple[_RecvInfo, ...]: + raise NotImplementedError + + def _get_recv_ops( + self, + recv_infos: Tuple[InputInfo, ...], + ) -> List[dist.P2POp]: + """ + Helper function shared by `get_fwd_recv_ops` and `get_bwd_recv_ops`. + Returns a list of ops that correspond to the recv infos. + """ + ops: List[dist.P2POp] = [] + for info in recv_infos: + if not isinstance(info, _RecvInfo): + continue + + peer_rank = self.stage_index_to_group_rank[info.source] + peer_global_rank = ( + peer_rank + if self.group is None + else dist.get_global_rank(self.group, peer_rank) + ) # TODO + ops.append( + dist.P2POp(dist.irecv, info.buffer, peer_global_rank, self.group) + ) + + return ops + + """[Note: V-schedule special case] + + V-Schedules have a special case where 2 stages with adjacent stage_id are on the same rank. + + ex: 2 ranks, 4 stages forms a simple V: + rank0: stage 0 stage 3 + rank1: stage 1 stage 2 + + stage 0,1 and 2,3 communicate activations using send/recv as usual, but stage 1,2 do not need to + use communication ops. Instead, they should pass tensor data directly via function call. + + set_local_fwd_input and (get_local_bwd_output + set_local_bwd_input) facilitate this optimization, and + should be called at the appropriate time during the pipeline schedule (after forward or backward execution). + """ + + def set_local_fwd_input(self, prev_stage_outputs: Any, mb_index: int) -> None: + """ + Moves 'prev_stage_outputs' from another stage on the same rank into place as inputs for this stage. Avoids + copying tensor data or using send/recv op. Detaches original tensor and sets requires_grad so the + tensor can serve as a leaf for autograd and gradients can be collected from it during backward. + """ + recv_infos: Tuple[InputInfo, ...] = self.args_recv_info[mb_index] + + # See [Note: pipeline model output type] + prev_stage_outputs = _normalize_model_output_as_tuple(prev_stage_outputs) + + for info, tensor in zip(recv_infos, prev_stage_outputs): + assert isinstance( + tensor, core.Tensor + ), f"expected tensor values as outputs from prev stage, got {type(tensor)}" + assert isinstance( + info, _RecvInfo + ), "set_local_Fwd_input should only be called on non-first stage, which should always have RecvInfo" + + # We don't need to do a data copy here, since we can directly pass the activation tensor reference from + # one stage to the next. However, we do need to mark the activation as a leaf tensor since it will serve + # as the input tensor for a fresh autograd graph, not part of the previous stage's autograd graph. + # TODO: confirm, do we use this activation as the root of the backward call for the previous stage? does + # detach have any affect on that? + info.buffer = tensor.detach().requires_grad_(True) + + def get_local_bwd_output(self, mb_index): + """ + Returns the input grad tensors for this stage, which correspond to the stage inputs during forward. + """ + assert ( + self.has_backward + ), "can't steal_bwd_input if this stage doesn't have backward" + assert not self.is_first, "can't get bwd output if this stage is first" + + self._check_chunk_id(mb_index) + return self.bwd_cache.pop(mb_index) + + def set_local_bwd_input( + self, next_stage_bwd_outputs: Tuple[Optional[core.Tensor], ...], mb_index: int + ) -> None: + """ + Moves 'grad input' tensors from the next stage to 'grad_output' on this stage, avoiding a copy or send/recv. + Does not detach or set '_requires_grad'. + """ + assert isinstance( + next_stage_bwd_outputs, tuple + ), f"Expected tuple, got {type(next_stage_bwd_outputs)}" + + assert ( + self.has_backward + ), "can't set bwd input if this stage doesn't have backward" + assert not self.is_last, "can't set bwd input if this stage is last" + recv_infos = self.grad_recv_info[mb_index] + for info, tensor in zip(recv_infos, next_stage_bwd_outputs): + assert isinstance( + tensor, core.Tensor + ), f"expected tensor values as outputs from prev stage, got {type(tensor)}" + assert isinstance( + info, _RecvInfo + ), f"Expected a recv info, got {type(info)}" + info.buffer = tensor + + def get_fwd_recv_ops(self, fwd_chunk_id: int) -> List[dist.P2POp]: + """ + Returns a list of ops that are needed to receive the input arguments + for this stage. + """ + recv_infos: Tuple[InputInfo, ...] = self.args_recv_info[fwd_chunk_id] + + return self._get_recv_ops(recv_infos) + + def get_bwd_recv_ops(self, bwd_chunk_id: int) -> List[dist.P2POp]: + """ + Returns a list of ops that are needed to receive the gradients + for this stage. + """ + if not self.has_backward or self.is_last: + return [] + + recv_infos = self.grad_recv_info[bwd_chunk_id] + return self._get_recv_ops(recv_infos) + + def get_fwd_send_ops(self, fwd_chunk_id: int) -> List[dist.P2POp]: + """ + Get the activation send ops for current stage's forward. + """ + output = self.output_chunks[fwd_chunk_id] + # Unify output form to tuple for easy correspondance with + # `act_send_info` + output_tuple = output if type(output) is tuple else (output,) + + ops: List[dist.P2POp] = [] + + for idx, out in enumerate(output_tuple): + dst_stages = self.act_send_info[idx] + for dst in dst_stages: + if dst is None: + continue + logger.debug( + "%s Sending tensor to Stage %s: %s", + self.log_prefix, + dst, + out.size(), + ) + peer_rank = self.stage_index_to_group_rank[dst] + peer_global_rank = ( + peer_rank + if self.group is None + else dist.get_global_rank(self.group, peer_rank) + ) # TODO + ops.append(dist.P2POp(dist.isend, out, peer_global_rank, self.group)) + + return ops + + def get_bwd_send_ops(self, bwd_chunk_id: int) -> List[dist.P2POp]: + """ + Get the gradient send ops for current stage's backward. + """ + self._check_chunk_id(bwd_chunk_id) + + if not self.has_backward or self.is_first: + return [] + + # Create bwd send infra lazily + if self.grad_send_info is None: + # Send info for input grads during backward: + # List of destinations corresponding to input grads + # Can be None if an input has no grad + # `grad_send_info` is a mirror of `args_recv_info` + self.grad_send_info = self._create_grad_send_info(self.args_recv_info[0]) + + ops: List[dist.P2POp] = [] + grads_input = self.bwd_cache.pop(bwd_chunk_id) + for grad, grad_recv_stage in zip(grads_input, self.grad_send_info): + if isinstance(grad, core.Tensor) and grad_recv_stage is not None: + logger.debug( + "%s Sending gradient to Stage %s: %s", + self.log_prefix, + grad_recv_stage, + grad.size(), + ) + peer_rank = self.stage_index_to_group_rank[grad_recv_stage] + peer_global_rank = ( + peer_rank + if self.group is None + else dist.get_global_rank(self.group, peer_rank) + ) # TODO + ops.append(dist.P2POp(dist.isend, grad, peer_global_rank, self.group)) + else: + if not (grad is None and grad_recv_stage is None): + raise RuntimeError( + f"[{self.stage_index}] for chunk {bwd_chunk_id} has gradients {grad} " + f"and is expecting to send gradients to stage {grad_recv_stage}" + ) + return ops + + def clear_runtime_states(self) -> None: + """ + Clear runtime states of the stage. + """ + # map microbatch ID to list of forward tensor args + self.fwd_cache.clear() + # Caching chunk outputs for final output merge or reduction + self.output_chunks.clear() + + # Clear grad of input buffers in between schedule steps. This is because + # `core.autograd.backward()` will accumulate gradients into leaf + # tensors by default. For gradients to pass back to previous stages, we + # don't want such accumulation. + for recv_tuple in self.args_recv_info.values(): # iterate over all chunks + for a in recv_tuple: # iterate over all input args + if isinstance(a, _RecvInfo): + # Set to None is the newer and recommended way to clear grads, compared to `zero_()`. + # See https://github.com/pytorch/pytorch/pull/92731 + a.buffer.grad = None + + def _map_tensor_from_recv_info( + self, + recv_infos: Tuple[InputInfo, ...], + ): + """ + Map tensors from recv infos to a list. + """ + + def get_recv_tensor(info): + if isinstance(info, _RecvInfo): + return info.buffer + else: + raise AssertionError(f"Expected _RecvInfo but got {type(info)}") + + tensors = map_aggregate( + recv_infos, + get_recv_tensor, + ) + + return tensors + + def _retrieve_recv_activations(self, fwd_chunk_id: int): + """ + Retrieve the activations received for the current stage during forward. + """ + recv_infos = self.args_recv_info[fwd_chunk_id] + activations = self._map_tensor_from_recv_info(recv_infos) + return activations + + def _retrieve_recv_grads( + self, + bwd_chunk_id: int, + ): + """ + Retrieve the gradients received for the current stage during backward. + """ + recv_infos = self.grad_recv_info[bwd_chunk_id] + grads = self._map_tensor_from_recv_info(recv_infos) + return grads + + def forward_maybe_with_nosync(self, *args, **kwargs): + # If submod is wrapped with DDP, we use the `no_sync` context manager to + # avoid gradient all-reduce per microbatch + if isinstance(self.submod, DistributedDataParallel): + with self.submod.no_sync(): # type: ignore[operator] + out_val = self.submod(*args, **kwargs) + else: + out_val = self.submod(*args, **kwargs) + return out_val + + def backward_maybe_with_nosync( + self, backward_type, bwd_kwargs: Dict, last_backward=False + ) -> Tuple[Tuple[Optional[core.Tensor], ...], Optional[List[Dict[str, Any]]]]: + """ + Whether using PP with FSDP or DDP, there are some runtime differences between the last backward step and the + other steps. Namely, we need to accumulate gradients on previous steps and reduce them on the last step, but + there are additional state-variables and performance considerations depending on the data parallelism used. + This helper should adapt any pipeline parallel schedule to work with common/supported data parallel libraries. + """ + + def perform_backward( + backward_type, + ) -> Callable[ + [], + Tuple[Tuple[Optional[core.Tensor], ...], Optional[List[Dict[str, Any]]]], + ]: + if backward_type == "full": + return lambda: ( + stage_backward( + bwd_kwargs["stage_output"], + bwd_kwargs["output_grads"], + bwd_kwargs["input_values"], + ), + None, + ) + elif backward_type == "input": + return lambda: stage_backward_input( + bwd_kwargs["stage_output"], + bwd_kwargs["output_grads"], + bwd_kwargs["input_values"], + self.submod.parameters(), + ) + elif backward_type == "weight": + return lambda: ( + stage_backward_weight( + self.submod.parameters(), bwd_kwargs["param_groups"] + ), + None, + ) + else: + raise RuntimeError(f"Unknown backward type: {backward_type}") + + # If submod is wrapped by DDP + if isinstance(self.submod, DistributedDataParallel): + if last_backward: + # Last chunk, prepare for gradient reduction + # HACK: reaching into DDP implementation details here. Is there a better way? + self.submod.reducer.prepare_for_backward( # type: ignore[union-attr, operator] + list( + core.nn.parallel.distributed._find_tensors( # type: ignore[attr-defined] + bwd_kwargs["stage_output"] + ) + ) + ) + result = perform_backward(backward_type)() + else: + with self.submod.no_sync(): # type: ignore[operator] + result = perform_backward(backward_type)() + # If submod is a FSDP module + elif isinstance(self.submod, FSDPModule): + self.submod.set_is_last_backward(False) + self.submod.set_reshard_after_backward(False) + self.submod.set_requires_gradient_sync(False) + result = perform_backward(backward_type)() + if last_backward: + # Manually call post backward for FSDP + def run_post_backward(fsdp_module: FSDPModule) -> None: + fsdp_module.set_is_last_backward(True) + fsdp_module.set_reshard_after_backward(True) + fsdp_module.set_requires_gradient_sync(True) + fsdp_state = fully_shard.state(fsdp_module) + for state in fsdp_state._state_ctx.all_states: + if state._fsdp_param_group: + state._fsdp_param_group.post_backward() + + run_post_backward(self.submod) + else: + # Non-DP submodule, regular backward + result = perform_backward(backward_type)() + + grads, param_groups = result + return grads, param_groups + + def forward_one_chunk( + self, + fwd_chunk_id: int, + args: Tuple[Any, ...], + kwargs: Optional[Dict[str, Any]] = None, + ): + """ + Perform forward pass on the stage with one microbatch. + `args` and `kwargs` are the inputs from *external* to this stage. + As of Sept 2024: + - `args` applies to the first stage only, other stages receives args + through activation transmission. + - `kwargs` can be passed to all stages via respective `step` calls. + """ + + if self.is_first: + # First stage doesn't need to receive anything + composite_args = args + else: + # Receive activations for this chunk + # Activations only come in args form + composite_args = self._retrieve_recv_activations(fwd_chunk_id) + + composite_kwargs = kwargs or {} + + self._validate_fwd_input(args, kwargs) + + # Compute forward + try: + output = self.forward_maybe_with_nosync(*composite_args, **composite_kwargs) + + except Exception as e: + exc_msg = f""" + {self.log_prefix} failed to run forward: + args: {map_debug_info(composite_args)} + kwargs: {map_debug_info(composite_kwargs)} + """ + raise RuntimeError(exc_msg) from e + + # See [Note: pipeline model output type] + output_tuple = _normalize_model_output_as_tuple(output) + + # Prepare for final output merge or reduction + self.output_chunks.append(output) + + # Save activations and inputs for backward + flat_args = flatten_args(composite_args) + flat_kwargs = flatten_args(composite_kwargs) + flatten_input_tensors = flat_args + flat_kwargs + self.fwd_cache[fwd_chunk_id] = ( + output_tuple, # stage_output + flatten_input_tensors, # input_values + ) + + logger.debug( + "%s Forwarded chunk %s, outputs: %s", + self.log_prefix, + fwd_chunk_id, + map_debug_info(output), + ) + self._validate_fwd_outputs(output_tuple) + + # We return the original user-provied output, not normalized to tuple. + # See [Note: pipeline model output type] + return output + + def backward_one_chunk( + self, + bwd_chunk_id: int, + loss=None, + full_backward: bool = True, + last_backward=False, + ): + """ + Perform backward pass on the module. + This should only be called once per microbatch. + + If full_backward is True (the default), the full backward pass including weight and input gradients will be run, + and it is an error to call `backward_weight_one_chunk` for this bwd_chunk_id. + + If full_backward is False, it is optional that `dw_runner` was provided to the PipelineStage at __init__ time, + and a subsequent call to `backward_weight_one_chunk` is required to invoke dw_runner and complete the backward. + + last_backward is controlled by the schedule and signals synchronization of gradients across DP groups + after the last backward. + """ + self._check_chunk_id(bwd_chunk_id) + + ( + stage_output, + input_values, + ) = self.fwd_cache.pop(bwd_chunk_id) + + # Compute backward + if self.is_last: + # Last stage computes gradients from loss and has no gradients from + # next stage + bwd_kwargs = { + "stage_output": loss, + "output_grads": None, + "input_values": input_values, + } + else: + # Otherwise, receive gradients from next stage + grads_output = self._retrieve_recv_grads(bwd_chunk_id) + # If an input to the pipeline requires gradient, + # `core.autograd.backward` will accumulate the gradient into the + # `.grad` field of such input + bwd_kwargs = { + "stage_output": stage_output, + "output_grads": grads_output, + "input_values": input_values, + } + + grads_input: Tuple[Optional[core.Tensor], ...] = () + + # Custom backward function + if self.dw_builder: + # TODO: We may want to change our semantics so we are allowed to ignore + # the 'dw_builder' and call full_backward directly when it is a full_backward op. + grads_input, _ = self.backward_maybe_with_nosync( + "full", bwd_kwargs, last_backward=last_backward + ) + if full_backward: + self.dw_builder()() + else: + self.dw_runner[bwd_chunk_id] = self.dw_builder() + else: + if full_backward: + grads_input, _ = self.backward_maybe_with_nosync( + "full", bwd_kwargs, last_backward=last_backward + ) + else: + param_groups: List[Dict[str, Any]] | None = None + # Skip the backward for the first stage since we will perform the weight update with + # autograd.backward in backward_weight_one_chunk + if not self.is_first: + if isinstance(bwd_kwargs["stage_output"], core.Tensor): + bwd_kwargs["stage_output"] = (bwd_kwargs["stage_output"],) + + # perform the partial backwards for the inputs with a custom backward function + # when the "stage_ouput" is a loss, then it is a tensor, otherwise it is a tuple of tensors + grads_input, param_groups = self.backward_maybe_with_nosync( + "input", bwd_kwargs, last_backward=last_backward + ) + + # TODO: we dont need to save this, add to dw_runner? + self.backward_state[bwd_chunk_id] = ( + bwd_kwargs["input_values"], + param_groups, + bwd_kwargs["stage_output"], + bwd_kwargs["output_grads"], + ) + # Save a placeholder for the dw_runner + self.dw_runner[bwd_chunk_id] = lambda: None + + self.bwd_cache[bwd_chunk_id] = grads_input + + if self.is_last and not self.is_first: + # Autograd dependencies: + # rest_of_autograd_graph -> stage_output -> loss + # stage_output is no longer used in the last stage for backward and only needed + # to return to the user in merge_output_chunks, therefore + # this should be detached to release autograd graph context and free memory earlier + for t in stage_output: + t.detach_() + + logger.debug("%s Backwarded chunk %s", self.log_prefix, bwd_chunk_id) + + def backward_weight_one_chunk(self, bwd_chunk_id: int, last_backward=False): + assert bwd_chunk_id in self.dw_runner, ( + f"{self.log_prefix} Attempted to run backward_weight_one_chunk for chunk {bwd_chunk_id}" + " without first calling `backward_one_chunk(full_backward=False)`" + ) + + if self.dw_builder is not None: + self.dw_runner.pop(bwd_chunk_id)() + else: + ( + input_values, + param_groups, + stage_output, + output_grads, + ) = self.backward_state.pop(bwd_chunk_id) + + if self.stage_index != 0: + bwd_kwargs = { + "stage_output": stage_output, + "param_groups": param_groups, + } + self.backward_maybe_with_nosync( + "weight", bwd_kwargs, last_backward=last_backward + ) + else: + # TODO: figure out a better way to do this: + # if inputs does not require gradient, + # then the parameter group will not be fully captured during stage_backward_input + # in this case, we need call grad directly on the parameters + # To solve: make input fn do the intersect compute and then finish it off during W + bwd_kwargs = { + "stage_output": stage_output, + "output_grads": output_grads, + "input_values": input_values, + } + self.backward_maybe_with_nosync( + "full", bwd_kwargs, last_backward=last_backward + ) + + def _validate_fwd_input(self, args, kwargs): + """Raises a RuntimeError if shapes of input args/kwargs do not match the shapes configured for this stage.""" + + if self.is_first: + # TODO why is there a separate recv_info for each pipeline chunk? + # kwen2501: to avoid passing a `fwd_chunk_id` to this function, we + # check all chunks against args_recv_info[0] + expected_args = self.args_recv_info[0] + else: + # We don't check inputs for non-0 stages assuming they don't accept + # user inputs in canonical pipeline scenarios + return + + if len(kwargs): + # TODO- need a mapping of kwarg to position in self.args_recv_info + # Without it, we are not 100% sure how to match the args and + # expected_args. + return + + # TODO- need a mapping of kwarg to position in self.args_recv_info + # maybe it's impossible to tell whether the len mismatches because + # (a) the user passed an extra arg or missed an arg + # (b) the user did not pass a kwarg, which has a default value baked into expected_args + expected_tensors_meta = [ + e.meta if isinstance(e, _RootArgPlaceholder) else e.buffer + for e in expected_args + ] + validate_tensors_metadata( + f"Stage {self.stage_index} forward inputs", expected_tensors_meta, args + ) + + def _validate_fwd_outputs(self, outputs: Tuple[core.Tensor, ...]): + """Raises a RuntimeError if this stage produces an output of unexpected shape/dtype. + Most likely, this could be cause either by incorrect user specification of output shapes, or becuase + shape inference was done on the original model but then at runtime the model is wrapped with something like + mixed precision which changes output dtype. + """ + expected_tensors_meta = self.get_outputs_meta() + validate_tensors_metadata( + f"Stage {self.stage_index} forward outputs", expected_tensors_meta, outputs + ) + + +class _PipelineStage(_PipelineStageBase): + def __init__( + self, + stage_module: core.nn.Module, + stage_index: int, + pipe_info: PipeInfo, + device: core.device, + group: Optional[dist.ProcessGroup] = None, + ): + """ + Create a pipeline stage given a stage_module to be wrapped by this stage + and a `pipe_info` describing the stage relationship of the pipeline. + + Args: + stage_module (core.nn.Module): the module to be wrapped by this stage + stage_index (int): the index of this stage in the pipeline + pipe_info (PipeInfo): information about the pipeline, can be retrieved by `pipe.info()` + device (core.device): the device to be used by this stage + group (Optional[dist.ProcessGroup]): the process group to be used by this stage + """ + _PipelineStageBase.__init__( + self, + stage_module, + stage_index, + pipe_info.num_stages, + device, + group, + ) + self.pipe_info = pipe_info + + # Find stage nodes in graph + submod_nodes = [ + node for node in pipe_info.graph.nodes if node.op == "call_module" + ] + if len(submod_nodes) != self.num_stages: + raise AssertionError( + f"Number of submodules in pipe graph {len(submod_nodes)} does not match number of stages {self.num_stages}" + ) + + # Find my stage node in graph + self.node = submod_nodes[self.stage_index] + self.name = self.node.name + logger.info( + "[%s] Creating PipelineStage %s for %s", + self.group_rank, + stage_index, + self.name, + ) + + # Create mapping from stage name to stage index + self.submod_to_stage_index: Dict[str, int] = {} + for i, node in enumerate(submod_nodes): + self.submod_to_stage_index.setdefault(node.name, i) + + # Cast submodule to device + self._move_submod_to_device() + + def _move_submod_to_device(self): + # Move submodule to indicated device if possible + # Note: we cannot move meta module to real devices because meta tensors + # do not support to() method. One needs to do an in-place tensor swap in + # that case. + has_meta_param = any( + isinstance(p, FakeTensor) or p.is_meta for p in self.submod.parameters() + ) + if has_meta_param: + logger.debug("%s Found meta parameters!", self.log_prefix) + else: + self.submod.to(self.device) + + def _prepare_forward_infra( + self, + num_microbatches: int, + args: Tuple[Any, ...], + kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[Any, ...]: + """ + Create send/recv infrastructures for activations (during forward) + """ + # TODO(whc) + # this method should be deleted once lazy buffer allocation is implemented + # for now, it ignores args/kwargs becuase it should not need to do shape inference + for chunk in range(num_microbatches): + self.args_recv_info[chunk] = self._create_act_recv_info() + + # Send info during forward for each activation + self.act_send_info = self._create_act_send_info() + return tuple() + + def get_stage_index_of_submod( + self, + submod_name: str, + ): + """ + Given a submodule name, return the stage index of the submodule. + """ + if submod_name not in self.submod_to_stage_index: + raise AssertionError(f"Stage id of {submod_name} not found") + + return self.submod_to_stage_index[submod_name] + + def _create_act_recv_info( + self, + ): + """ + Create a tuple of `_RecvInfo` for inputs to the stage. + """ + + def create_recv_tensor(placeholder, arg_node): + """ + Create a receive buffer for a placeholder. + """ + example_value = placeholder.meta["val"] + if arg_node.op == "placeholder": + # This is a root level placeholder, thus an input argument to the entire model. + # We are likely at stage 0, hence no need to create a receive buffer. + return _RootArgPlaceholder(example_value) + + # Figure out the source stage of this input + while arg_node.target is operator.getitem: + # If the input is a getitem, we need to go deeper + arg_node = arg_node.args[0] + + assert ( + arg_node.op == "call_module" + ), f"Expecting call_module, got {arg_node.op}" + src_stage = self.get_stage_index_of_submod(arg_node.name) + + # Create a receive buffer for this placeholder + logger.debug( + "%s Creating recv buffer for input '%s' : %s, %s", + self.log_prefix, + placeholder.name, + example_value.shape, + example_value.dtype, + ) + buffer = _make_tensor_from_meta(example_value, self.device) + # In case there is backward pass, set requires_grad for receive buffers + # before first forward + if self.has_backward: + buffer.requires_grad_(True) + + return _RecvInfo( + arg_node.name, + src_stage, + buffer, + ) + + args_recv_info: List[InputInfo] = [] + # Filter out placeholder nodes from `self.submod` (a GraphModule) + placeholders = filter( # type: ignore[var-annotated] + lambda node: node.op == "placeholder", self.submod.graph.nodes # type: ignore[arg-type, union-attr] + ) + # `placeholders` are nodes internal to submod. + # `self.node.args` are dependency nodes in the outer graph. + # The two are 1:1. + for placeholder, arg_node in zip(placeholders, self.node.args): + # Create a receive buffer for this placeholder + recv_info = create_recv_tensor(placeholder, arg_node) + args_recv_info.append(recv_info) + + logger.debug( + "%s Activation recv / args info: %s", self.log_prefix, args_recv_info + ) + # `args` is a Tuple, hence we will return a Tuple[InputInfo] + return tuple(args_recv_info) + + def find_dst_rank( + self, + user: fx.Node, + ) -> Optional[int]: + """ + Find the destination rank of a `user` node. + If the `user` is not a submod, `None` may be returned. + """ + if user.op == "call_module": + # User is a stage (`call_module`) + return self.get_stage_index_of_submod(user.name) + else: + # - If user.op == "output": + # No need to send back to rank 0 + # - If user.target is stage_backward: + # No need to send assuming submod output is stored locally or + # should be re-calucated in case of activation checkpointing + return None + + def _create_act_send_info(self): + """ + Create a dict of send info for activations. + The dict is of the form: + { + output_index: [dst_rank_0, dst_rank_1, ...], + ... + } + where the list of `dst_rank`s covers the case where an output value may + be consumed by multiple stages. + """ + # Output index: List of receiver ranks + act_send_info: Dict[int, List] = {} + out_idx = 0 + + for user in self.node.users: + if user.target is operator.getitem: + # Recursively find the real destination + gi_dsts = act_send_info.setdefault(out_idx, []) + for gi_user in user.users: + dst_rank = self.find_dst_rank(gi_user) + if dst_rank is not None: + gi_dsts.append(dst_rank) + # Next `getitem` will point to the next output index + out_idx += 1 + else: + # In case of single output value, `out_idx` will not increase + dsts = act_send_info.setdefault(out_idx, []) + dst_rank = self.find_dst_rank(user) + if dst_rank is not None: + dsts.append(dst_rank) + + output_node = self._get_output_node() + output_vals: Tuple[core.Tensor] = tuple( + v.meta["val"] for v in flatten_args(output_node.args) + ) + self._configure_outputs_meta(output_vals) + + logger.debug("%s Send info: %s", self.log_prefix, act_send_info) + return act_send_info + + def _get_output_node(self): + output_nodes = [node for node in self.submod.graph.nodes if node.op == "output"] # type: ignore[union-attr] + assert len(output_nodes) == 1 + output_node = output_nodes[0] + return output_node + + def _create_grad_recv_info( + self, + act_send_info: Dict, + ) -> Tuple[_RecvInfo, ...]: + """ + Create a tuple of `_RecvInfo` for gradients. + """ + # Dict[output_index, _RecvInfo] + grad_recv_info: Dict[int, _RecvInfo] = {} + output_node = self._get_output_node() + + # The output node may take multiple args, meaning the submod having multiple output values. + output_vals = flatten_args(output_node.args) + + for out_idx, dst_list in act_send_info.items(): + if not dst_list: + # No actual receiver for activation so no grad coming back + continue + + output = output_vals[out_idx] + example_value = output.meta["val"] + logger.debug( + f"{self.log_prefix} Creating grad recv buffer for output {output.name} " # noqa: G004 + f": {example_value.shape}, {example_value.dtype}" + ) + + # TODO: otherwise needs grad accumulation + assert len(dst_list) == 1, "Backward of skip connections not supported yet" + grad_src = dst_list[0] + grad_recv_info[out_idx] = _RecvInfo( + f"{grad_src}", # noqa: G004 + grad_src, + _make_tensor_from_meta(example_value, self.device), + ) + + # Convert to tuple for convenience in get_ops and retrieve tensor + grad_recv_info_tuple = tuple(grad_recv_info.values()) + logger.debug("%s Grad recv info: %s", self.log_prefix, grad_recv_info_tuple) + return grad_recv_info_tuple + + +# A helper function to create a pipeline stage based on traced pipeline information +def build_stage( + stage_module: core.nn.Module, + stage_index: int, + pipe_info: PipeInfo, + device: core.device, + group: Optional[dist.ProcessGroup] = None, +) -> _PipelineStage: + """ + Create a pipeline stage given a stage_module to be wrapped by this stage + and pipeline information. + + Args: + stage_module (core.nn.Module): the module to be wrapped by this stage + stage_index (int): the index of this stage in the pipeline + pipe_info (PipeInfo): information about the pipeline, can be retrieved by `pipe.info()` + device (core.device): the device to be used by this stage + group (Optional[dist.ProcessGroup]): the process group to be used by this stage + + Returns: + _PipelineStage: a pipeline stage that can run with `PipelineSchedules`. + """ + return _PipelineStage( + stage_module, + stage_index, + pipe_info, + device, + group, + ) + + +class PipelineStage(_PipelineStageBase): + """ + A class representing a pipeline stage in a pipeline parallelism setup. + + PipelineStage assumes sequential partitioning of the model, i.e. the model is split into chunks where outputs from + one chunk feed into inputs of the next chunk, with no skip connections. + + PipelineStage performs runtime shape/dtype inference automatically by propagating the outputs from stage0 to + stage1 and so forth, in linear order. To bypass shape inference, pass the `input_args` and `output_args` to each + PipelineStage instance. + + Args: + submodule (nn.Module): The PyTorch module wrapped by this stage. + stage_index (int): The ID of this stage. + num_stages (int): The total number of stages. + device (core.device): The device where this stage is located. + input_args (Union[core.Tensor, Tuple[core.tensor]], optional): The input arguments for the submodule. + output_args (Union[core.Tensor, Tuple[core.tensor]], optional): The output arguments for the submodule. + group (dist.ProcessGroup, optional): The process group for distributed training. If None, default group. + dw_builder: TODO clean up comments + """ + + def __init__( + self, + submodule: nn.Module, + stage_index: int, + num_stages: int, + device: core.device, + input_args: Optional[Union[core.Tensor, Tuple[core.Tensor, ...]]] = None, + output_args: Optional[Union[core.Tensor, Tuple[core.Tensor, ...]]] = None, + group: Optional[dist.ProcessGroup] = None, + dw_builder: Optional[Callable[[], Callable[..., None]]] = None, + ): + super().__init__(submodule, stage_index, num_stages, device, group, dw_builder) + self.inputs: Optional[List[core.Tensor]] = None + self.inputs_meta: Optional[Tuple[core.Tensor, ...]] = None + # Note: inputs and submod should ideally be on meta device. We decided not to assert this (yet) becuase it + # might be breaking for existing users. + if input_args is None: + assert output_args is None, ( + "If specifying output_args, input_args must also be specified. " + "Otherwise, shape inference will be performed at runtime" + ) + else: + self.inputs_meta = ( + (input_args,) if isinstance(input_args, core.Tensor) else input_args + ) + if output_args is None: + logger.warning( + "Deprecation warning: passing input_args and performing init-time shape inference is deprecated. " + "PipelineStage now supports runtime shape inference using the real inputs provided to schedule step(). " + "Either delete `input_args` arg to `PipelineStage` to opt-into runtime shape inference, " + "or additionally pass `output_args` to `PipelineStage` to fully override shape inference. " + ) + try: + with core.no_grad(): + output_args = submodule(*self.inputs_meta) + output_args = tree_map_only( + core.Tensor, lambda x: x.to("meta"), output_args + ) + except Exception as e: + raise RuntimeError( + "Failed to perform pipeline shape inference- are your inputs on the same device as your module?" + ) from e + assert ( + output_args is not None + ), "If passing input_args, also pass output_args to override shape inference" + self._configure_outputs_meta( + (output_args,) if isinstance(output_args, core.Tensor) else output_args + ) + + # these are the buffers used in backwards send/recv, they are allocated later + self.outputs_grad: List[core.Tensor] = [] + + def stage_global_rank(peer_rank): + return ( + peer_rank + if self.group is None + else dist.get_global_rank(self.group, peer_rank) + ) + + self.prev_rank = stage_global_rank((self.group_rank - 1) % self.group_size) + self.next_rank = stage_global_rank((self.group_rank + 1) % self.group_size) + + dbg_str = ( + f"Finished pipeline stage init, {self.stage_index=}, {self.is_first=}, " # noqa: G004 + f"{self.is_last=}, {self.num_stages=}, " + ) + if self.inputs_meta is not None: + dbg_str += ( + f"inputs: {[inp.shape for inp in self.inputs_meta]}, " + f"output: {[output.shape for output in self.get_outputs_meta()]}" + ) + else: + dbg_str += " running shape-inference at runtime" + + logger.debug(dbg_str) + + def _shape_inference( + self, + args: Tuple[Any, ...], + kwargs: Optional[Dict[str, Any]] = None, + ): + if kwargs is None: + kwargs = {} + assert args is not None, "Args may be an empty tuple but not None" + + # We skip recv communication if we're the first stage, but also if the previous stage is on the same rank + # and can pass its output shapes in as args instead of using send/recv. + if ( + self.is_first + # if not first stage, then check if prev stage is on the same rank + or self.stage_index_to_group_rank[self.stage_index - 1] == self.group_rank + ): + logger.debug( + "Shape inference: stage %s skipping recv, because shape info passed in via `args`", + self.stage_index, + ) + args = tree_map_only(core.Tensor, lambda x: x.to("meta"), args) + else: + assert ( + len(args) == 0 + ), "Can't supply input args for shape inference on non-first stage" + objects = [None] + logger.debug( + "Shape inference: stage %s receiving from stage %s", + self.stage_index, + self.stage_index - 1, + ) + dist.recv_object_list( + objects, src=self.prev_rank, group=self.group, device=self.device + ) + recv_args = objects[0] + assert isinstance(recv_args, tuple), type(recv_args) + args = recv_args + + # cache input shapes for use during recv buffer allocation + self.inputs_meta = args + args = tree_map_only( + core.Tensor, lambda x: core.zeros_like(x, device=self.device), args + ) + + # set attributes needed for forward + with core.no_grad(): + logger.debug("Shape inference: stage %s running forward", self.stage_index) + outputs = self.submod(*args, **kwargs) + + # if single tensor, convert so it is always a list + if isinstance(outputs, core.Tensor): + outputs = [outputs] + + # communicate meta outputs not real outputs for two reasons + # 1 - its faster (esp. since obj coll pickles tensor data!) + # 2 - avoid activating a cuda context for the src rank when unpickling on the recv end! + outputs_meta = tuple( + tree_map_only(core.Tensor, lambda x: x.to("meta"), outputs) + ) + self._configure_outputs_meta(outputs_meta) + + # Passing outputs to the next stage: + # two cases- + # 1. Usually: use send/recv communication to pass the output + # 2. Special case: for V-schedules, 2 'adjacent' stages (e.g. stage 3, 4 in an 8-stage 4-rank V) + # pass their shape info via return value and function args rather than send/recv. + if ( + self.is_last + # if not last stage, then check if next stage is on the same rank + or self.stage_index_to_group_rank[self.stage_index + 1] == self.group_rank + ): + # Case (2) above: pass shape info via return value and caller passes it as args to next stage's + # _shape_inference call + logger.debug( + "Shape inference: stage %s skipping send to next stage", + self.stage_index, + ) + + else: + # Case (1): send shapes via send operation, and ensure not to return it to the caller + logger.debug( + "Shape inference: stage %s sending to stage %s", + self.stage_index, + self.stage_index + 1, + ) + dist.send_object_list( + [outputs_meta], + dst=self.next_rank, + group=self.group, + device=self.device, + ) + outputs_meta = tuple() + + return outputs_meta + + def _prepare_forward_infra( + self, + num_microbatches: int, + args: Tuple[Any, ...], + kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[Any, ...]: + # TODO move self.device to an argument from step API (from its input tensors)? + assert num_microbatches is not None, "TODO fix num_microbatches" + + outputs: Tuple[Any, ...] = tuple() + if self.inputs_meta is None: + outputs = self._shape_inference(args, kwargs) + + assert self.inputs_meta is not None + # Receive info during forward + # TODO: create args_recv_info lazily? (same needed for PipelineStage) + for chunk_id in range(num_microbatches): + if not self.is_first: + # We assume that we always receive from stage - 1 + recv_infos = tuple( + [ + _RecvInfo( + f"recv_for_{self.stage_index}_from_{self.stage_index - 1}", + self.stage_index - 1, + _make_tensor_from_meta(inp, self.device), + ) + for inp in self.inputs_meta + ] + ) + # In case there is backward pass, set requires_grad for receive buffers + if self.has_backward: + for r in recv_infos: + r.buffer.requires_grad_(True) + + self.args_recv_info[chunk_id] = recv_infos + else: + self.args_recv_info[chunk_id] = tuple( + [_RootArgPlaceholder(i) for i in self.inputs_meta] + ) + + # Send info during forward for each activation + # only need the rank that is being sent to + self.act_send_info: Dict[int, List] = {} + + for idx in range(len(self.get_outputs_meta())): + # We assume we always send to stage + 1 + if not self.is_last: + self.act_send_info[idx] = [self.stage_index + 1] + else: + self.act_send_info[idx] = [] + + return outputs + + def _create_grad_recv_info( + self, + act_send_info: Dict, + ) -> Tuple[_RecvInfo, ...]: + grad_recv_info: Tuple[_RecvInfo, ...] = () + if not self.is_last: + # Receiving gradients from multiple sources is not supported + # hence we only take the first destination + grad_recv_info = tuple( + [ + _RecvInfo( + f"recv_grad_for_{self.stage_index}_from_{dst_list[0]}", + dst_list[0], + _make_tensor_from_meta( + self.get_outputs_meta()[idx], self.device + ), + ) + for idx, dst_list in act_send_info.items() + ] + ) + return grad_recv_info + + def _init_p2p_neighbors(self): + """ + Set up p2p communitors between previous and next stages + by sending a dummy tensor. + + If this is used, must be called for all pipeline stages. + """ + ops = [] + recv_tensor = core.zeros(1, device="cuda") + send_tensor = core.ones(1, device="cuda") + # forward + if not self.is_first: + ops.append(dist.P2POp(dist.irecv, recv_tensor, self.prev_rank, self.group)) + if not self.is_last: + ops.append(dist.P2POp(dist.isend, send_tensor, self.next_rank, self.group)) + + # backward + if not self.is_first: + ops.append(dist.P2POp(dist.isend, send_tensor, self.prev_rank, self.group)) + if not self.is_last: + ops.append(dist.P2POp(dist.irecv, recv_tensor, self.next_rank, self.group)) + + return True diff --git a/mindnlp/core/distributed/remote_device.py b/mindnlp/core/distributed/remote_device.py new file mode 100644 index 000000000..798d43197 --- /dev/null +++ b/mindnlp/core/distributed/remote_device.py @@ -0,0 +1,120 @@ +# mypy: allow-untyped-defs +from typing import Optional, Union + +from mindnlp import core + + +class _remote_device: + """ + Represents a device on a remote worker. + + Args: + remote_device (str or core.device): Represents a device on a remote worker. + The string format should be one of the following: + + 1. "/", where the device field can be parsed as core.device type. + E.g., "trainer0/cpu", "trainer0", "ps0/cuda:0". + In addition, the device field can be optional and the default value is "cpu". + 2. "rank:/", where is the rank of the + process and device can be parsed as core.device type. + E.g., "rank:0/cpu", "rank:0", "rank:0/cuda:0" + 3. and are optional and formats like "cpu" + and "cuda:1", just represent local devices. + """ + + def __init__(self, remote_device: Union[str, core.device]): + PARSE_ERROR = ( + f"Could not parse remote_device: {remote_device}. The valid format is " + "'/' or 'rank:/' or ''" + ) + self._worker_name = None + self._rank = None + self._device: Optional[Union[str, int, core.device]] = None + + if isinstance(remote_device, core.device): + self._device = remote_device + elif isinstance(remote_device, str): + fields = remote_device.split("/") + if len(fields) == 2: + self._worker_name, self._device = fields + elif len(fields) == 1: + # Check if this is a valid device. + if _remote_device._is_valid_local_device(fields[0]): + self._device = fields[0] + else: + self._worker_name = fields[0] + self._device = "cpu" + else: + raise ValueError(PARSE_ERROR) + else: + raise TypeError(f"Invalid type for remote_device: {type(remote_device)}") + + # Do some basic sanity check (no empty string) + if self._worker_name is not None and not self._worker_name: + raise ValueError(PARSE_ERROR) + + # Validate the device. + self._device = core.device(self._device) + + # Check for rank based format. + if self._worker_name is not None: + fields = self._worker_name.split(":") + if len(fields) == 2: + # rank:/device format, extract rank + if fields[0] == "rank" and fields[1].isdigit(): + self._rank = int(fields[1]) # type: ignore[assignment] + self._worker_name = None + else: + raise ValueError(PARSE_ERROR) + elif len(fields) > 2: + raise ValueError(PARSE_ERROR) + + @staticmethod + def _is_valid_local_device(device): + # Check for core.device + try: + core.device(device) + return True + except Exception: + return False + + def worker_name(self) -> Optional[str]: + """Return the name of remote worker representing the remote device and ``None`` if no worker name is available.""" + return self._worker_name + + def rank(self) -> Optional[int]: + """ + Returns the rank of remote worker representing the remote device. + Returns ``None`` if no rank is available. + """ + return self._rank + + def device(self) -> core.device: + """Return the local device on the remote worker.""" + return self._device # type: ignore[return-value] + + def __repr__(self): + if self._device is not None: + if self._worker_name is not None: + return f"{self._worker_name}/{self._device}" + elif self._rank is not None: + return f"rank:{self._rank}/{self._device}" + else: + return str(self._device) + else: + if self._worker_name is not None: + return f"{self._worker_name}" + elif self._rank is not None: + return f"{self._rank}" + else: + raise RuntimeError("Invalid state!") + + def __eq__(self, other): + return isinstance(other, _remote_device) and ( + self._worker_name == other._worker_name + and self._device == other._device + and self._rank == other._rank + ) + + def __hash__(self): + return hash(self._worker_name) ^ hash(self._device) ^ hash(self._rank) diff --git a/mindnlp/core/distributed/rendezvous.py b/mindnlp/core/distributed/rendezvous.py new file mode 100644 index 000000000..57883364d --- /dev/null +++ b/mindnlp/core/distributed/rendezvous.py @@ -0,0 +1,286 @@ +# mypy: allow-untyped-defs +try: + from urllib.parse import urlparse, urlunparse +except ImportError as e: + raise ImportError( + "urllib cannot be found, urlparse from python2 is no longer supported." + ) from e + +import numbers +import os +import sys +from datetime import timedelta +from typing import Callable, Dict, Iterator, Optional, Tuple + +from core.distributed import FileStore, Store, TCPStore + +from .constants import default_pg_timeout + + +_rendezvous_handlers: Dict[str, Callable[..., Iterator[Tuple[Store, int, int]]]] = {} + +__all__ = ["register_rendezvous_handler", "rendezvous"] + + +def register_rendezvous_handler(scheme, handler): + """ + Register a new rendezvous handler. + + Before we can run collective algorithms, participating processes + need to find each other and exchange information to be able to + communicate. We call this process rendezvous. + + The outcome of the rendezvous process is a triplet containing a + shared key/value store, the rank of the process, and the total + number of participating processes. + + If none of the bundled rendezvous methods apply to your execution + environment you can opt to register your own rendezvous handler. + Pick a unique name and use the URL scheme to identify it when + calling the `rendezvous()` function. + + Args: + scheme (str): URL scheme to identify your rendezvous handler. + handler (function): Handler that is invoked when the + `rendezvous()` function is called with a URL that uses + the corresponding scheme. It must be a generator function + that yields the triplet. + """ + global _rendezvous_handlers + if scheme in _rendezvous_handlers: + raise RuntimeError(f"Rendezvous handler for {scheme}:// already registered") + _rendezvous_handlers[scheme] = handler + + +# Query will have format "rank=0&world_size=1" and is +# converted into {"rank": 0, "world_size": 1} +def _query_to_dict(query: str) -> Dict[str, str]: + return { + pair[0]: pair[1] + for pair in (pair.split("=") for pair in filter(None, query.split("&"))) + } + + +def _get_use_libuv_from_query_dict(query_dict: Dict[str, str]) -> bool: + # libuv is the default backend for TCPStore. To enable the non-libuv backend, + # user can explicitly specify ``use_libuv=0`` in the URL parameter. + return query_dict.get("use_libuv", os.environ.get("USE_LIBUV", "1")) == "1" + + +def _rendezvous_helper(url: str, rank: int, world_size_opt: Optional[int], **kwargs): + result = urlparse(url) + if world_size_opt is None: + world_size = -1 + if result.scheme == "env": + rank = int(os.environ.get("RANK", rank)) + # If the world_size env variable is not present then it is a dynamic group + world_size = int(os.environ.get("WORLD_SIZE", world_size)) + else: + world_size = world_size_opt + if rank != -1 or world_size != -1 or world_size_opt is None: + query_dict = _query_to_dict(result.query) + assert ( + "rank" not in query_dict and "world_size" not in query_dict + ), f"The url: {url} has node-specific arguments(rank, world_size) already." + if rank != -1: + query_dict["rank"] = str(rank) + if world_size != -1 or world_size_opt is None: + query_dict["world_size"] = str(world_size) + result = result._replace( + query=f"{'&'.join([f'{k}={v}' for k, v in query_dict.items()])}" + ) + url = urlunparse(result) + + if result.scheme not in _rendezvous_handlers: + raise RuntimeError(f"No rendezvous handler for {result.scheme}://") + return _rendezvous_handlers[result.scheme](url, **kwargs) + + +def rendezvous(url: str, rank: int = -1, world_size: int = -1, **kwargs): + if not isinstance(url, (str, bytes)): + raise RuntimeError(f"`url` must be a string. {type(url)}: {url}") + + if not isinstance(rank, numbers.Integral): + raise RuntimeError(f"`rank` must be an integer. {rank}") + + if not isinstance(world_size, numbers.Integral): + raise RuntimeError(f"`world_size` must be an integer. {world_size}") + + return _rendezvous_helper(url, rank, world_size, **kwargs) + + +def _create_store_from_options(backend_options, rank): + store, _, _ = next(_rendezvous_helper(backend_options.init_method, rank, None)) + return store + + +def _rendezvous_error(msg): + return ValueError("Error initializing core.distributed using " + msg) + + +def _file_rendezvous_handler(url: str, **kwargs): + def _error(msg): + return _rendezvous_error("file:// rendezvous: " + msg) + + result = urlparse(url) + path = result.path + if sys.platform == "win32": + import urllib.request + + full_path = result.netloc + result.path + path = urllib.request.url2pathname(full_path) + if path: + # Normalizing an empty string produces ".", which is not expected. + path = os.path.normpath(path) + + if not path: + raise _error("path missing") + query_dict = _query_to_dict(result.query) + if "rank" not in query_dict: + raise _error("rank parameter missing") + if "world_size" not in query_dict: + raise _error("world size parameter missing") + + rank = int(query_dict["rank"]) + world_size = int(query_dict["world_size"]) + store = FileStore(path, world_size) + yield (store, rank, world_size) + + # If this configuration is invalidated, there is nothing we can do about it + raise RuntimeError("Unable to perform rerendezvous using file:// method") + + +def _torchelastic_use_agent_store() -> bool: + return os.environ.get("TORCHELASTIC_USE_AGENT_STORE", None) == str(True) + + +def _create_c10d_store( + hostname, port, rank, world_size, timeout, use_libuv=True +) -> Store: + """ + Smartly creates a c10d Store object on ``rank`` based on whether we need to re-use agent store. + + The TCPStore server is assumed to be hosted + on ``hostname:port``. + + By default, the TCPStore server uses the asynchronous implementation + ``LibUVStoreDaemon`` which utilizes libuv. + + If ``torchelastic_use_agent_store()`` is ``True``, then it is assumed that + the agent leader (node rank 0) hosts the TCPStore server (for which the + endpoint is specified by the given ``hostname:port``). Hence + ALL ranks will create and return a TCPStore client (e.g. ``start_daemon=False``). + + If ``torchelastic_use_agent_store()`` is ``False``, then rank 0 will host + the TCPStore (with multi-tenancy) and it is assumed that rank 0's hostname + and port are correctly passed via ``hostname`` and ``port``. All + non-zero ranks will create and return a TCPStore client. + """ + # check if port is uint16_t + if not 0 <= port < 2**16: + raise ValueError(f"port must have value from 0 to 65535 but was {port}.") + + if _torchelastic_use_agent_store(): + # We create a new TCPStore for every retry so no need to add prefix for each attempt. + return TCPStore( + host_name=hostname, + port=port, + world_size=world_size, + is_master=False, + timeout=timeout, + ) + else: + start_daemon = rank == 0 + return TCPStore( + host_name=hostname, + port=port, + world_size=world_size, + is_master=start_daemon, + timeout=timeout, + multi_tenant=True, + use_libuv=use_libuv, + ) + + +def _tcp_rendezvous_handler( + url: str, timeout: timedelta = default_pg_timeout, **kwargs +): + def _error(msg): + return _rendezvous_error("tcp:// rendezvous: " + msg) + + result = urlparse(url) + if not result.port: + raise _error("port number missing") + query_dict = _query_to_dict(result.query) + if "rank" not in query_dict: + raise _error("rank parameter missing") + if "world_size" not in query_dict: + raise _error("world size parameter missing") + + rank = int(query_dict["rank"]) + world_size = int(query_dict["world_size"]) + use_libuv = _get_use_libuv_from_query_dict(query_dict) + + assert result.hostname is not None + + store = _create_c10d_store( + result.hostname, result.port, rank, world_size, timeout, use_libuv + ) + + yield (store, rank, world_size) + + # If this configuration is invalidated, there is nothing we can do about it + raise RuntimeError("Unable to perform re-rendezvous using tcp:// method") + + +def _env_rendezvous_handler( + url: str, timeout: timedelta = default_pg_timeout, **kwargs +): + def _error(msg): + return _rendezvous_error("env:// rendezvous: " + msg) + + def _env_error(var): + return _error(f"environment variable {var} expected, but not set") + + def _get_env_or_raise(env_var: str) -> str: + env_val = os.environ.get(env_var, None) + if not env_val: + raise _env_error(env_var) + else: + return env_val + + result = urlparse(url) + query_dict = _query_to_dict(result.query) + + rank: int + world_size: int + master_port: int + master_addr: str + + if "rank" in query_dict: + rank = int(query_dict["rank"]) + else: + rank = int(_get_env_or_raise("RANK")) + + if "world_size" in query_dict: + world_size = int(query_dict["world_size"]) + else: + world_size = int(_get_env_or_raise("WORLD_SIZE")) + + master_addr = _get_env_or_raise("MASTER_ADDR") + master_port = int(_get_env_or_raise("MASTER_PORT")) + use_libuv = _get_use_libuv_from_query_dict(query_dict) + + store = _create_c10d_store( + master_addr, master_port, rank, world_size, timeout, use_libuv + ) + + yield (store, rank, world_size) + + # If this configuration is invalidated, there is nothing we can do about it + raise RuntimeError("Unable to perform re-rendezvous using env:// method") + + +register_rendezvous_handler("tcp", _tcp_rendezvous_handler) +register_rendezvous_handler("env", _env_rendezvous_handler) +register_rendezvous_handler("file", _file_rendezvous_handler) diff --git a/mindnlp/core/distributed/rpc/__init__.py b/mindnlp/core/distributed/rpc/__init__.py new file mode 100644 index 000000000..3200e5efb --- /dev/null +++ b/mindnlp/core/distributed/rpc/__init__.py @@ -0,0 +1,249 @@ +# mypy: allow-untyped-defs +import logging +import os +import threading +import warnings +from datetime import timedelta +from typing import Generator, Tuple +from urllib.parse import urlparse + +from mindnlp import core +from mindnlp import core.distributed as dist + + +__all__ = ["is_available"] + + +logger = logging.getLogger(__name__) + + +_init_counter = 0 +_init_counter_lock = threading.Lock() + + +def is_available() -> bool: + return False + + +# if is_available() and not core._C._rpc_init(): +# raise RuntimeError("Failed to initialize core.distributed.rpc") + + +if is_available(): + import numbers + + from mindnlp import core.distributed.autograd as dist_autograd + from ..c10d import Store + from core._C._distributed_rpc import ( # noqa: F401 + _cleanup_python_rpc_handler, + _DEFAULT_INIT_METHOD, + _DEFAULT_NUM_WORKER_THREADS, + _DEFAULT_RPC_TIMEOUT_SEC, + _delete_all_user_and_unforked_owner_rrefs, + _destroy_rref_context, + _disable_jit_rref_pickle, + _disable_server_process_global_profiler, + _enable_jit_rref_pickle, + _enable_server_process_global_profiler, + _get_current_rpc_agent, + _invoke_remote_builtin, + _invoke_remote_python_udf, + _invoke_remote_torchscript, + _invoke_rpc_builtin, + _invoke_rpc_python_udf, + _invoke_rpc_torchscript, + _is_current_rpc_agent_set, + _reset_current_rpc_agent, + _rref_context_get_debug_info, + _set_and_start_rpc_agent, + _set_profiler_node_id, + _set_rpc_timeout, + _TensorPipeRpcBackendOptionsBase, + _UNSET_RPC_TIMEOUT, + enable_gil_profiling, + get_rpc_timeout, + PyRRef, + RemoteProfilerManager, + RpcAgent, + RpcBackendOptions, + TensorPipeAgent, + WorkerInfo, + ) + + from . import api, backend_registry, functions + from .api import * # noqa: F401,F403 + from .backend_registry import BackendType + from .options import TensorPipeRpcBackendOptions # noqa: F401 + from .server_process_global_profiler import _server_process_global_profile + + rendezvous_iterator: Generator[Tuple[Store, int, int], None, None] + + __all__ += ["init_rpc", "BackendType", "TensorPipeRpcBackendOptions"] + __all__ = __all__ + api.__all__ + backend_registry.__all__ # noqa: PLE0605 + + def init_rpc( + name, + backend=None, + rank=-1, + world_size=None, + rpc_backend_options=None, + ): + r""" + Initializes RPC primitives such as the local RPC agent + and distributed autograd, which immediately makes the current + process ready to send and receive RPCs. + + Args: + name (str): a globally unique name of this node. (e.g., + ``Trainer3``, ``ParameterServer2``, ``Master``, ``Worker1``) + Name can only contain number, alphabet, underscore, colon, + and/or dash, and must be shorter than 128 characters. + backend (BackendType, optional): The type of RPC backend + implementation. Supported values is + ``BackendType.TENSORPIPE`` (the default). + See :ref:`rpc-backends` for more information. + rank (int): a globally unique id/rank of this node. + world_size (int): The number of workers in the group. + rpc_backend_options (RpcBackendOptions, optional): The options + passed to the RpcAgent constructor. It must be an agent-specific + subclass of :class:`~core.distributed.rpc.RpcBackendOptions` + and contains agent-specific initialization configurations. By + default, for all agents, it sets the default timeout to 60 + seconds and performs the rendezvous with an underlying process + group initialized using ``init_method = "env://"``, + meaning that environment variables ``MASTER_ADDR`` and + ``MASTER_PORT`` need to be set properly. See + :ref:`rpc-backends` for more information and find which options + are available. + """ + core._C._log_api_usage_once("core.distributed.init_rpc") + if backend is not None and not isinstance( + backend, backend_registry.BackendType + ): + raise TypeError("Argument backend must be a member of BackendType") + + if rpc_backend_options is not None and not isinstance( + rpc_backend_options, RpcBackendOptions + ): + raise TypeError( + "Argument rpc_backend_options must be an instance of RpcBackendOptions" + ) + + # Try to detect the backend from the options + if backend is None and rpc_backend_options is not None: + for candidate_backend in BackendType: + if isinstance( + rpc_backend_options, + type( + backend_registry.construct_rpc_backend_options( + candidate_backend + ) + ), + ): + backend = candidate_backend + break + else: + raise TypeError( + f"Could not infer backend for options {rpc_backend_options}" + ) + # Ignore type error because mypy doesn't handle dynamically generated type objects (#4865) + if backend != BackendType.TENSORPIPE: # type: ignore[attr-defined] + logger.warning( + "RPC was initialized with no explicit backend but with options " # type: ignore[attr-defined] + "corresponding to %(backend)s, hence that backend will be used " + "instead of the default BackendType.TENSORPIPE. To silence this " + "warning pass `backend=%(backend)s` explicitly.", + {"backend": backend}, + ) + + if backend is None: + backend = BackendType.TENSORPIPE # type: ignore[attr-defined] + + if rpc_backend_options is None: + # default construct a set of RPC backend options. + rpc_backend_options = backend_registry.construct_rpc_backend_options( + backend + ) + + # Create store, performs rendezvous for static RPC group. + if not world_size: + # If world_size is not set in construction and also not set in environment variables + # The store will be created for the dynamic group setting + store = dist._create_store_from_options(rpc_backend_options, rank) + else: + # This rendezvous state sometimes is destroyed before all processes + # finishing handshaking. To avoid that issue, we make it global to + # keep it alive. + global rendezvous_iterator + rendezvous_iterator = dist.rendezvous( + rpc_backend_options.init_method, rank=rank, world_size=world_size + ) + store, _, _ = next(rendezvous_iterator) + # Use same timeout as RPC. + store.set_timeout(timedelta(seconds=rpc_backend_options.rpc_timeout)) + + # Use a PrefixStore to distinguish multiple invocations. + with _init_counter_lock: + global _init_counter + store = dist.PrefixStore(str(f"rpc_prefix_{_init_counter}"), store) + _init_counter += 1 + + # Initialize autograd before RPC since _init_rpc_backend guarantees all + # processes sync via the store. If we initialize autograd after RPC, + # there could be a race where some nodes might have initialized autograd + # and others might not have. As a result, a node calling + # core.distributed.autograd.backward() would run into errors since + # other nodes might not have been initialized. + dist_autograd._init(rank) + + _set_profiler_node_id(rank) + # Initialize RPC. + _init_rpc_backend(backend, store, name, rank, world_size, rpc_backend_options) + + def _validate_rpc_args(backend, store, name, rank, world_size, rpc_backend_options): + type_mapping = { + backend: backend_registry.BackendType, + store: dist.Store, + name: str, + rank: numbers.Integral, + # world_size can be None for a dynamic group + world_size: (numbers.Integral, type(None)), + rpc_backend_options: RpcBackendOptions, + } + for arg, arg_type in type_mapping.items(): + if not isinstance(arg, arg_type): # type: ignore[arg-type] + raise RuntimeError( + f"Argument {arg} must be of type {arg_type} but got type {type(arg)}" + ) + + def _init_rpc_backend( + backend=BackendType.TENSORPIPE, # type: ignore[attr-defined] + store=None, + name=None, + rank=-1, + world_size=None, + rpc_backend_options=None, + ): + _validate_rpc_args(backend, store, name, rank, world_size, rpc_backend_options) + + if _is_current_rpc_agent_set(): + raise RuntimeError("RPC is already initialized") + + # Initialize RPC. + rpc_agent = backend_registry.init_backend( + backend, + store=store, + name=name, + rank=rank, + world_size=world_size, + rpc_backend_options=rpc_backend_options, + ) + + api._init_rpc_states(rpc_agent) + + @api._require_initialized + def _get_debug_info(): + info = _rref_context_get_debug_info() + info.update(api._get_current_rpc_agent().get_debug_info()) + info.update(dist_autograd._get_debug_info()) + return info diff --git a/mindnlp/core/distributed/rpc/_testing/__init__.py b/mindnlp/core/distributed/rpc/_testing/__init__.py new file mode 100644 index 000000000..5cdc5af17 --- /dev/null +++ b/mindnlp/core/distributed/rpc/_testing/__init__.py @@ -0,0 +1,20 @@ +# mypy: allow-untyped-defs + +from mindnlp import core + + +def is_available(): + return hasattr(core._C, "_faulty_agent_init") + + +if is_available() and not core._C._faulty_agent_init(): + raise RuntimeError("Failed to initialize core.distributed.rpc._testing") + +if is_available(): + # Registers FAULTY_TENSORPIPE RPC backend. + from core._C._distributed_rpc_testing import ( + FaultyTensorPipeAgent, + FaultyTensorPipeRpcBackendOptions, + ) + + from . import faulty_agent_backend_registry diff --git a/mindnlp/core/distributed/rpc/_testing/faulty_agent_backend_registry.py b/mindnlp/core/distributed/rpc/_testing/faulty_agent_backend_registry.py new file mode 100644 index 000000000..bfbc84c1b --- /dev/null +++ b/mindnlp/core/distributed/rpc/_testing/faulty_agent_backend_registry.py @@ -0,0 +1,62 @@ +#!/usr/bin/env python3 +# mypy: allow-untyped-defs + +from mindnlp import core.distributed as dist +from mindnlp import core.distributed.rpc as rpc + + +def _faulty_tensorpipe_construct_rpc_backend_options_handler( + rpc_timeout, + init_method, + num_worker_threads, + messages_to_fail, + messages_to_delay, + num_fail_sends, + **kwargs, +): + from . import FaultyTensorPipeRpcBackendOptions + + return FaultyTensorPipeRpcBackendOptions( + num_worker_threads=num_worker_threads, + rpc_timeout=rpc_timeout, + init_method=init_method, + messages_to_fail=messages_to_fail, + messages_to_delay=messages_to_delay, + num_fail_sends=num_fail_sends, + ) + + +def _faulty_tensorpipe_init_backend_handler( + store, name, rank, world_size, rpc_backend_options +): + from core.distributed.rpc import api + + from . import FaultyTensorPipeAgent, FaultyTensorPipeRpcBackendOptions + + if not isinstance(store, dist.Store): + raise TypeError(f"`store` must be a c10d::Store. {store}") + + if not isinstance(rpc_backend_options, FaultyTensorPipeRpcBackendOptions): + raise TypeError( + f"`rpc_backend_options` must be a `FaultyTensorPipeRpcBackendOptions`. {rpc_backend_options}" + ) + + agent = FaultyTensorPipeAgent( + store, + name, + rank, + world_size, + rpc_backend_options, + {}, # reverse_device_map + [], # devices + ) + api._init_rpc_states(agent) + + return agent + + +rpc.backend_registry.register_backend( + "FAULTY_TENSORPIPE", + _faulty_tensorpipe_construct_rpc_backend_options_handler, + _faulty_tensorpipe_init_backend_handler, +) diff --git a/mindnlp/core/distributed/rpc/_utils.py b/mindnlp/core/distributed/rpc/_utils.py new file mode 100644 index 000000000..d1138478b --- /dev/null +++ b/mindnlp/core/distributed/rpc/_utils.py @@ -0,0 +1,47 @@ +# mypy: allow-untyped-defs +import logging +from contextlib import contextmanager +from typing import cast + +from . import api, TensorPipeAgent + + +logger = logging.getLogger(__name__) + + +@contextmanager +def _group_membership_management(store, name, is_join): + token_key = "RpcGroupManagementToken" + join_or_leave = "join" if is_join else "leave" + my_token = f"Token_for_{name}_{join_or_leave}" + while True: + # Retrieve token from store to signal start of rank join/leave critical section + returned = store.compare_set(token_key, "", my_token).decode() + if returned == my_token: + # Yield to the function this context manager wraps + yield + # Finished, now exit and release token + # Update from store to signal end of rank join/leave critical section + store.set(token_key, "") + # Other will wait for this token to be set before they execute + store.set(my_token, "Done") + break + else: + # Store will wait for the token to be released + try: + store.wait([returned]) + except RuntimeError: + logger.error( + "Group membership token %s timed out waiting for %s to be released.", + my_token, + returned, + ) + raise + + +def _update_group_membership(worker_info, my_devices, reverse_device_map, is_join): + agent = cast(TensorPipeAgent, api._get_current_rpc_agent()) + ret = agent._update_group_membership( + worker_info, my_devices, reverse_device_map, is_join + ) + return ret diff --git a/mindnlp/core/distributed/rpc/api.py b/mindnlp/core/distributed/rpc/api.py new file mode 100644 index 000000000..c570e489f --- /dev/null +++ b/mindnlp/core/distributed/rpc/api.py @@ -0,0 +1,965 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs + +import collections +import contextlib +import functools +import inspect +import logging +import threading +from typing import Any, Dict, Generic, Set, TYPE_CHECKING, TypeVar + +from mindnlp import core +from core._C._distributed_rpc import ( + _cleanup_python_rpc_handler, + _delete_all_user_and_unforked_owner_rrefs, + _destroy_rref_context, + _get_current_rpc_agent, + _invoke_remote_builtin, + _invoke_remote_python_udf, + _invoke_remote_torchscript, + _invoke_rpc_builtin, + _invoke_rpc_python_udf, + _invoke_rpc_torchscript, + _is_current_rpc_agent_set, + _reset_current_rpc_agent, + _set_and_start_rpc_agent, + get_rpc_timeout, + PyRRef, + RemoteProfilerManager, + TensorPipeAgent, + WorkerInfo, +) +from core.futures import Future + +from ._utils import _group_membership_management, _update_group_membership +from .constants import DEFAULT_SHUTDOWN_TIMEOUT, UNSET_RPC_TIMEOUT +from .internal import ( + _build_rpc_profiling_key, + _internal_rpc_pickler, + PythonUDF, + RPCExecMode, +) + + +__all__ = [ + "shutdown", + "get_worker_info", + "remote", + "rpc_sync", + "rpc_async", + "RRef", + "AllGatherStates", + "method_factory", + "new_method", +] + + +logger = logging.getLogger(__name__) + +# NB: Ignoring RRef leaks during shutdown. Without this, applications have to +# make sure there is no references to any RRef in the application code and +# Python GC has done its job to delete those RRefs. This is could result in bad +# debugging experiences especially when for large applications. Therefore, by +# default, we are going to ignore RRef leaks during shutdown. This is usually +# fine as shutdown means applications have done training and no longer care +# about states. +# +# To enable RRef leak checking, set this _ignore_rref_leak to False +_ignore_rref_leak = True +_default_pickler = _internal_rpc_pickler + + +@contextlib.contextmanager +def _use_rpc_pickler(rpc_pickler): + r""" + rpc_pickler: (.internal._InternalRPCPickler) Overrides the default RPC pickler + """ + global _default_pickler + _default_pickler = rpc_pickler + try: + yield + finally: + _default_pickler = _internal_rpc_pickler + + +def _require_initialized(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + if not _is_current_rpc_agent_set(): + raise RuntimeError( + "RPC has not been initialized. Call " + "core.distributed.rpc.init_rpc first." + ) + return func(*args, **kwargs) + + return wrapper + + +class AllGatherStates: + def __init__(self): + # Each `gathered_objects` is an empty dict at beginning. + # The leader worker is elected as the first worker in a sorted worker + # name list. Whenever there is a worker entering `_all_gather()`, it + # runs `_gather_to_leader()` on the leader to add its own name and + # data obj to this dict. The leader also adds itself's name to the dict + # on calling `_all_gather()`. + # Once `set(gathered_objects.keys()) == _ALL_WORKER_NAMES`, the leader + # will broadcast the gathered dict to all follower workers and set their + # `gathered_objects` field and the `proceed_signal` field. + self.gathered_objects = {} + # All workers wait on this signal until it receives all gathered + # objects. + self.proceed_signal = threading.Event() + + +# States used by `def _all_gather()`. +# `_ALL_WORKER_NAMES` is initialized on initializing RPC layer. +_ALL_WORKER_NAMES: Set[Any] = set() +_all_gather_dict_lock = threading.RLock() +_all_gather_sequence_id: Dict[str, int] = {} +_all_gather_sequence_id_to_states: collections.defaultdict = collections.defaultdict( + AllGatherStates +) + + +def _init_rpc_states(agent): + worker_infos = agent.get_worker_infos() + global _ALL_WORKER_NAMES + _ALL_WORKER_NAMES = {worker_info.name for worker_info in worker_infos} + + # NB: backend implementation might have already set the rpc_agent. + if not _is_current_rpc_agent_set(): + _set_and_start_rpc_agent(agent) + + +def _gather_to_leader(sequence_id, worker_name, obj, worker_names=None): + with _all_gather_dict_lock: + if not worker_names: + worker_names = _ALL_WORKER_NAMES + assert ( + worker_name in worker_names + ), f"{worker_name} is not expected by leader." + states = _all_gather_sequence_id_to_states[sequence_id] + assert ( + worker_name not in states.gathered_objects + ), f"{worker_name} reported intent sequence id {sequence_id} twice. " + states.gathered_objects[worker_name] = obj + if worker_names == set(states.gathered_objects.keys()): + states.proceed_signal.set() + + +def _broadcast_to_followers(sequence_id, objects_map): + with _all_gather_dict_lock: + states = _all_gather_sequence_id_to_states[sequence_id] + + assert ( + not states.proceed_signal.is_set() + ), f"Termination signal sequence id {sequence_id} got set twice." + states.gathered_objects = objects_map + states.proceed_signal.set() + + +_thread_local_var = threading.local() + + +@contextlib.contextmanager +def _wait_all(): + r""" + A context manager that collects all futures returned by ``rpc_async`` and + waits them on the context manager's exit; relieving the user of needing + to explicitly call wait. + + + Example:: + >>> # xdoctest: +SKIP("distributed") + >>> # On worker 0: + >>> from mindnlp import core + >>> from mindnlp import core.distributed.rpc as rpc + >>> rpc.init_rpc("worker0", rank=0, world_size=2) + >>> with rpc._wait_all(): + >>> fut_1 = rpc.rpc_async(dst, core.add, (core.ones(2, 2), 1)) + >>> fut_2 = rpc.rpc_async(dst, core.add, (core.ones(2, 2), 1)) + >>> #fut_1 and fut_2 are waited on + """ + _thread_local_var.future_list = [] + try: + yield + finally: + try: + core.futures.wait_all(_thread_local_var.future_list) + finally: + del _thread_local_var.future_list + + +@_require_initialized +def _all_gather(obj, worker_names=None, timeout: float = UNSET_RPC_TIMEOUT): + r""" + This is similar to core.distributed.all_gather(), but is using RPC. It + picks the worker with the smallest name (alphabetic order) as the leader. + Then all followers send their data ``obj`` to the leader. After the leader + has received all, it will broadcast the results back to all followers. This + function blocks until all workers have received the gathered results. + """ + if not worker_names: + assert ( + _ALL_WORKER_NAMES is not None + ), "`_ALL_WORKER_NAMES` is not initialized for `def _all_gather`." + worker_names = _ALL_WORKER_NAMES + leader_name = min(worker_names) + + self_name = _get_current_rpc_agent().get_worker_info().name + + with _all_gather_dict_lock: + concat_names = "".join(sorted(worker_names)) + sequence_num = _all_gather_sequence_id.get(concat_names, 0) + _all_gather_sequence_id[concat_names] = sequence_num + 1 + sequence_id = concat_names + str(sequence_num) + + is_leader = leader_name == self_name + + if timeout == UNSET_RPC_TIMEOUT: + # Timeout is specified by agent for RPC calls + rpc_timeout = get_rpc_timeout() + # No timeout for signal + signal_timeout = None + elif timeout == DEFAULT_SHUTDOWN_TIMEOUT: + # No timeout for RPC + rpc_timeout = timeout + # No timeout for signal + signal_timeout = None + else: + # Signal and RPC timeout use the same timeout + signal_timeout = rpc_timeout = timeout + + # Phase 1: Followers send it's object to the leader + if is_leader: + _gather_to_leader(sequence_id, self_name, obj, worker_names) + else: + rpc_sync( + leader_name, + _gather_to_leader, + args=(sequence_id, self_name, obj, worker_names), + timeout=rpc_timeout, + ) + + with _all_gather_dict_lock: + states = _all_gather_sequence_id_to_states[sequence_id] + + # Timeout is either set by function parameter or None (which is indefinite) + states.proceed_signal.wait(timeout=signal_timeout) + + # Phase 2: Leader broadcast gathered results to all followers + # Leader's signal is the first to be unblocked, after receiving all + # followers' data objects. + if is_leader: + worker_name_to_response_future_dict = {} + for follower_name in worker_names - {leader_name}: + fut = rpc_async( + follower_name, + _broadcast_to_followers, + args=(sequence_id, states.gathered_objects), + timeout=rpc_timeout, + ) + worker_name_to_response_future_dict[follower_name] = fut + + errors = [] + for follower_name, fut in worker_name_to_response_future_dict.items(): + try: + fut.wait() + except RuntimeError as ex: + errors.append((follower_name, ex)) + + if errors: + raise RuntimeError( + f"Followers {[e[0] for e in errors]} timed out in _all_gather " + f"after {rpc_timeout:.2f} seconds. The first exception is {errors[0][1]}" + ) + + # Clean up for the states using the sequence_id + with _all_gather_dict_lock: + states = _all_gather_sequence_id_to_states.pop(sequence_id) + return states.gathered_objects + + +@_require_initialized +def _barrier(worker_names): + r""" + Synchronizes local and remote RPC processes. + + This will block until all local and remote RPC processes specified under worker_names + reach this method to wait for all outstanding work to complete. + + Args: + worker_names (List[str]): The set of workers to synchronize. + + """ + try: + _all_gather(None, set(worker_names)) + except RuntimeError as ex: + logger.error("Failed to complete barrier, got error %s", ex) + + +@_require_initialized +def _wait_all_workers(timeout=DEFAULT_SHUTDOWN_TIMEOUT): + r""" + Block until all local and remote RPC processes reach this method and wait + for all outstanding work to complete. Every RPC process must call this + method before exit to perform a graceful shutdown. This should be used to + terminate the RPC framework, and there is no guarantee that the RPC + framework will work after this method returns. + """ + try: + _all_gather(None, timeout=timeout) + except RuntimeError as ex: + logger.error( + "Failed to respond to 'Shutdown Proceed' in time, got error %s", ex + ) + raise ex + + +@_require_initialized +def shutdown(graceful=True, timeout=DEFAULT_SHUTDOWN_TIMEOUT): + r""" + Perform a shutdown of the RPC agent, and then destroy the RPC agent. This + stops the local agent from accepting outstanding requests, and shuts + down the RPC framework by terminating all RPC threads. If ``graceful=True``, + this will block until all local and remote RPC processes reach this method + and wait for all outstanding work to complete. Otherwise, if + ``graceful=False``, this is a local shutdown, and it does not wait for other + RPC processes to reach this method. + + .. warning:: + For :class:`~core.futures.Future` objects returned by + :meth:`~core.distributed.rpc.rpc_async`, ``future.wait()`` should not + be called after ``shutdown()``. + + Args: + graceful (bool): Whether to do a graceful shutdown or not. If True, + this will 1) wait until there is no pending system + messages for ``UserRRefs`` and delete them; 2) block + until all local and remote RPC processes have reached + this method and wait for all outstanding work to + complete. + + Example:: + Make sure that ``MASTER_ADDR`` and ``MASTER_PORT`` are set properly + on both workers. Refer to :meth:`~core.distributed.init_process_group` + API for more details. For example, + + export MASTER_ADDR=localhost + export MASTER_PORT=5678 + + Then run the following code in two different processes: + + >>> # xdoctest: +SKIP + >>> # On worker 0: + >>> from mindnlp import core + >>> from mindnlp import core.distributed.rpc as rpc + >>> rpc.init_rpc("worker0", rank=0, world_size=2) + >>> # do some work + >>> result = rpc.rpc_sync("worker1", core.add, args=(core.ones(1), 1)) + >>> # ready to shutdown + >>> rpc.shutdown() + + >>> # On worker 1: + >>> from mindnlp import core.distributed.rpc as rpc + >>> rpc.init_rpc("worker1", rank=1, world_size=2) + >>> # wait for worker 0 to finish work, and then shutdown. + >>> rpc.shutdown() + """ + if graceful: + try: + agent = _get_current_rpc_agent() + if not isinstance(agent, TensorPipeAgent) or agent.is_static_group: + _wait_all_workers(timeout) + _delete_all_user_and_unforked_owner_rrefs() + agent.join(shutdown=True, timeout=timeout) + else: + # This is a dynamic group so we need to grab the token for the operation + my_worker_info = agent.get_worker_info() + my_name = my_worker_info.name + with _group_membership_management(agent.store, my_name, False): + all_worker_infos = agent.get_worker_infos() + for worker in all_worker_infos: + if worker.name != my_name: + rpc_sync( + worker.name, + _update_group_membership, + args=(my_worker_info, [], {}, False), + ) + agent.join(shutdown=True, timeout=timeout) + finally: + # In case of errors, continue to complete the local shutdown. + _finalize_shutdown() + else: + _finalize_shutdown() + + +def _finalize_shutdown(): + try: + # This raises a `TORCH_CHECK()` exception on RRef leak detected. + _destroy_rref_context(_ignore_rref_leak) + finally: + _get_current_rpc_agent().shutdown() + # clean up python rpc handler in shutdown(), see comments in + # PythonRpcHandler::cleanup(), call it in python API because the + # cleanup() function has python dependency, it assumes python + # interpreter exists. + # No matter if RRef leak exception is raised, this clean-up code + # must run to avoid destruction segfault in Python 3.5. + # + # future.wait() should not be called after shutdown(). + # pythonRpcHandler is cleaned up in shutdown(), after + # shutdown(), python objects returned from rpc python call can not be + # resolved. + _cleanup_python_rpc_handler() + _reset_current_rpc_agent() + + +@_require_initialized +def get_worker_info(worker_name=None): + r""" + Get :class:`~core.distributed.rpc.WorkerInfo` of a given worker name. + Use this :class:`~core.distributed.rpc.WorkerInfo` to avoid passing an + expensive string on every invocation. + + Args: + worker_name (str): the string name of a worker. If ``None``, return the + the id of the current worker. (default ``None``) + + Returns: + :class:`~core.distributed.rpc.WorkerInfo` instance for the given + ``worker_name`` or :class:`~core.distributed.rpc.WorkerInfo` of the + current worker if ``worker_name`` is ``None``. + """ + if worker_name is not None: + return _get_current_rpc_agent().get_worker_info(worker_name) + else: + return _get_current_rpc_agent().get_worker_info() + + +def _to_worker_info(to): + if isinstance(to, WorkerInfo): + return to + elif isinstance(to, (str, int)): + return get_worker_info(to) + else: + raise ValueError(f"Cannot get WorkerInfo from name {to}") + + +def _rref_typeof_on_owner(rref, blocking: bool = True): + rref_type = type(rref.local_value()) + if blocking: + return rref_type + else: + # Wrap result into a completed Future. This is so that if blocking=`False` + # is specified, we return a future regardless of if this call is on user + # or owner. + future = Future[type]() + future.set_result(rref_type) + return future + + +def _rref_typeof_on_user( + rref, timeout: float = UNSET_RPC_TIMEOUT, blocking: bool = True +): + fut = rpc_async(rref.owner(), _rref_typeof_on_owner, args=(rref,), timeout=timeout) + if blocking: + return fut.wait() + else: + return fut + + +T = TypeVar("T") +GenericWithOneTypeVar = Generic[T] + + +if TYPE_CHECKING: + + class RRef(PyRRef[T], Generic[T]): + pass + +else: + try: + # Combine the implementation class and the type class. + class RRef(PyRRef, Generic[T]): + pass + + except TypeError: + # TypeError: metaclass conflict: the metaclass of a derived class + # must be a (non-strict) subclass of the metaclasses of all its bases + # Mypy doesn't understand __class__ (mypy bug #4177) + class RRefMeta(PyRRef.__class__, GenericWithOneTypeVar.__class__): # type: ignore[name-defined, misc, valid-type] + pass + + # Combine the implementation class and the type class. + # Types for classes expecting a certain generic parameter (mypy bug #7791) + class RRef(PyRRef, GenericWithOneTypeVar, metaclass=RRefMeta): # type: ignore[misc, no-redef, valid-type] + pass + + +# Install docstrings from `PyRRef` to `RRef`. +# +# This is for the fact that pybind11 generates the parameter +# `self` as type `rpc.PyRRef`, so a `:inherited-members:` +# under `.. autoclass:: RRef` does not work. +# we have to do the following process to replace `rpc.PyRRef` with `rpc.RRef`. +# +def method_factory(method_name, docstring): + def method(self, *args, **kwargs): + return getattr(super(RRef, self), method_name)(*args, **kwargs) + + if method.__doc__: + method.__doc__ = docstring + return method + + +for method_name, method in inspect.getmembers(PyRRef): + # Ignore magic methods, except "__str__". + if method_name.startswith("_") and method_name != "__str__": + continue + + # Get pybind11 generated docstring. + # It's like, + """ + to_here(self: core.distributed.rpc.PyRRef, timeout: float=-1.0) -> object + + Blocking call that copies the value of the RRef from the owner + to the local node and returns it. If the current node is the + owner, returns a reference to the local value. + """ + docstring = getattr(method, "__doc__", None) + assert docstring is not None, "RRef user-facing methods should all have docstrings." + + # Do surgery on pybind11 generated docstrings. + docstring = docstring.replace( + "core.distributed.rpc.PyRRef", "core.distributed.rpc.RRef" + ) + + # Attach user-facing RRef method with modified docstring. + new_method = method_factory(method_name, docstring) + setattr(RRef, method_name, new_method) + + +@_require_initialized +def remote(to, func, args=None, kwargs=None, timeout=UNSET_RPC_TIMEOUT): + r""" + Make a remote call to run ``func`` on worker ``to`` and return an + :class:`~core.distributed.rpc.RRef` to the result value immediately. + Worker ``to`` will be the owner of the returned + :class:`~core.distributed.rpc.RRef`, and the worker calling ``remote`` is + a user. The owner manages the global reference count of its + :class:`~core.distributed.rpc.RRef`, and the owner + :class:`~core.distributed.rpc.RRef` is only destructed when globally there + are no living references to it. + + Args: + to (str or WorkerInfo or int): name/rank/``WorkerInfo`` of the destination worker. + func (Callable): a callable function, such as Python callables, builtin + operators (e.g. :meth:`~core.add`) and annotated + TorchScript functions. + args (tuple): the argument tuple for the ``func`` invocation. + kwargs (dict): is a dictionary of keyword arguments for the ``func`` + invocation. + + timeout (float, optional): timeout in seconds for this remote call. If the + creation of this + :class:`~core.distributed.rpc.RRef` on worker + ``to`` is not successfully processed on this + worker within this timeout, then the next time + there is an attempt to use the RRef (such as + ``to_here()``), a timeout will be raised + indicating this failure. A value of 0 indicates + an infinite timeout, i.e. a timeout error will + never be raised. If not provided, the default + value set during initialization or with + ``_set_rpc_timeout`` is used. + + Returns: + A user :class:`~core.distributed.rpc.RRef` instance to the result + value. Use the blocking API :meth:`core.distributed.rpc.RRef.to_here` + to retrieve the result value locally. + + .. warning :: + The ``remote`` API does not copy storages of argument tensors until + sending them over the wire, which could be done by a different thread + depending on the RPC backend type. The caller should make sure that the + contents of those tensors stay intact until the returned RRef is + confirmed by the owner, which can be checked using the + :meth:`core.distributed.rpc.RRef.confirmed_by_owner` API. + + .. warning :: + Errors such as timeouts for the ``remote`` API are handled on a + best-effort basis. This means that when remote calls initiated by + ``remote`` fail, such as with a timeout error, we take a best-effort + approach to error handling. This means that errors are handled and set + on the resulting RRef on an asynchronous basis. If the RRef has not been + used by the application before this handling (such as ``to_here`` or + fork call), then future uses of the ``RRef`` will appropriately raise + errors. However, it is possible that the user application will use the + ``RRef`` before the errors are handled. In this case, errors may not be + raised as they have not yet been handled. + + Example:: + + Make sure that ``MASTER_ADDR`` and ``MASTER_PORT`` are set properly + on both workers. Refer to :meth:`~core.distributed.init_process_group` + API for more details. For example, + + export MASTER_ADDR=localhost + export MASTER_PORT=5678 + + Then run the following code in two different processes: + + >>> # xdoctest: +SKIP + >>> # On worker 0: + >>> from mindnlp import core + >>> from mindnlp import core.distributed.rpc as rpc + >>> rpc.init_rpc("worker0", rank=0, world_size=2) + >>> rref1 = rpc.remote("worker1", core.add, args=(core.ones(2), 3)) + >>> rref2 = rpc.remote("worker1", core.add, args=(core.ones(2), 1)) + >>> x = rref1.to_here() + rref2.to_here() + >>> rpc.shutdown() + + >>> # On worker 1: + >>> from mindnlp import core.distributed.rpc as rpc + >>> rpc.init_rpc("worker1", rank=1, world_size=2) + >>> rpc.shutdown() + + Below is an example of running a TorchScript function using RPC. + + >>> # On both workers: + >>> @core.jit.script + >>> def my_script_add(tensor: core.Tensor, scalar: int): + >>> return core.add(tensor, scalar) + + >>> # On worker 0: + >>> from mindnlp import core.distributed.rpc as rpc + >>> rpc.init_rpc("worker0", rank=0, world_size=2) + >>> rref = rpc.remote("worker1", my_script_add, args=(core.ones(2), 3)) + >>> rref.to_here() + >>> rpc.shutdown() + + >>> # On worker 1: + >>> from mindnlp import core.distributed.rpc as rpc + >>> rpc.init_rpc("worker1", rank=1, world_size=2) + >>> rpc.shutdown() + """ + core._C._log_api_usage_once("core.distributed.rpc_remote") + qualified_name = core.jit._builtins._find_builtin(func) + dst_worker_info = _to_worker_info(to) + should_profile = _get_should_profile() + + ctx_manager = _enable_rpc_profiler( + should_profile, qualified_name, func, RPCExecMode.REMOTE, dst_worker_info + ) + + with ctx_manager as rf: + args = args if args else () + kwargs = kwargs if kwargs else {} + + is_async_exec = hasattr(func, "_wrapped_async_rpc_function") + + if is_async_exec: + wrapped = func._wrapped_async_rpc_function + if isinstance(wrapped, core.jit.ScriptFunction): + func = wrapped + + if qualified_name is not None: + rref = _invoke_remote_builtin( + dst_worker_info, qualified_name, timeout, *args, **kwargs + ) + elif isinstance(func, core.jit.ScriptFunction): + rref = _invoke_remote_torchscript( + dst_worker_info.name, + core._jit_internal._qualified_name(func), + timeout, + is_async_exec, + *args, + **kwargs, + ) + else: + (pickled_python_udf, tensors) = _default_pickler.serialize( + PythonUDF(func, args, kwargs) + ) + rref = _invoke_remote_python_udf( + dst_worker_info, pickled_python_udf, tensors, timeout, is_async_exec + ) + # attach profiling information + if should_profile: + assert core.autograd._profiler_enabled() + assert rf is not None + fut = rf._call_end_callbacks_on_future(rref._get_future()) + rref._set_profiling_future(fut) + + return rref + + +def _invoke_rpc( + to, func, rpc_type, args=None, kwargs=None, rpc_timeout: float = UNSET_RPC_TIMEOUT +): + if not callable(func): + raise TypeError("function should be callable.") + + qualified_name = core.jit._builtins._find_builtin(func) + dst_worker_info = _to_worker_info(to) + + should_profile = _get_should_profile() + + ctx_manager = _enable_rpc_profiler( + should_profile, qualified_name, func, rpc_type, dst_worker_info + ) + + with ctx_manager as rf: + args = args if args else () + kwargs = kwargs if kwargs else {} + + is_async_exec = hasattr(func, "_wrapped_async_rpc_function") + + if is_async_exec: + wrapped = func._wrapped_async_rpc_function + if isinstance(wrapped, core.jit.ScriptFunction): + func = wrapped + + if qualified_name is not None: + fut = _invoke_rpc_builtin( + dst_worker_info, qualified_name, rpc_timeout, *args, **kwargs + ) + elif isinstance(func, core.jit.ScriptFunction): + fut = _invoke_rpc_torchscript( + dst_worker_info.name, + core._jit_internal._qualified_name(func), + args, + kwargs, + rpc_timeout, + is_async_exec, + ) + else: + (pickled_python_udf, tensors) = _default_pickler.serialize( + PythonUDF(func, args, kwargs) + ) + fut = _invoke_rpc_python_udf( + dst_worker_info, pickled_python_udf, tensors, rpc_timeout, is_async_exec + ) + if should_profile: + assert core.autograd._profiler_enabled() + assert rf is not None + # Schedule profiling callbacks to run when the future completes. + # This returns a future that is completed when the original future + # completes and the profiling callbacks have been completed as well, + # to guarantee that fut.wait() completes the profiling. This new + # future will contain the same value as the original future. + fut = rf._call_end_callbacks_on_future(fut) + return fut + + +@_require_initialized +def rpc_sync(to, func, args=None, kwargs=None, timeout: float = UNSET_RPC_TIMEOUT): + r""" + Make a blocking RPC call to run function ``func`` on worker ``to``. RPC + messages are sent and received in parallel to execution of Python code. This + method is thread-safe. + + Args: + to (str or WorkerInfo or int): name/rank/``WorkerInfo`` of the destination worker. + func (Callable): a callable function, such as Python callables, builtin + operators (e.g. :meth:`~core.add`) and annotated + TorchScript functions. + args (tuple): the argument tuple for the ``func`` invocation. + kwargs (dict): is a dictionary of keyword arguments for the ``func`` + invocation. + timeout (float, optional): timeout in seconds to use for this RPC. If + the RPC does not complete in this amount of + time, an exception indicating it has + timed out will be raised. A value of 0 + indicates an infinite timeout, i.e. a timeout + error will never be raised. If not provided, + the default value set during initialization + or with ``_set_rpc_timeout`` is used. + + Returns: + Returns the result of running ``func`` with ``args`` and ``kwargs``. + + Example:: + Make sure that ``MASTER_ADDR`` and ``MASTER_PORT`` are set properly + on both workers. Refer to :meth:`~core.distributed.init_process_group` + API for more details. For example, + + export MASTER_ADDR=localhost + export MASTER_PORT=5678 + + Then run the following code in two different processes: + + >>> # xdoctest: +SKIP + >>> # On worker 0: + >>> from mindnlp import core + >>> from mindnlp import core.distributed.rpc as rpc + >>> rpc.init_rpc("worker0", rank=0, world_size=2) + >>> ret = rpc.rpc_sync("worker1", core.add, args=(core.ones(2), 3)) + >>> rpc.shutdown() + + >>> # On worker 1: + >>> from mindnlp import core.distributed.rpc as rpc + >>> rpc.init_rpc("worker1", rank=1, world_size=2) + >>> rpc.shutdown() + + Below is an example of running a TorchScript function using RPC. + + >>> # On both workers: + >>> @core.jit.script + >>> def my_script_add(tensor: core.Tensor, scalar: int): + >>> return core.add(tensor, scalar) + + >>> # On worker 0: + >>> from mindnlp import core.distributed.rpc as rpc + >>> rpc.init_rpc("worker0", rank=0, world_size=2) + >>> ret = rpc.rpc_sync("worker1", my_script_add, args=(core.ones(2), 3)) + >>> rpc.shutdown() + + >>> # On worker 1: + >>> from mindnlp import core.distributed.rpc as rpc + >>> rpc.init_rpc("worker1", rank=1, world_size=2) + >>> rpc.shutdown() + + """ + core._C._log_api_usage_once("core.distributed.rpc_sync") + fut = _invoke_rpc(to, func, RPCExecMode.SYNC, args, kwargs, timeout) + return fut.wait() + + +@_require_initialized +def rpc_async(to, func, args=None, kwargs=None, timeout=UNSET_RPC_TIMEOUT): + r""" + Make a non-blocking RPC call to run function ``func`` on worker ``to``. RPC + messages are sent and received in parallel to execution of Python code. This + method is thread-safe. This method will immediately return a + :class:`~core.futures.Future` that can be awaited on. + + Args: + to (str or WorkerInfo or int): name/rank/``WorkerInfo`` of the destination worker. + func (Callable): a callable function, such as Python callables, builtin + operators (e.g. :meth:`~core.add`) and annotated + TorchScript functions. + args (tuple): the argument tuple for the ``func`` invocation. + kwargs (dict): is a dictionary of keyword arguments for the ``func`` + invocation. + timeout (float, optional): timeout in seconds to use for this RPC. If + the RPC does not complete in this amount of + time, an exception indicating it has + timed out will be raised. A value of 0 + indicates an infinite timeout, i.e. a timeout + error will never be raised. If not provided, + the default value set during initialization + or with ``_set_rpc_timeout`` is used. + + + Returns: + Returns a :class:`~core.futures.Future` object that can be waited + on. When completed, the return value of ``func`` on ``args`` and + ``kwargs`` can be retrieved from the :class:`~core.futures.Future` + object. + + .. warning :: + Using GPU tensors as arguments or return values of ``func`` is not + supported since we don't support sending GPU tensors over the wire. You + need to explicitly copy GPU tensors to CPU before using them as + arguments or return values of ``func``. + + .. warning :: + The ``rpc_async`` API does not copy storages of argument tensors until + sending them over the wire, which could be done by a different thread + depending on the RPC backend type. The caller should make sure that the + contents of those tensors stay intact until the returned + :class:`~core.futures.Future` completes. + + Example:: + Make sure that ``MASTER_ADDR`` and ``MASTER_PORT`` are set properly + on both workers. Refer to :meth:`~core.distributed.init_process_group` + API for more details. For example, + + export MASTER_ADDR=localhost + export MASTER_PORT=5678 + + Then run the following code in two different processes: + + >>> # xdoctest: +SKIP + >>> # On worker 0: + >>> from mindnlp import core + >>> from mindnlp import core.distributed.rpc as rpc + >>> rpc.init_rpc("worker0", rank=0, world_size=2) + >>> fut1 = rpc.rpc_async("worker1", core.add, args=(core.ones(2), 3)) + >>> fut2 = rpc.rpc_async("worker1", min, args=(1, 2)) + >>> result = fut1.wait() + fut2.wait() + >>> rpc.shutdown() + + >>> # On worker 1: + >>> from mindnlp import core.distributed.rpc as rpc + >>> rpc.init_rpc("worker1", rank=1, world_size=2) + >>> rpc.shutdown() + + Below is an example of running a TorchScript function using RPC. + + >>> # On both workers: + >>> @core.jit.script + >>> def my_script_add(tensor: core.Tensor, scalar: int): + >>> return core.add(tensor, scalar) + + >>> # On worker 0: + >>> from mindnlp import core.distributed.rpc as rpc + >>> rpc.init_rpc("worker0", rank=0, world_size=2) + >>> fut = rpc.rpc_async("worker1", my_script_add, args=(core.ones(2), 3)) + >>> ret = fut.wait() + >>> rpc.shutdown() + + >>> # On worker 1: + >>> from mindnlp import core.distributed.rpc as rpc + >>> rpc.init_rpc("worker1", rank=1, world_size=2) + >>> rpc.shutdown() + """ + core._C._log_api_usage_once("core.distributed.rpc_async") + fut = _invoke_rpc(to, func, RPCExecMode.ASYNC, args, kwargs, timeout) + if hasattr(_thread_local_var, "future_list"): + _thread_local_var.future_list.append(fut) + return fut + + +def _get_should_profile(): + # Legacy profiler should be enabled. RPC profiling is not supported with + # Kineto profiler. + ActiveProfilerType = core._C._profiler.ActiveProfilerType + return ( + core.autograd._profiler_enabled() + and core._C._autograd._profiler_type() + == ActiveProfilerType.LEGACY # type: ignore[attr-defined] + ) + + +def _enable_rpc_profiler( + should_profile, qualified_name, func, rpc_type, dst_worker_info +): + ctx_manager = contextlib.nullcontext() + + if should_profile: + # Create appropriate string representation based on type of func + # (builtin, script, python) + if qualified_name is None: + func_name = ( + core._jit_internal._qualified_name(func) + if isinstance(func, core.jit.ScriptFunction) + else func.__qualname__ + ) + else: + func_name = qualified_name + # Build RPC profiling key. + rpc_profiling_key = _build_rpc_profiling_key( + rpc_type, + func_name, + get_worker_info().name, + dst_worker_info.name, + ) + RemoteProfilerManager.set_current_profiling_key(rpc_profiling_key) + # Mypy doesn't support re-def of a variable not in the same block (#1174) + ctx_manager = core.autograd.profiler.record_function(rpc_profiling_key) # type: ignore[assignment] + + return ctx_manager diff --git a/mindnlp/core/distributed/rpc/backend_registry.py b/mindnlp/core/distributed/rpc/backend_registry.py new file mode 100644 index 000000000..1d787999d --- /dev/null +++ b/mindnlp/core/distributed/rpc/backend_registry.py @@ -0,0 +1,432 @@ +# mypy: allow-untyped-defs + + +import collections +import enum +from typing import cast, Dict, List, Set, Tuple + +from mindnlp import core +from mindnlp import core.distributed as dist + +from . import api, constants as rpc_constants +from ._utils import _group_membership_management, _update_group_membership + + +__all__ = [ + "backend_registered", + "register_backend", + "construct_rpc_backend_options", + "init_backend", + "BackendValue", + "BackendType", +] + +BackendValue = collections.namedtuple( + "BackendValue", ["construct_rpc_backend_options_handler", "init_backend_handler"] +) + + +def _backend_type_repr(self): + return "BackendType." + self.name + + +_backend_type_doc = """ + An enum class of available backends. + + PyTorch ships with a builtin ``BackendType.TENSORPIPE`` backend. + Additional ones can be registered using the + :func:`~core.distributed.rpc.backend_registry.register_backend` function. +""" + +# Create an enum type, `BackendType`, with empty members. +# Can't handle Function Enum API (mypy bug #9079) +BackendType = enum.Enum(value="BackendType", names={}) # type: ignore[misc] +# Unable to assign a function a method (mypy bug #2427) +BackendType.__repr__ = _backend_type_repr # type: ignore[assignment] + +if BackendType.__doc__: + BackendType.__doc__ = _backend_type_doc + + +def backend_registered(backend_name): + """ + Checks if backend_name is registered as an RPC backend. + + Args: + backend_name (str): string to identify the RPC backend. + Returns: + True if the backend has been registered with ``register_backend``, else + False. + """ + return backend_name in BackendType.__members__.keys() + + +def register_backend( + backend_name, construct_rpc_backend_options_handler, init_backend_handler +): + """Registers a new RPC backend. + + Args: + backend_name (str): backend string to identify the handler. + construct_rpc_backend_options_handler (function): + Handler that is invoked when + rpc_backend.construct_rpc_backend_options(**dict) is called. + init_backend_handler (function): Handler that is invoked when the + `_init_rpc_backend()` function is called with a backend. + This returns the agent. + """ + global BackendType + if backend_registered(backend_name): + raise RuntimeError(f"RPC backend {backend_name}: already registered") + # Create a new enum type, `BackendType`, with extended members. + existing_enum_dict = {member.name: member.value for member in BackendType} + extended_enum_dict = dict( + { + backend_name: BackendValue( + construct_rpc_backend_options_handler=construct_rpc_backend_options_handler, + init_backend_handler=init_backend_handler, + ) + }, + **existing_enum_dict, + ) + # Can't handle Function Enum API (mypy bug #9079) + BackendType = enum.Enum(value="BackendType", names=extended_enum_dict) # type: ignore[misc] + # Unable to assign a function a method (mypy bug #2427) + BackendType.__repr__ = _backend_type_repr # type: ignore[assignment] + if BackendType.__doc__: + BackendType.__doc__ = _backend_type_doc + return BackendType[backend_name] + + +def construct_rpc_backend_options( + backend, + rpc_timeout=rpc_constants.DEFAULT_RPC_TIMEOUT_SEC, + init_method=rpc_constants.DEFAULT_INIT_METHOD, + **kwargs, +): + return backend.value.construct_rpc_backend_options_handler( + rpc_timeout, init_method, **kwargs + ) + + +def init_backend(backend, *args, **kwargs): + return backend.value.init_backend_handler(*args, **kwargs) + + +def _init_process_group(store, rank, world_size): + # Initialize ProcessGroup. + process_group_timeout = rpc_constants.DEFAULT_PROCESS_GROUP_TIMEOUT + + # We're using a bunch of private APIs here since `new_group` requires the + # default group to be initialized. + group = dist.ProcessGroupGloo(store, rank, world_size, process_group_timeout) + + assert group is not None, "Failed to initialize default ProcessGroup." + + if (rank != -1) and (rank != group.rank()): + raise RuntimeError(f"rank argument {rank} doesn't match pg rank {group.rank()}") + if (world_size != -1) and (world_size != group.size()): + raise RuntimeError( + f"world_size argument {world_size} doesn't match pg size {group.size()}" + ) + return group + + +def _tensorpipe_construct_rpc_backend_options_handler( + rpc_timeout, + init_method, + num_worker_threads=rpc_constants.DEFAULT_NUM_WORKER_THREADS, + _transports=None, + _channels=None, + **kwargs, +): + from . import TensorPipeRpcBackendOptions + + return TensorPipeRpcBackendOptions( + rpc_timeout=rpc_timeout, + init_method=init_method, + num_worker_threads=num_worker_threads, + _transports=_transports, + _channels=_channels, + ) + + +def _tensorpipe_validate_devices(devices, device_count): + return all( + d.type == "cpu" or (d.type == "cuda" and 0 <= d.index < device_count) + for d in devices + ) + + +# detect if any worker has invalid device_map configurations, and return +# reverse device maps +def _tensorpipe_exchange_and_check_all_device_maps( + my_name, my_device_count, my_device_maps, my_devices, group +): + gathered: List[ + Tuple[str, int, Dict[str, Dict[core.device, core.device]], List[core.device]] + ] = [("", 0, {}, []) for _ in range(group.size())] + dist.all_gather_object( + gathered, (my_name, my_device_count, my_device_maps, my_devices), group + ) + all_names = [name for name, _, _, _ in gathered] + all_device_counts = {name: count for name, count, _, _ in gathered} + all_device_maps = {name: map_ for name, _, map_, _ in gathered} + all_devices = {name: devices for name, _, _, devices in gathered} + + _validate_device_maps(all_names, all_device_counts, all_device_maps, all_devices) + + # passed all checked, construct reverse mapping and get list of devices handled by this agent + reverse_device_maps = _create_reverse_mapping(my_name, all_names, all_device_maps) + my_devices = _create_device_list(my_devices, my_device_maps, reverse_device_maps) + return reverse_device_maps, my_devices + + +def _validate_device_maps( + all_names, all_device_counts, all_device_maps, all_devices, is_static_group=True +): + for node in all_names: + devices = all_devices[node] + if len(set(devices)) != len(devices): + raise ValueError( + f"Node {node} has duplicated devices\n" f"devices = {devices}" + ) + if not _tensorpipe_validate_devices(devices, all_device_counts[node]): + raise ValueError( + f"Node {node} has devices with invalid indices\n" + f"devices = {devices}\n" + f"device count = {all_device_counts[node]}" + ) + + for source_node in all_names: + # For dynamic group (non-static) do not check the target node name since it may not have joined yet + if is_static_group and not set(all_device_maps[source_node].keys()).issubset( + all_names + ): + raise ValueError( + f"Node {source_node} has invalid target node names in its device maps\n" + f"device maps = {all_device_maps[source_node].keys()}\n" + f"node names = {all_names}" + ) + for target_node, map_ in all_device_maps[source_node].items(): + if len(set(map_.values())) != len(map_): + raise ValueError( + f"Node {source_node} has duplicated target devices " + f"in its device map for {target_node}\n" + f"device map = {map_}" + ) + if all_devices[source_node]: + if not set(map_.keys()).issubset(all_devices[source_node]): + raise ValueError( + f"Node {source_node} has unexpected source devices " + f"in its device map for {target_node}\n" + f"device map = {map_}\n" + f"devices = {all_devices[source_node]}" + ) + elif not _tensorpipe_validate_devices( + map_.keys(), all_device_counts[source_node] + ): + raise ValueError( + f"Node {source_node} has source devices with invalid indices " + f"in its device map for {target_node}\n" + f"device map = {map_}\n" + f"device count = {all_device_counts[source_node]}" + ) + if all_devices.get(target_node, []): + if not set(map_.values()).issubset(all_devices[target_node]): + raise ValueError( + f"Node {source_node} has unexpected target devices " + f"in its device map for {target_node}\n" + f"device map = {map_}\n" + f"devices = {all_devices[target_node]}" + ) + elif target_node in all_device_counts and not _tensorpipe_validate_devices( + map_.values(), all_device_counts[target_node] + ): + raise ValueError( + f"Node {source_node} has target devices with invalid indices " + f"in its device map for {target_node}\n" + f"device map = {map_}\n" + f"device count = {all_device_counts[target_node]}" + ) + + +def _create_device_list(my_devices, my_device_maps, reverse_device_maps): + if not my_devices: + devices_set: Set[core.device] = set() + for map_ in my_device_maps.values(): + devices_set.update(map_.keys()) + for map_ in reverse_device_maps.values(): + devices_set.update(map_.keys()) + devices_set.discard(core.device("cpu")) + my_devices = list(devices_set) + my_devices = sorted(my_devices, key=lambda d: d.index) + return my_devices + + +def _create_reverse_mapping(my_name, all_names, all_device_maps): + reverse_device_maps: Dict[str, Dict[core.device, core.device]] = {} + for node in all_names: + if my_name in all_device_maps[node]: + reverse_device_maps[node] = { + v: k for k, v in all_device_maps[node][my_name].items() + } + return reverse_device_maps + + +def _get_device_infos(): + from . import TensorPipeAgent + + agent = cast(TensorPipeAgent, api._get_current_rpc_agent()) + opts = agent._get_backend_options() + device_count = core.cuda.device_count() + if core.cuda.is_available() and opts.devices: + core.cuda.init() + return device_count, opts.device_maps, opts.devices + + +def _set_devices_and_reverse_device_map(agent): + from . import TensorPipeAgent + + agent = cast(TensorPipeAgent, agent) + # Group state is retrieved from local agent + # On initialization, tensorpipe agent retrieves information from all existing workers, so group state is valid + my_worker_info = agent.get_worker_info() + my_name = my_worker_info.name + all_worker_infos = agent.get_worker_infos() + # One round to get device_maps of all workers and construct reverse device maps + all_device_counts, all_device_maps, all_devices, all_names = {}, {}, {}, [] + for worker_info in all_worker_infos: + worker_name = worker_info.name + if worker_name != my_name: + # TODO: make async? + device_count, device_map, devices = api.rpc_sync( + worker_name, _get_device_infos + ) + else: + opts = agent._get_backend_options() + device_count, device_map, devices = ( + core.cuda.device_count(), + opts.device_maps, + opts.devices, + ) + all_device_counts[worker_name] = device_count + all_device_maps[worker_name] = device_map + all_devices[worker_name] = devices + all_names.append(worker_name) + + _validate_device_maps( + all_names, + all_device_counts, + all_device_maps, + all_devices, + is_static_group=False, + ) + reverse_device_maps = _create_reverse_mapping(my_name, all_names, all_device_maps) + + # Perform RPC call to all workers, including itself, to include newly joined worker information and device maps + for worker_name in all_names: + # Set device list for each worker + all_devices[worker_name] = _create_device_list( + all_devices[worker_name], all_device_maps[worker_name], reverse_device_maps + ) + api.rpc_sync( + worker_name, + _update_group_membership, + args=(my_worker_info, all_devices[worker_name], reverse_device_maps, True), + ) + + +def _tensorpipe_init_backend_handler( + store, name, rank, world_size, rpc_backend_options +): + from . import TensorPipeAgent, TensorPipeRpcBackendOptions + + if not isinstance(store, dist.Store): + raise TypeError(f"`store` must be a c10d::Store. {store}") + + if not isinstance(rpc_backend_options, TensorPipeRpcBackendOptions): + raise TypeError( + f"`rpc_backend_options` must be a `TensorPipeRpcBackendOptions`. {rpc_backend_options}" + ) + + device_count = core.cuda.device_count() + + is_static_group = True if world_size else False + # world_size is specified so this is a static group (ranks cannot join and leave) + if is_static_group: + # The agent's join method is required to behave like a barrier and perform + # collective operations, for which it relies on a process group, instead of + # re-implementing this on top of RPCs. + group = _init_process_group(store, rank, world_size) + + reverse_device_maps, devices = _tensorpipe_exchange_and_check_all_device_maps( + name, + device_count, + rpc_backend_options.device_maps, + rpc_backend_options.devices, + group, + ) + + if core.cuda.is_available() and devices: + # It's necessary to initialize PyTorch CUDA states here (e.g., + # CUDACachingAllocator). If this is missing, we could hit errors like + # "allocator not initialized", because other processes might send + # CUDA-related RPC request to this process before user code in this + # process initializes its PyTorch CUDA states. + core.cuda.init() + + # TODO: add try-except and destroy _agent in all processes if any fails. + agent = TensorPipeAgent( + store, + name, + rank, + world_size, + rpc_backend_options, + reverse_device_maps, + devices, + ) + + api._init_rpc_states(agent) + + # Run one dummy round of RPC to initialize channels/transports. Without + # this, it's easy to hit timeout in rpc.shutdown() if there is no other RPC + # on that process before rpc.shutdown(), as the agent initialization can + # take longer than 5s. + api._all_gather(None, timeout=rpc_backend_options.rpc_timeout) + # Need a barrier here to make sure no peers leave before the rank0 finishes + # _all_gather + group.barrier().wait() + + return agent + # initialization for dynamic rpc (ranks can join and leave) + else: + with _group_membership_management(store, name, True): + # Construct TPAgent with empty reverse_device_map and devices + # these properties will be updated after initialization + agent = TensorPipeAgent( + store, + name, + rank, + world_size, + rpc_backend_options, + {}, + [], + ) + api._init_rpc_states(agent) + + try: + # Notify all workers in group this rank has joined and set devices and reverse_device_map + # This is a synchronous operation that completes once all existing ranks are updated + _set_devices_and_reverse_device_map(agent) + except Exception: + api.shutdown() + raise + return agent + + +register_backend( + "TENSORPIPE", + _tensorpipe_construct_rpc_backend_options_handler, + _tensorpipe_init_backend_handler, +) diff --git a/mindnlp/core/distributed/rpc/constants.py b/mindnlp/core/distributed/rpc/constants.py new file mode 100644 index 000000000..9d954fecc --- /dev/null +++ b/mindnlp/core/distributed/rpc/constants.py @@ -0,0 +1,25 @@ +from datetime import timedelta +from typing import List + +from core._C._distributed_rpc import ( + _DEFAULT_INIT_METHOD, + _DEFAULT_NUM_WORKER_THREADS, + _DEFAULT_RPC_TIMEOUT_SEC, + _UNSET_RPC_TIMEOUT, +) + + +# For any RpcAgent. +DEFAULT_RPC_TIMEOUT_SEC: float = _DEFAULT_RPC_TIMEOUT_SEC +DEFAULT_INIT_METHOD: str = _DEFAULT_INIT_METHOD +DEFAULT_SHUTDOWN_TIMEOUT: float = 0 + +# For TensorPipeAgent. +DEFAULT_NUM_WORKER_THREADS: int = _DEFAULT_NUM_WORKER_THREADS +# Ensure that we don't time out when there are long periods of time without +# any operations against the underlying ProcessGroup. +DEFAULT_PROCESS_GROUP_TIMEOUT: timedelta = timedelta(milliseconds=2**31 - 1) +# Value indicating that timeout is not set for RPC call, and the default should be used. +UNSET_RPC_TIMEOUT: float = _UNSET_RPC_TIMEOUT + +__all__: List[str] = [] diff --git a/mindnlp/core/distributed/rpc/functions.py b/mindnlp/core/distributed/rpc/functions.py new file mode 100644 index 000000000..34d44d6b3 --- /dev/null +++ b/mindnlp/core/distributed/rpc/functions.py @@ -0,0 +1,169 @@ +# mypy: allow-untyped-defs +import functools + + +def async_execution(fn): + r""" + A decorator for a function indicating that the return value of the function + is guaranteed to be a :class:`~core.futures.Future` object and this + function can run asynchronously on the RPC callee. More specifically, the + callee extracts the :class:`~core.futures.Future` returned by the wrapped + function and installs subsequent processing steps as a callback to that + :class:`~core.futures.Future`. The installed callback will read the value + from the :class:`~core.futures.Future` when completed and send the + value back as the RPC response. That also means the returned + :class:`~core.futures.Future` only exists on the callee side and is never + sent through RPC. This decorator is useful when the wrapped function's + (``fn``) execution needs to pause and resume due to, e.g., containing + :meth:`~core.distributed.rpc.rpc_async` or waiting for other signals. + + .. note:: To enable asynchronous execution, applications must pass the + function object returned by this decorator to RPC APIs. If RPC detected + attributes installed by this decorator, it knows that this function + returns a ``Future`` object and will handle that accordingly. + However, this does not mean this decorator has to be outmost one when + defining a function. For example, when combined with ``@staticmethod`` + or ``@classmethod``, ``@rpc.functions.async_execution`` needs to be the + inner decorator to allow the target function be recognized as a static + or class function. This target function can still execute asynchronously + because, when accessed, the static or class method preserves attributes + installed by ``@rpc.functions.async_execution``. + + + Example:: + The returned :class:`~core.futures.Future` object can come from + :meth:`~core.distributed.rpc.rpc_async`, + :meth:`~core.futures.Future.then`, or :class:`~core.futures.Future` + constructor. The example below shows directly using the + :class:`~core.futures.Future` returned by + :meth:`~core.futures.Future.then`. + + >>> from core.distributed import rpc + >>> + >>> # omitting setup and shutdown RPC + >>> + >>> # On all workers + >>> @rpc.functions.async_execution + >>> def async_add_chained(to, x, y, z): + >>> # This function runs on "worker1" and returns immediately when + >>> # the callback is installed through the `then(cb)` API. In the + >>> # mean time, the `rpc_async` to "worker2" can run concurrently. + >>> # When the return value of that `rpc_async` arrives at + >>> # "worker1", "worker1" will run the lambda function accordingly + >>> # and set the value for the previously returned `Future`, which + >>> # will then trigger RPC to send the result back to "worker0". + >>> return rpc.rpc_async(to, core.add, args=(x, y)).then( + >>> lambda fut: fut.wait() + z + >>> ) + >>> + >>> # On worker0 + >>> # xdoctest: +SKIP + >>> ret = rpc.rpc_sync( + >>> "worker1", + >>> async_add_chained, + >>> args=("worker2", core.ones(2), 1, 1) + >>> ) + >>> print(ret) # prints tensor([3., 3.]) + + When combined with TorchScript decorators, this decorator must be the + outmost one. + + >>> from mindnlp.core import Tensor + >>> from core.futures import Future + >>> from core.distributed import rpc + >>> + >>> # omitting setup and shutdown RPC + >>> + >>> # On all workers + >>> @core.jit.script + >>> def script_add(x: Tensor, y: Tensor) -> Tensor: + >>> return x + y + >>> + >>> @rpc.functions.async_execution + >>> @core.jit.script + >>> def async_add(to: str, x: Tensor, y: Tensor) -> Future[Tensor]: + >>> return rpc.rpc_async(to, script_add, (x, y)) + >>> + >>> # On worker0 + >>> ret = rpc.rpc_sync( + >>> "worker1", + >>> async_add, + >>> args=("worker2", core.ones(2), 1) + >>> ) + >>> print(ret) # prints tensor([2., 2.]) + + When combined with static or class method, this decorator must be the + inner one. + + >>> from core.distributed import rpc + >>> + >>> # omitting setup and shutdown RPC + >>> + >>> # On all workers + >>> class AsyncExecutionClass: + >>> + >>> @staticmethod + >>> @rpc.functions.async_execution + >>> def static_async_add(to, x, y, z): + >>> return rpc.rpc_async(to, core.add, args=(x, y)).then( + >>> lambda fut: fut.wait() + z + >>> ) + >>> + >>> @classmethod + >>> @rpc.functions.async_execution + >>> def class_async_add(cls, to, x, y, z): + >>> ret_fut = core.futures.Future() + >>> rpc.rpc_async(to, core.add, args=(x, y)).then( + >>> lambda fut: ret_fut.set_result(fut.wait() + z) + >>> ) + >>> return ret_fut + >>> + >>> @rpc.functions.async_execution + >>> def bound_async_add(self, to, x, y, z): + >>> return rpc.rpc_async(to, core.add, args=(x, y)).then( + >>> lambda fut: fut.wait() + z + >>> ) + >>> + >>> # On worker0 + >>> ret = rpc.rpc_sync( + >>> "worker1", + >>> AsyncExecutionClass.static_async_add, + >>> args=("worker2", core.ones(2), 1, 2) + >>> ) + >>> print(ret) # prints tensor([4., 4.]) + >>> + >>> ret = rpc.rpc_sync( + >>> "worker1", + >>> AsyncExecutionClass.class_async_add, + >>> args=("worker2", core.ones(2), 1, 2) + >>> ) + >>> print(ret) # prints tensor([4., 4.]) + + This decorator also works with RRef helpers, i.e., . + :meth:`core.distributed.rpc.RRef.rpc_sync`, + :meth:`core.distributed.rpc.RRef.rpc_async`, and + :meth:`core.distributed.rpc.RRef.remote`. + + >>> from core.distributed import rpc + >>> + >>> # reuse the AsyncExecutionClass class above + >>> rref = rpc.remote("worker1", AsyncExecutionClass) + >>> ret = rref.rpc_sync().static_async_add("worker2", core.ones(2), 1, 2) + >>> print(ret) # prints tensor([4., 4.]) + >>> + >>> rref = rpc.remote("worker1", AsyncExecutionClass) + >>> ret = rref.rpc_async().static_async_add("worker2", core.ones(2), 1, 2).wait() + >>> print(ret) # prints tensor([4., 4.]) + >>> + >>> rref = rpc.remote("worker1", AsyncExecutionClass) + >>> ret = rref.remote().static_async_add("worker2", core.ones(2), 1, 2).to_here() + >>> print(ret) # prints tensor([4., 4.]) + """ + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + return fn(*args, **kwargs) + + # Can't declare and use attributes of function objects (mypy#2087) + wrapper._wrapped_async_rpc_function = fn # type: ignore[attr-defined] + return wrapper diff --git a/mindnlp/core/distributed/rpc/internal.py b/mindnlp/core/distributed/rpc/internal.py new file mode 100644 index 000000000..6c1a7e320 --- /dev/null +++ b/mindnlp/core/distributed/rpc/internal.py @@ -0,0 +1,285 @@ +# mypy: allow-untyped-defs +import collections +import copyreg +import io +import pickle +import sys +import threading +import traceback +from enum import Enum + +from mindnlp import core +from mindnlp import core.distributed as dist +from core._C._distributed_rpc import _get_current_rpc_agent + + +__all__ = ["RPCExecMode", "serialize", "deserialize", "PythonUDF", "RemoteException"] + +# Thread local tensor tables to store tensors while pickling core.Tensor +# objects +_thread_local_tensor_tables = threading.local() +_pickler = pickle.Pickler +_unpickler = pickle.Unpickler + + +class RPCExecMode(Enum): + SYNC = "sync" + ASYNC = "async" + ASYNC_JIT = "async_jit" + REMOTE = "remote" + + +class _InternalRPCPickler: + r""" + This class provides serialize() and deserialize() interfaces to serialize + data to be "binary string + tensor table" format + So for RPC python UDF function and args, non tensor data will be serialized + into regular binary string, tensor data will be put into thread local tensor + tables, this serialization format is consistent with builtin operator and args + using JIT pickler. This format will make tensor handling in C++ much easier, + e.g. attach tensor to distributed autograd graph in C++ + """ + + def __init__(self): + # Ignore type error because dispatch_table is defined in third-party package + self._dispatch_table = copyreg.dispatch_table.copy() # type: ignore[attr-defined] + self._dispatch_table[core.Tensor] = self._tensor_reducer + # Used for registering customized picklers. + self._class_reducer_dict = {} + + def _register_reducer(self, obj_class, reducer): + # For the same class, only register the reducer once. + if obj_class not in self._class_reducer_dict: + self._class_reducer_dict[obj_class] = reducer + + @classmethod + def _tensor_receiver(cls, tensor_index): + global _thread_local_tensor_tables + return _thread_local_tensor_tables.recv_tables[tensor_index] + + def _tensor_reducer(self, tensor): + global _thread_local_tensor_tables + _thread_local_tensor_tables.send_tables.append(tensor) + tensor_index = len(_thread_local_tensor_tables.send_tables) - 1 + return (_InternalRPCPickler._tensor_receiver, (tensor_index,)) + + @classmethod + def _py_rref_receiver(cls, rref_fork_data): + return dist.rpc.PyRRef._deserialize(rref_fork_data) + + def _py_rref_reducer(self, py_rref): + rref_fork_data = py_rref._serialize() + return (_InternalRPCPickler._py_rref_receiver, (rref_fork_data,)) + + def _rref_reducer(self, rref): + return self._py_rref_reducer(rref) + + @classmethod + def _script_module_receiver(cls, script_module_serialized): + """ + Given a serialized representation of a ScriptModule created with core.jit.save, + loads and returns the ScriptModule. + """ + f = io.BytesIO(script_module_serialized) + m = core.jit.load(f) + return m + + def _script_module_reducer(self, script_module): + """ + Serializes a ScriptModule. + """ + f = io.BytesIO() + core.jit.save(script_module, f) + return (_InternalRPCPickler._script_module_receiver, (f.getvalue(),)) + + def serialize(self, obj): + r""" + Serialize non tensor data into binary string, tensor data into + tensor table + """ + f = io.BytesIO() + p = _pickler(f) + p.dispatch_table = self._dispatch_table + + # rpc api could accept user picklers inheriting from _InternalRPCPickler to serialize rref, + # user picklers could have different initialization function from _InternalRPCPickler, + # but all the user picklers should call serialize() and use _rref_reducer to pickle rref + # in python. also, when _internal_rpc_pickler is imported to rpc/api.py, rpc.RRef is not + # compiled yet, it is not good place to access rpc.RRef inside _InternalRPCPickler constructor, + # so putting rref's dispatch table here + # + # The return value of a `rpc.remote(..)` call is type of `rpc.PyRRef`. + # The deserialized RRef object on an RPC receiver side is type of `rpc.PyRRef`. + # Ignore type error because dispatch_table is defined in third-party package + p.dispatch_table[dist.rpc.PyRRef] = self._py_rref_reducer # type: ignore[index] + # An RRef created locally by RRef Python constructor is type of `rpc.RRef`. + # Ignore type error because dispatch_table is defined in third-party package + p.dispatch_table[dist.rpc.RRef] = self._rref_reducer # type: ignore[index] + + # Add dispatch pickling for ScriptModule or its subclass. + if isinstance(obj, core.jit.ScriptModule): + # Ignore type error because dispatch_table is defined in third-party package + p.dispatch_table[obj.__class__] = self._script_module_reducer # type: ignore[index] + + # Install customized picklers. + for class_name in self._class_reducer_dict.keys(): + p.dispatch_table[class_name] = self._class_reducer_dict[class_name] # type: ignore[index] + + # save _thread_local_tensor_tables.send_tables if it is in nested call + global _thread_local_tensor_tables + if hasattr(_thread_local_tensor_tables, "send_tables"): + old_send_tables = _thread_local_tensor_tables.send_tables + else: + old_send_tables = None + _thread_local_tensor_tables.send_tables = [] + + p.dump(obj) + + # restore _thread_local_tensor_tables.send_tables if return + # from nested call, otherwise clean up the table + tensors = _thread_local_tensor_tables.send_tables + if old_send_tables is not None: + _thread_local_tensor_tables.send_tables = old_send_tables + else: + del _thread_local_tensor_tables.send_tables + + return (f.getvalue(), tensors) + + def deserialize(self, binary_data, tensor_table): + r""" + Deserialize binary string + tensor table to original obj + """ + # save _thread_local_tensor_tables.recv_tables if it is in nested call + global _thread_local_tensor_tables + if hasattr(_thread_local_tensor_tables, "recv_tables"): + old_recv_tables = _thread_local_tensor_tables.recv_tables + else: + old_recv_tables = None + _thread_local_tensor_tables.recv_tables = tensor_table + + try: + unpickler = _unpickler(io.BytesIO(binary_data)) + ret = unpickler.load() + except AttributeError as e: + # Occurs when function is not found on module/class during + # unpickling. + except_str = ( + str(e) + + """ Default RPC pickler does not serialize + function code. Ensure that UDFs are defined on both caller and + callee modules.""" + ) + ret = AttributeError(except_str) + # Ensure the stack trace gets preserved + ret.__cause__ = e + + # restore _thread_local_tensor_tables.recv_tables if return + # from nested call, otherwise clean up the table + if old_recv_tables is not None: + _thread_local_tensor_tables.recv_tables = old_recv_tables + else: + del _thread_local_tensor_tables.recv_tables + + return ret + + +# Create _internal_rpc_pickler only once to initialize _dispatch_table only once +_internal_rpc_pickler = _InternalRPCPickler() + + +def serialize(obj): + return _internal_rpc_pickler.serialize(obj) + + +def deserialize(binary_data, tensor_table): + return _internal_rpc_pickler.deserialize(binary_data, tensor_table) + + +def _run_function(python_udf): + r""" + This function is exclusively called from C++. + See ``torch/csrc/distributed/rpc/python_rpc_handler.cpp``. + + Runs a Python UDF and returns its return value. + Wraps any exception in ``RemoteException`` if the function raises. + """ + try: + if isinstance(python_udf, AttributeError): + raise python_udf + result = python_udf.func(*python_udf.args, **python_udf.kwargs) + except Exception as e: + # except str = exception info + traceback string + except_str = ( + f"On {_get_current_rpc_agent().get_worker_info()}:\n" + f"{repr(e)}\n{traceback.format_exc()}" + ) + print(except_str, file=sys.stderr) + result = RemoteException(except_str, type(e)) + return result + + +def _handle_exception(result): + if isinstance(result, RemoteException): + exception_msg = result.msg.encode("utf-8").decode("unicode_escape") + # We wrap exception re-creation here in case some exception classes + # cannot be constructed directly from a string. + exc = None + try: + exc = result.exception_type(exception_msg) + except BaseException as e: + raise RuntimeError( # noqa: B904 + f"Failed to create original exception type. Error msg was {str(e)}" + f" Original exception on remote side was {exception_msg}" + ) from e + + if exc is not None: + raise exc + + +def _build_rpc_profiling_key( + exec_type, func_name, current_worker_name, dst_worker_name +): + """ + Builds the key that RPC calls are profiled with using the autograd profiler. + This will be the name of the corresponding Event recorded in the profiler. + + Args: + exec_type (RPCExecMode): Type of RPC/RRef call + func_name (str): Name of function being profiled. + current_worker_name (str): Name of current worker. + dst_worker_name (str): Name of the destination worker. + + Returns: + String representing profiling key + """ + profile_key = ( + f"rpc_{exec_type.value}#{func_name}({current_worker_name} -> {dst_worker_name})" + ) + return profile_key + + +def _start_record_function(exec_type, func_name, current_worker_name, dest_worker_name): + """ + This function should be called from RPC/RRef functions to create a + RecordFunction object for profiling. This function also runs the before + callbacks that start the profiling, though the user is responsible for + running the appropriate callbacks when the function to be profiled finishes. + + Args: + exec_type (RPCExecMode): Type of RPC/RRef call + func_name (str): Name of function being profiled. + current_worker_name (str): Name of current worker. + dest_worker_name (str): Name of the destination worker. + + Returns: + An instance of `core.autograd._RecordFunction`. + """ + assert core.autograd._profiler_enabled(), "Autograd profiler should be enabled." + profile_key = f"rpc_{exec_type.value}#{str(func_name)}({current_worker_name} -> {dest_worker_name})" + rf = core.autograd._RecordFunction() # type: ignore[attr-defined] + core.autograd._run_before_callbacks(rf, profile_key) # type: ignore[attr-defined] + return rf + + +PythonUDF = collections.namedtuple("PythonUDF", ["func", "args", "kwargs"]) +RemoteException = collections.namedtuple("RemoteException", ["msg", "exception_type"]) diff --git a/mindnlp/core/distributed/rpc/options.py b/mindnlp/core/distributed/rpc/options.py new file mode 100644 index 000000000..94768806e --- /dev/null +++ b/mindnlp/core/distributed/rpc/options.py @@ -0,0 +1,175 @@ +# mypy: allow-untyped-defs +from typing import Dict, List, Optional, Union + +from mindnlp import core +from core._C._distributed_rpc import _TensorPipeRpcBackendOptionsBase + +from . import constants as rpc_contants + + +DeviceType = Union[int, str, core.device] + +__all__ = ["TensorPipeRpcBackendOptions"] + + +def _to_device(device: DeviceType) -> core.device: + device = core.device(device) + if device.type != "cuda": + raise ValueError( + "`set_devices` expect a list of CUDA devices, but got " + f"device type {device.type}." + ) + return device + + +def _to_device_map( + device_map: Dict[DeviceType, DeviceType] +) -> Dict[core.device, core.device]: + full_device_map: Dict[core.device, core.device] = {} + reverse_map: Dict[core.device, core.device] = {} + for k, v in device_map.items(): + k, v = core.device(k), core.device(v) + if v in reverse_map: + raise ValueError( + "`device_map` only supports 1-to-1 mapping, " + f"trying to map {k} and {reverse_map[v]} to {v}" + ) + full_device_map[k] = v + reverse_map[v] = k + return full_device_map + + +def _to_device_list(devices: List[DeviceType]) -> List[core.device]: + return list(map(_to_device, devices)) + + +class TensorPipeRpcBackendOptions(_TensorPipeRpcBackendOptionsBase): + r""" + The backend options for + :class:`~core.distributed.rpc.TensorPipeAgent`, derived from + :class:`~core.distributed.rpc.RpcBackendOptions`. + + Args: + num_worker_threads (int, optional): The number of threads in the + thread-pool used by + :class:`~core.distributed.rpc.TensorPipeAgent` to execute + requests (default: 16). + rpc_timeout (float, optional): The default timeout, in seconds, + for RPC requests (default: 60 seconds). If the RPC has not + completed in this timeframe, an exception indicating so will + be raised. Callers can override this timeout for individual + RPCs in :meth:`~core.distributed.rpc.rpc_sync` and + :meth:`~core.distributed.rpc.rpc_async` if necessary. + init_method (str, optional): The URL to initialize the distributed + store used for rendezvous. It takes any value accepted for the + same argument of :meth:`~core.distributed.init_process_group` + (default: ``env://``). + device_maps (Dict[str, Dict], optional): Device placement mappings from + this worker to the callee. Key is the callee worker name and value + the dictionary (``Dict`` of ``int``, ``str``, or ``core.device``) + that maps this worker's devices to the callee worker's devices. + (default: ``None``) + devices (List[int, str, or ``core.device``], optional): all local + CUDA devices used by RPC agent. By Default, it will be initialized + to all local devices from its own ``device_maps`` and corresponding + devices from its peers' ``device_maps``. When processing CUDA RPC + requests, the agent will properly synchronize CUDA streams for + all devices in this ``List``. + """ + + def __init__( + self, + *, + num_worker_threads: int = rpc_contants.DEFAULT_NUM_WORKER_THREADS, + rpc_timeout: float = rpc_contants.DEFAULT_RPC_TIMEOUT_SEC, + init_method: str = rpc_contants.DEFAULT_INIT_METHOD, + device_maps: Optional[Dict[str, Dict[DeviceType, DeviceType]]] = None, + devices: Optional[List[DeviceType]] = None, + _transports: Optional[List] = None, + _channels: Optional[List] = None, + ): + full_device_maps = ( + {} + if device_maps is None + else {k: _to_device_map(v) for k, v in device_maps.items()} + ) + full_device_list = [] if devices is None else _to_device_list(devices) + super().__init__( + num_worker_threads, + _transports, + _channels, + rpc_timeout, + init_method, + full_device_maps, + full_device_list, + ) + + def set_device_map(self, to: str, device_map: Dict[DeviceType, DeviceType]): + r""" + Set device mapping between each RPC caller and callee pair. This + function can be called multiple times to incrementally add + device placement configurations. + + Args: + to (str): Callee name. + device_map (Dict of int, str, or core.device): Device placement + mappings from this worker to the callee. This map must be + invertible. + + Example: + >>> # xdoctest: +SKIP("distributed") + >>> # both workers + >>> def add(x, y): + >>> print(x) # tensor([1., 1.], device='cuda:1') + >>> return x + y, (x + y).to(2) + >>> + >>> # on worker 0 + >>> options = TensorPipeRpcBackendOptions( + >>> num_worker_threads=8, + >>> device_maps={"worker1": {0: 1}} + >>> # maps worker0's cuda:0 to worker1's cuda:1 + >>> ) + >>> options.set_device_map("worker1", {1: 2}) + >>> # maps worker0's cuda:1 to worker1's cuda:2 + >>> + >>> rpc.init_rpc( + >>> "worker0", + >>> rank=0, + >>> world_size=2, + >>> backend=rpc.BackendType.TENSORPIPE, + >>> rpc_backend_options=options + >>> ) + >>> + >>> x = core.ones(2) + >>> rets = rpc.rpc_sync("worker1", add, args=(x.to(0), 1)) + >>> # The first argument will be moved to cuda:1 on worker1. When + >>> # sending the return value back, it will follow the invert of + >>> # the device map, and hence will be moved back to cuda:0 and + >>> # cuda:1 on worker0 + >>> print(rets[0]) # tensor([2., 2.], device='cuda:0') + >>> print(rets[1]) # tensor([2., 2.], device='cuda:1') + """ + full_device_map = _to_device_map(device_map) + curr_device_maps = super().device_maps + + if to in curr_device_maps: + for k, v in full_device_map.items(): + if k in curr_device_maps[to] and v != curr_device_maps[to][k]: + raise ValueError( + "`set_device_map` only supports 1-to-1 mapping, trying" + f" to map {k} to {v} and {curr_device_maps[to][k]}" + ) + + super()._set_device_map(to, full_device_map) + + def set_devices(self, devices: List[DeviceType]): + r""" + Set local devices used by the TensorPipe RPC agent. When processing + CUDA RPC requests, the TensorPipe RPC agent will properly synchronize + CUDA streams for all devices in this ``List``. + + Args: + devices (List of int, str, or core.device): local devices used by + the TensorPipe RPC agent. + """ + self.devices = _to_device_list(devices) diff --git a/mindnlp/core/distributed/rpc/rref_proxy.py b/mindnlp/core/distributed/rpc/rref_proxy.py new file mode 100644 index 000000000..929cbfe12 --- /dev/null +++ b/mindnlp/core/distributed/rpc/rref_proxy.py @@ -0,0 +1,80 @@ +# mypy: allow-untyped-defs +from functools import partial + +from mindnlp import core +from core.futures import Future + +from . import functions, rpc_async +from .constants import UNSET_RPC_TIMEOUT + + +def _local_invoke(rref, func_name, args, kwargs): + return getattr(rref.local_value(), func_name)(*args, **kwargs) + + +@functions.async_execution +def _local_invoke_async_execution(rref, func_name, args, kwargs): + return getattr(rref.local_value(), func_name)(*args, **kwargs) + + +def _invoke_rpc(rref, rpc_api, func_name, timeout, *args, **kwargs): + def _rref_type_cont(rref_fut): + rref_type = rref_fut.value() + + _invoke_func = _local_invoke + # Bypass ScriptModules when checking for async function attribute. + bypass_type = issubclass(rref_type, core.jit.ScriptModule) or issubclass( + rref_type, core._C.ScriptModule + ) + if not bypass_type: + func = getattr(rref_type, func_name) + if hasattr(func, "_wrapped_async_rpc_function"): + _invoke_func = _local_invoke_async_execution + + return rpc_api( + rref.owner(), + _invoke_func, + args=(rref, func_name, args, kwargs), + timeout=timeout, + ) + + rref_fut = rref._get_type(timeout=timeout, blocking=False) + + if rpc_api != rpc_async: + rref_fut.wait() + return _rref_type_cont(rref_fut) + else: + # A little explanation on this. + # rpc_async returns a Future pointing to the return value of `func_name`, it returns a `Future[T]` + # Calling _rref_type_cont from the `then` lambda causes Future wrapping. IOW, `then` returns a `Future[Future[T]]` + # To address that, we return a Future that is completed with the result of the async call. + result: Future = Future() + + def _wrap_rref_type_cont(fut): + try: + _rref_type_cont(fut).then(_complete_op) + except BaseException as ex: + result.set_exception(ex) + + def _complete_op(fut): + try: + result.set_result(fut.value()) + except BaseException as ex: + result.set_exception(ex) + + rref_fut.then(_wrap_rref_type_cont) + return result + + +# This class manages proxied RPC API calls for RRefs. It is entirely used from +# C++ (see python_rpc_handler.cpp). +class RRefProxy: + def __init__(self, rref, rpc_api, timeout=UNSET_RPC_TIMEOUT): + self.rref = rref + self.rpc_api = rpc_api + self.rpc_timeout = timeout + + def __getattr__(self, func_name): + return partial( + _invoke_rpc, self.rref, self.rpc_api, func_name, self.rpc_timeout + ) diff --git a/mindnlp/core/distributed/rpc/server_process_global_profiler.py b/mindnlp/core/distributed/rpc/server_process_global_profiler.py new file mode 100644 index 000000000..5607fa489 --- /dev/null +++ b/mindnlp/core/distributed/rpc/server_process_global_profiler.py @@ -0,0 +1,183 @@ +#!/usr/bin/python3 +# mypy: allow-untyped-defs + +import itertools +from typing import List + +from mindnlp import core +from core.autograd.profiler_legacy import profile + +from . import ( + _disable_server_process_global_profiler, + _enable_server_process_global_profiler, +) + + +__all__: List[str] = [] + + +class _server_process_global_profile(profile): + """ + It has the same API as ``core.autograd.profiler.profile`` class, + except that it enables profiling on all threads running RPC server request callbacks. + + Context manager that manages autograd profiler state and holds a summary of results. + Under the hood it just records events of functions being executed in C++ and + exposes those events to Python. You can wrap any code into it and it will + only report runtime of PyTorch functions. + Note: profiler is thread local and is automatically propagated into the async tasks + + Args: + enabled (bool, optional): Setting this to False makes this context manager a no-op. + Default: ``True``. + + use_cuda (bool, optional): Enables timing of CUDA events as well using the cudaEvent API. + Adds approximately 4us of overhead to each tensor operation. + Default: ``False`` + + record_shapes (bool, optional): If shapes recording is set, information + about input dimensions will be collected. This allows one to see which + dimensions have been used under the hood and further group by them + using prof.key_averages(group_by_input_shape=True). Please note that + shape recording might skew your profiling data. It is recommended to + use separate runs with and without shape recording to validate the timing. + Most likely the skew will be negligible for bottom most events (in a case + of nested function calls). But for higher level functions the total + self cpu time might be artificially increased because of the shape + collection. + + profile_memory (bool, optional): Whether to report memory usage, default: ``False`` + + .. warning: + Enabling memory profiling incurs additional profiler overhead + + .. warning: + Due to some CUDA multiprocessing limitations (multiprocessing-cuda-note_), + one cannot use the profiler with ``use_cuda = True`` to benchmark + DataLoaders with ``num_workers > 0``. If you wish to benchmark data loading, + please use ``use_cuda = False`` or ``num_workers = 0``. + + Example: + >>> # xdoctest: +SKIP + >>> # On worker 0: + >>> from mindnlp import core + >>> from mindnlp import core.distributed.rpc as rpc + >>> rpc.init_rpc("worker0", rank=0, world_size=2) + >>> x, y = core.tensor(1), core.tensor(2) + >>> outer_profile_rref = rpc.remote(dst_worker_name, rpc._server_process_global_profile) + >>> outer_profile_rref.rpc_sync().__enter__() + >>> rpc.rpc_sync(dst_worker_name, core.add, (x, y)) + >>> inner_profile_rref = rpc.remote(dst_worker_name, rpc._server_process_global_profile) + >>> inner_profile_rref.rpc_sync().__enter__() + >>> rpc.rpc_sync(dst_worker_name, core.sub, (x, y)) + >>> inner_profile_rref.rpc_sync().__exit__(None, None, None) + >>> outer_profile_rref.rpc_sync().__exit__(None, None, None) + >>> print(inner_profile_rref.rpc_sync().key_averages()) + --------- --------------- --------------- --------------- --------------- --------------- --------------- + Name Self CPU total % Self CPU total CPU total % CPU total CPU time avg Number of Calls + --------- --------------- --------------- --------------- --------------- --------------- --------------- + sub 85.06% 76.275us 100.00% 89.667us 89.667us 1 + empty 14.94% 13.392us 14.94% 13.392us 13.392us 1 + --------- --------------- --------------- --------------- --------------- --------------- --------------- + Self CPU time total: 89.667us + >>> print(outer_profile_rref.rpc_sync().key_averages()) + --------- --------------- --------------- --------------- --------------- --------------- --------------- + Name Self CPU total % Self CPU total CPU total % CPU total CPU time avg Number of Calls + --------- --------------- --------------- --------------- --------------- --------------- --------------- + sub 35.65% 76.275us 41.91% 89.667us 89.667us 1 + empty 12.67% 27.101us 12.67% 27.101us 13.551us 2 + add 51.68% 110.550us 58.09% 124.259us 124.259us 1 + --------- --------------- --------------- --------------- --------------- --------------- --------------- + Self CPU time total: 213.926us + >>> rpc.shutdown() + + >>> # On worker 1: + >>> from mindnlp import core.distributed.rpc as rpc + >>> rpc.init_rpc("worker1", rank=1, world_size=2) + >>> # wait for worker 0 to finish work, and then shutdown. + >>> rpc.shutdown() + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def __enter__(self): + """ + Turn on server-side process-global profiling. + This enables thread-local profiler on all RPC threads running server-side request callbacks. + """ + if not self.enabled: + return + + if self.entered: # type: ignore[has-type] + raise RuntimeError("autograd profiler traces are not reentrant") + self.entered = True + + profiler_kind = ( + core.autograd.ProfilerState.CUDA + if self.use_cuda + else core.autograd.ProfilerState.CPU + ) + profiler_config = core.autograd.ProfilerConfig( + profiler_kind, + self.record_shapes, + self.profile_memory, + False, + False, + False, + core.profiler._ExperimentalConfig(), + ) + _enable_server_process_global_profiler(profiler_config) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """ + Turn off server-side process-global profiling. + Aggregate all profiling events recorded by RPC threads. + + These attributes are assigned on exiting context. + + Attributes: + function_events (core.autograd.profiler.EventList). It's a list that has helper + methods, like 1) show record items in a pretty-print table. + 2) do averaging by grouping on keys. 3) and more. + + process_global_function_events (List[core.autograd.profiler.FunctionEvent]). + It's a list of ``FunctionEvent`` elements. Every element is a profiling result + of an RPC request handling within the profiling range. + """ + if not self.enabled: + return + + process_global_events = _disable_server_process_global_profiler() + + # Every element in this list is a thread profiling result from an RPC request handling. + process_global_function_events = [] + for thread_local_events in process_global_events: + # Parse from ``Event``s to ``FunctionEvent``s. + thread_local_function_events = ( + core.autograd.profiler_legacy._parse_legacy_records( + thread_local_events + ) + ) + thread_local_function_events.sort( + key=lambda function_event: [ + function_event.time_range.start, + -(function_event.time_range.end), + ] + ) + process_global_function_events.append(thread_local_function_events) + + flattened_function_events = list( + itertools.chain.from_iterable(process_global_function_events) + ) + self.function_events = core.autograd.profiler_util.EventList( + flattened_function_events, + use_device="cuda" if self.use_cuda else None, + profile_memory=self.profile_memory, + ) + self.function_events._build_tree() + + self.process_global_function_events = process_global_function_events + + return False diff --git a/mindnlp/core/distributed/run.py b/mindnlp/core/distributed/run.py new file mode 100644 index 000000000..7a97817d5 --- /dev/null +++ b/mindnlp/core/distributed/run.py @@ -0,0 +1,922 @@ +#!/usr/bin/env python3 +# mypy: allow-untyped-defs + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Superset of ``core.distributed.launch``. + +``torchrun`` provides a superset of the functionality as ``core.distributed.launch`` +with the following additional functionalities: + +1. Worker failures are handled gracefully by restarting all workers. + +2. Worker ``RANK`` and ``WORLD_SIZE`` are assigned automatically. + +3. Number of nodes is allowed to change between minimum and maximum sizes (elasticity). + +.. note:: ``torchrun`` is a python + `console script `_ + to the main module + `core.distributed.run `_ + declared in the ``entry_points`` configuration in + `setup.py `_. + It is equivalent to invoking ``python -m core.distributed.run``. + + +Transitioning from core.distributed.launch to torchrun +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + +``torchrun`` supports the same arguments as ``core.distributed.launch`` **except** +for ``--use-env`` which is now deprecated. To migrate from ``core.distributed.launch`` +to ``torchrun`` follow these steps: + +1. If your training script is already reading ``local_rank`` from the ``LOCAL_RANK`` environment variable. + Then you need simply omit the ``--use-env`` flag, e.g.: + + +--------------------------------------------------------------------+--------------------------------------------+ + | ``core.distributed.launch`` | ``torchrun`` | + +====================================================================+============================================+ + | | | + | .. code-block:: shell-session | .. code-block:: shell-session | + | | | + | $ python -m core.distributed.launch --use-env train_script.py | $ torchrun train_script.py | + | | | + +--------------------------------------------------------------------+--------------------------------------------+ + +2. If your training script reads local rank from a ``--local-rank`` cmd argument. + Change your training script to read from the ``LOCAL_RANK`` environment variable as + demonstrated by the following code snippet: + + +-------------------------------------------------------+----------------------------------------------------+ + | ``core.distributed.launch`` | ``torchrun`` | + +=======================================================+====================================================+ + | | | + | .. code-block:: python | .. code-block:: python | + | | | + | | | + | import argparse | import os | + | parser = argparse.ArgumentParser() | local_rank = int(os.environ["LOCAL_RANK"]) | + | parser.add_argument("--local-rank", type=int) | | + | args = parser.parse_args() | | + | | | + | local_rank = args.local_rank | | + | | | + +-------------------------------------------------------+----------------------------------------------------+ + +.. versionchanged:: 2.0.0 + + The launcher will pass the ``--local-rank=`` argument to your script. + From PyTorch 2.0.0 onwards, the dashed ``--local-rank`` is preferred over the + previously used underscored ``--local_rank``. + + For backward compatibility, it may be necessary for users to handle both + cases in their argument parsing code. This means including both ``"--local-rank"`` + and ``"--local_rank"`` in the argument parser. If only ``"--local_rank"`` is + provided, the launcher will trigger an error: "error: unrecognized arguments: + --local-rank=". For training code that only supports PyTorch 2.0.0+, + including ``"--local-rank"`` should be sufficient. + + :: + + >>> # xdoctest: +SKIP + >>> import argparse + >>> parser = argparse.ArgumentParser() + >>> parser.add_argument("--local-rank", "--local_rank", type=int) + >>> args = parser.parse_args() + +The aformentioned changes suffice to migrate from ``core.distributed.launch`` to ``torchrun``. +To take advantage of new features such as elasticity, fault-tolerance, and error reporting of ``torchrun`` +please refer to: + +* :ref:`elastic_train_script` for more information on authoring training scripts that are ``torchrun`` compliant. +* the rest of this page for more information on the features of ``torchrun``. + + +Usage +-------- + +Single-node multi-worker +++++++++++++++++++++++++++++++ + +:: + + torchrun + --standalone + --nnodes=1 + --nproc-per-node=$NUM_TRAINERS + YOUR_TRAINING_SCRIPT.py (--arg1 ... train script args...) + +Stacked single-node multi-worker ++++++++++++++++++++++++++++++++++++ + +To run multiple instances (separate jobs) of single-node, multi-worker on the +same host, we need to make sure that each instance (job) is +setup on different ports to avoid port conflicts (or worse, two jobs being merged +as a single job). To do this you have to run with ``--rdzv-backend=c10d`` +and specify a different port by setting ``--rdzv-endpoint=localhost:$PORT_k``. +For ``--nodes=1``, its often convenient to let ``torchrun`` pick a free random +port automatically instead of manually assigning different ports for each run. + +:: + + torchrun + --rdzv-backend=c10d + --rdzv-endpoint=localhost:0 + --nnodes=1 + --nproc-per-node=$NUM_TRAINERS + YOUR_TRAINING_SCRIPT.py (--arg1 ... train script args...) + + +Fault tolerant (fixed sized number of workers, no elasticity, tolerates 3 failures) +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +:: + + torchrun + --nnodes=$NUM_NODES + --nproc-per-node=$NUM_TRAINERS + --max-restarts=3 + --rdzv-id=$JOB_ID + --rdzv-backend=c10d + --rdzv-endpoint=$HOST_NODE_ADDR + YOUR_TRAINING_SCRIPT.py (--arg1 ... train script args...) + +``HOST_NODE_ADDR``, in form [:] (e.g. node1.example.com:29400), specifies the node and +the port on which the C10d rendezvous backend should be instantiated and hosted. It can be any +node in your training cluster, but ideally you should pick a node that has a high bandwidth. + +.. note:: + If no port number is specified ``HOST_NODE_ADDR`` defaults to 29400. + +Elastic (``min=1``, ``max=4``, tolerates up to 3 membership changes or failures) ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +:: + + torchrun + --nnodes=1:4 + --nproc-per-node=$NUM_TRAINERS + --max-restarts=3 + --rdzv-id=$JOB_ID + --rdzv-backend=c10d + --rdzv-endpoint=$HOST_NODE_ADDR + YOUR_TRAINING_SCRIPT.py (--arg1 ... train script args...) + +``HOST_NODE_ADDR``, in form [:] (e.g. node1.example.com:29400), specifies the node and +the port on which the C10d rendezvous backend should be instantiated and hosted. It can be any +node in your training cluster, but ideally you should pick a node that has a high bandwidth. + +.. note:: + If no port number is specified ``HOST_NODE_ADDR`` defaults to 29400. + +Note on rendezvous backend +------------------------------ + +For multi-node training you need to specify: + +1. ``--rdzv-id``: A unique job id (shared by all nodes participating in the job) +2. ``--rdzv-backend``: An implementation of + :py:class:`core.distributed.elastic.rendezvous.RendezvousHandler` +3. ``--rdzv-endpoint``: The endpoint where the rendezvous backend is running; usually in form + ``host:port``. + +Currently ``c10d`` (recommended), ``etcd-v2``, and ``etcd`` (legacy) rendezvous backends are +supported out of the box. To use ``etcd-v2`` or ``etcd``, setup an etcd server with the ``v2`` api +enabled (e.g. ``--enable-v2``). + +.. warning:: + ``etcd-v2`` and ``etcd`` rendezvous use etcd API v2. You MUST enable the v2 API on the etcd + server. Our tests use etcd v3.4.3. + +.. warning:: + For etcd-based rendezvous we recommend using ``etcd-v2`` over ``etcd`` which is functionally + equivalent, but uses a revised implementation. ``etcd`` is in maintenance mode and will be + removed in a future version. + +Definitions +-------------- + +1. ``Node`` - A physical instance or a container; maps to the unit that the job manager works with. + +2. ``Worker`` - A worker in the context of distributed training. + +3. ``WorkerGroup`` - The set of workers that execute the same function (e.g. trainers). + +4. ``LocalWorkerGroup`` - A subset of the workers in the worker group running on the same node. + +5. ``RANK`` - The rank of the worker within a worker group. + +6. ``WORLD_SIZE`` - The total number of workers in a worker group. + +7. ``LOCAL_RANK`` - The rank of the worker within a local worker group. + +8. ``LOCAL_WORLD_SIZE`` - The size of the local worker group. + +9. ``rdzv_id`` - A user-defined id that uniquely identifies the worker group for a job. This id is + used by each node to join as a member of a particular worker group. + +9. ``rdzv_backend`` - The backend of the rendezvous (e.g. ``c10d``). This is typically a strongly + consistent key-value store. + +10. ``rdzv_endpoint`` - The rendezvous backend endpoint; usually in form ``:``. + +A ``Node`` runs ``LOCAL_WORLD_SIZE`` workers which comprise a ``LocalWorkerGroup``. The union of +all ``LocalWorkerGroups`` in the nodes in the job comprise the ``WorkerGroup``. + +Environment Variables +---------------------- + +The following environment variables are made available to you in your script: + +1. ``LOCAL_RANK`` - The local rank. + +2. ``RANK`` - The global rank. + +3. ``GROUP_RANK`` - The rank of the worker group. A number between 0 and ``max_nnodes``. When + running a single worker group per node, this is the rank of the node. + +4. ``ROLE_RANK`` - The rank of the worker across all the workers that have the same role. The role + of the worker is specified in the ``WorkerSpec``. + +5. ``LOCAL_WORLD_SIZE`` - The local world size (e.g. number of workers running locally); equals to + ``--nproc-per-node`` specified on ``torchrun``. + +6. ``WORLD_SIZE`` - The world size (total number of workers in the job). + +7. ``ROLE_WORLD_SIZE`` - The total number of workers that was launched with the same role specified + in ``WorkerSpec``. + +8. ``MASTER_ADDR`` - The FQDN of the host that is running worker with rank 0; used to initialize + the Torch Distributed backend. + +9. ``MASTER_PORT`` - The port on the ``MASTER_ADDR`` that can be used to host the C10d TCP store. + +10. ``TORCHELASTIC_RESTART_COUNT`` - The number of worker group restarts so far. + +11. ``TORCHELASTIC_MAX_RESTARTS`` - The configured maximum number of restarts. + +12. ``TORCHELASTIC_RUN_ID`` - Equal to the rendezvous ``run_id`` (e.g. unique job id). + +13. ``PYTHON_EXEC`` - System executable override. If provided, the python user script will + use the value of ``PYTHON_EXEC`` as executable. The `sys.executable` is used by default. + +Deployment +------------ + +1. (Not needed for the C10d backend) Start the rendezvous backend server and get the endpoint (to be + passed as ``--rdzv-endpoint`` to the launcher script) + +2. Single-node multi-worker: Start the launcher on the host to start the agent process which + creates and monitors a local worker group. + +3. Multi-node multi-worker: Start the launcher with the same arguments on all the nodes + participating in training. + +When using a job/cluster manager the entry point command to the multi-node job should be this +launcher. + +Failure Modes +--------------- + +1. Worker failure: For a training job with ``n`` workers, if ``k<=n`` workers fail all workers + are stopped and restarted up to ``max_restarts``. + +2. Agent failure: An agent failure results in a local worker group failure. It is up to the job + manager to fail the entire job (gang semantics) or attempt to replace the node. Both behaviors + are supported by the agent. + +3. Node failure: Same as agent failure. + +Membership Changes +-------------------- + +1. Node departure (scale-down): The agent is notified of the departure, all existing workers are + stopped, a new ``WorkerGroup`` is formed, and all workers are started with a new ``RANK`` and + ``WORLD_SIZE``. + +2. Node arrival (scale-up): The new node is admitted to the job, all existing workers are stopped, + a new ``WorkerGroup`` is formed, and all workers are started with a new ``RANK`` and + ``WORLD_SIZE``. + +Important Notices +-------------------- + +1. This utility and multi-process distributed (single-node or + multi-node) GPU training currently only achieves the best performance using + the NCCL distributed backend. Thus NCCL backend is the recommended backend to + use for GPU training. + +2. The environment variables necessary to initialize a Torch process group are provided to you by + this module, no need for you to pass ``RANK`` manually. To initialize a process group in your + training script, simply run: + +:: + + >>> # xdoctest: +SKIP("stub") + >>> from mindnlp import core.distributed as dist + >>> dist.init_process_group(backend="gloo|nccl") + +3. In your training program, you can either use regular distributed functions + or use :func:`core.nn.parallel.DistributedDataParallel` module. If your + training program uses GPUs for training and you would like to use + :func:`core.nn.parallel.DistributedDataParallel` module, + here is how to configure it. + +:: + + local_rank = int(os.environ["LOCAL_RANK"]) + model = core.nn.parallel.DistributedDataParallel(model, + device_ids=[local_rank], + output_device=local_rank) + +Please ensure that ``device_ids`` argument is set to be the only GPU device id +that your code will be operating on. This is generally the local rank of the +process. In other words, the ``device_ids`` needs to be ``[int(os.environ("LOCAL_RANK"))]``, +and ``output_device`` needs to be ``int(os.environ("LOCAL_RANK"))`` in order to use this +utility + + +4. On failures or membership changes ALL surviving workers are killed immediately. Make sure to + checkpoint your progress. The frequency of checkpoints should depend on your job's tolerance + for lost work. + +5. This module only supports homogeneous ``LOCAL_WORLD_SIZE``. That is, it is assumed that all + nodes run the same number of local workers (per role). + +6. ``RANK`` is NOT stable. Between restarts, the local workers on a node can be assigned a + different range of ranks than before. NEVER hard code any assumptions about the stable-ness of + ranks or some correlation between ``RANK`` and ``LOCAL_RANK``. + +7. When using elasticity (``min_size!=max_size``) DO NOT hard code assumptions about + ``WORLD_SIZE`` as the world size can change as nodes are allowed to leave and join. + +8. It is recommended for your script to have the following structure: + +:: + + def main(): + load_checkpoint(checkpoint_path) + initialize() + train() + + def train(): + for batch in iter(dataset): + train_step(batch) + + if should_checkpoint: + save_checkpoint(checkpoint_path) + +9. (Recommended) On worker errors, this tool will summarize the details of the error + (e.g. time, rank, host, pid, traceback, etc). On each node, the first error (by timestamp) + is heuristically reported as the "Root Cause" error. To get tracebacks as part of this + error summary print out, you must decorate your main entrypoint function in your + training script as shown in the example below. If not decorated, then the summary + will not include the traceback of the exception and will only contain the exitcode. + For details on torchelastic error handling see: https://pycore.org/docs/stable/elastic/errors.html + +:: + + from core.distributed.elastic.multiprocessing.errors import record + + @record + def main(): + # do train + pass + + if __name__ == "__main__": + main() + +""" +import os +import sys +import uuid +from argparse import ArgumentParser, REMAINDER +from importlib import metadata +from typing import Callable, List, Optional, Set, Tuple, Type, Union + +from mindnlp import core +from core.distributed.argparse_util import check_env, env +from core.distributed.elastic.multiprocessing import DefaultLogsSpecs, LogsSpecs, Std +from core.distributed.elastic.multiprocessing.errors import record +from core.distributed.elastic.rendezvous.utils import _parse_rendezvous_config +from core.distributed.elastic.utils import macros +from core.distributed.elastic.utils.logging import get_logger +from core.distributed.launcher.api import elastic_launch, LaunchConfig +from core.utils.backend_registration import _get_custom_mod_func + + +logger = get_logger(__name__) + + +def get_args_parser() -> ArgumentParser: + """Parse the command line options.""" + parser = ArgumentParser(description="Torch Distributed Elastic Training Launcher") + + # + # Worker/node size related arguments. + # + + parser.add_argument( + "--nnodes", + action=env, + type=str, + default="1:1", + help="Number of nodes, or the range of nodes in form :.", + ) + parser.add_argument( + "--nproc-per-node", + "--nproc_per_node", + action=env, + type=str, + default="1", + help="Number of workers per node; supported values: [auto, cpu, gpu, int].", + ) + + # + # Rendezvous related arguments + # + + parser.add_argument( + "--rdzv-backend", + "--rdzv_backend", + action=env, + type=str, + default="static", + help="Rendezvous backend.", + ) + parser.add_argument( + "--rdzv-endpoint", + "--rdzv_endpoint", + action=env, + type=str, + default="", + help="Rendezvous backend endpoint; usually in form :.", + ) + parser.add_argument( + "--rdzv-id", + "--rdzv_id", + action=env, + type=str, + default="none", + help="User-defined group id.", + ) + parser.add_argument( + "--rdzv-conf", + "--rdzv_conf", + action=env, + type=str, + default="", + help="Additional rendezvous configuration (=,=,...).", + ) + parser.add_argument( + "--standalone", + action=check_env, + help="Start a local standalone rendezvous backend that is represented by a C10d TCP store " + "on a free port. Useful when launching single-node, multi-worker job. If specified " + "--rdzv-backend, --rdzv-endpoint, --rdzv-id are auto-assigned and any explicitly set values " + "are ignored.", + ) + + # + # User-code launch related arguments. + # + + parser.add_argument( + "--max-restarts", + "--max_restarts", + action=env, + type=int, + default=0, + help="Maximum number of worker group restarts before failing.", + ) + parser.add_argument( + "--monitor-interval", + "--monitor_interval", + action=env, + type=float, + default=0.1, + help="Interval, in seconds, to monitor the state of workers.", + ) + parser.add_argument( + "--start-method", + "--start_method", + action=env, + type=str, + default="spawn", + choices=["spawn", "fork", "forkserver"], + help="Multiprocessing start method to use when creating workers.", + ) + parser.add_argument( + "--role", + action=env, + type=str, + default="default", + help="User-defined role for the workers.", + ) + parser.add_argument( + "-m", + "--module", + action=check_env, + help="Change each process to interpret the launch script as a Python module, executing " + "with the same behavior as 'python -m'.", + ) + parser.add_argument( + "--no-python", + "--no_python", + action=check_env, + help="Skip prepending the training script with 'python' - just execute it directly. Useful " + "when the script is not a Python script.", + ) + + parser.add_argument( + "--run-path", + "--run_path", + action=check_env, + help="Run the training script with runpy.run_path in the same interpreter." + " Script must be provided as an abs path (e.g. /abs/path/script.py)." + " Takes precedence over --no-python.", + ) + parser.add_argument( + "--log-dir", + "--log_dir", + action=env, + type=str, + default=None, + help="Base directory to use for log files (e.g. /var/log/torch/elastic). The same " + "directory is re-used for multiple runs (a unique job-level sub-directory is created with " + "rdzv_id as the prefix).", + ) + parser.add_argument( + "-r", + "--redirects", + action=env, + type=str, + default="0", + help="Redirect std streams into a log file in the log directory (e.g. [-r 3] redirects " + "both stdout+stderr for all workers, [-r 0:1,1:2] redirects stdout for local rank 0 and " + "stderr for local rank 1).", + ) + parser.add_argument( + "-t", + "--tee", + action=env, + type=str, + default="0", + help="Tee std streams into a log file and also to console (see --redirects for format).", + ) + + parser.add_argument( + "--local-ranks-filter", + "--local_ranks_filter", + action=env, + type=str, + default="", + help="Only show logs from specified ranks in console (e.g. [--local_ranks_filter=0,1,2] will " + "only show logs from rank 0, 1 and 2). This will only apply to stdout and stderr, not to" + "log files saved via --redirect or --tee", + ) + + # + # Backwards compatible parameters with caffe2.distributed.launch. + # + + parser.add_argument( + "--node-rank", + "--node_rank", + type=int, + action=env, + default=0, + help="Rank of the node for multi-node distributed training.", + ) + parser.add_argument( + "--master-addr", + "--master_addr", + default="127.0.0.1", + type=str, + action=env, + help="Address of the master node (rank 0) that only used for static rendezvous. It should " + "be either the IP address or the hostname of rank 0. For single node multi-proc training " + "the --master-addr can simply be 127.0.0.1; IPv6 should have the pattern " + "`[0:0:0:0:0:0:0:1]`.", + ) + parser.add_argument( + "--master-port", + "--master_port", + default=29500, + type=int, + action=env, + help="Port on the master node (rank 0) to be used for communication during distributed " + "training. It is only used for static rendezvous.", + ) + parser.add_argument( + "--local-addr", + "--local_addr", + default=None, + type=str, + action=env, + help="Address of the local node. If specified, will use the given address for connection. " + "Else, will look up the local node address instead. Else, it will be default to local " + "machine's FQDN.", + ) + + parser.add_argument( + "--logs-specs", + "--logs_specs", + default=None, + type=str, + help="torchrun.logs_specs group entrypoint name, value must be type of LogsSpecs. " + "Can be used to override custom logging behavior.", + ) + + # + # Positional arguments. + # + + parser.add_argument( + "training_script", + type=str, + help="Full path to the (single GPU) training program/script to be launched in parallel, " + "followed by all the arguments for the training script.", + ) + + # Rest from the training program. + parser.add_argument("training_script_args", nargs=REMAINDER) + + return parser + + +def parse_args(args): + parser = get_args_parser() + return parser.parse_args(args) + + +def parse_min_max_nnodes(nnodes: str): + arr = nnodes.split(":") + + if len(arr) == 1: + min_nodes = max_nodes = int(arr[0]) + elif len(arr) == 2: + min_nodes = int(arr[0]) + max_nodes = int(arr[1]) + else: + raise RuntimeError(f'nnodes={nnodes} is not in "MIN:MAX" format') # noqa: E231 + + return min_nodes, max_nodes + + +def determine_local_world_size(nproc_per_node: str): + try: + logger.info("Using nproc_per_node=%s.", nproc_per_node) + return int(nproc_per_node) + except ValueError as e: + if nproc_per_node == "cpu": + num_proc = os.cpu_count() + device_type = "cpu" + elif nproc_per_node == "gpu": + if not core.cuda.is_available(): + raise ValueError("Cuda is not available.") from e + device_type = "gpu" + num_proc = core.cuda.device_count() + elif nproc_per_node == core._C._get_privateuse1_backend_name(): + if not _get_custom_mod_func("is_available")(): + raise ValueError(f"{nproc_per_node} is not available.") from e + device_type = nproc_per_node + num_proc = _get_custom_mod_func("device_count")() + elif nproc_per_node == "auto": + if core.cuda.is_available(): + num_proc = core.cuda.device_count() + device_type = "gpu" + elif ( + hasattr(torch, core._C._get_privateuse1_backend_name()) + and _get_custom_mod_func("is_available")() + ): + num_proc = _get_custom_mod_func("device_count")() + device_type = core._C._get_privateuse1_backend_name() + else: + num_proc = os.cpu_count() + device_type = "cpu" + else: + raise ValueError( + f"Unsupported nproc_per_node value: {nproc_per_node}" + ) from e + + logger.info( + "Using nproc_per_node=%s, setting nproc_per_node to %s since the instance has %s %s", + nproc_per_node, + num_proc, + num_proc, + device_type, + ) + return num_proc + + +def get_rdzv_endpoint(args): + if args.rdzv_backend == "static" and not args.rdzv_endpoint: + return f"{args.master_addr}:{args.master_port}" # noqa: E231 + return args.rdzv_endpoint + + +def get_use_env(args) -> bool: + """ + Retrieve ``use_env`` from the args. + + ``use_env`` is a legacy argument, if ``use_env`` is False, the + ``--node-rank`` argument will be transferred to all worker processes. + ``use_env`` is only used by the ``core.distributed.launch`` and will + be deprecated in future releases. + """ + if not hasattr(args, "use_env"): + return True + return args.use_env + + +def _get_logs_specs_class(logs_specs_name: Optional[str]) -> Type[LogsSpecs]: + """ + Attemps to load `torchrun.logs_spec` entrypoint with key of `logs_specs_name` param. + Provides plugin mechanism to provide custom implementation of LogsSpecs. + + Returns `DefaultLogsSpecs` when logs_spec_name is None. + Raises ValueError when entrypoint for `logs_spec_name` can't be found in entrypoints. + """ + logs_specs_cls = None + if logs_specs_name is not None: + eps = metadata.entry_points() + if hasattr(eps, "select"): # >= 3.10 + group = eps.select(group="torchrun.logs_specs") + if group.select(name=logs_specs_name): + logs_specs_cls = group[logs_specs_name].load() + + elif specs := eps.get("torchrun.logs_specs"): # < 3.10 + if entrypoint_list := [ep for ep in specs if ep.name == logs_specs_name]: + logs_specs_cls = entrypoint_list[0].load() + + if logs_specs_cls is None: + raise ValueError( + f"Could not find entrypoint under 'torchrun.logs_specs[{logs_specs_name}]' key" + ) + + logger.info( + "Using logs_spec '%s' mapped to %s", logs_specs_name, str(logs_specs_cls) + ) + else: + logs_specs_cls = DefaultLogsSpecs + + return logs_specs_cls + + +def config_from_args(args) -> Tuple[LaunchConfig, Union[Callable, str], List[str]]: + # If ``args`` not passed, defaults to ``sys.argv[:1]`` + min_nodes, max_nodes = parse_min_max_nnodes(args.nnodes) + assert 0 < min_nodes <= max_nodes + assert args.max_restarts >= 0 + + if ( + hasattr(args, "master_addr") + and args.rdzv_backend != "static" + and not args.rdzv_endpoint + ): + logger.warning( + "master_addr is only used for static rdzv_backend and when rdzv_endpoint " + "is not specified." + ) + + nproc_per_node = determine_local_world_size(args.nproc_per_node) + if "OMP_NUM_THREADS" not in os.environ and nproc_per_node > 1: + omp_num_threads = 1 + logger.warning( + "\n*****************************************\n" + "Setting OMP_NUM_THREADS environment variable for each process to be " + "%s in default, to avoid your system being overloaded, " + "please further tune the variable for optimal performance in " + "your application as needed. \n" + "*****************************************", + omp_num_threads, + ) + # This env variable will be passed down to the subprocesses + os.environ["OMP_NUM_THREADS"] = str(omp_num_threads) + + log_line_prefix_template = os.getenv("TORCHELASTIC_LOG_LINE_PREFIX_TEMPLATE") + + rdzv_configs = _parse_rendezvous_config(args.rdzv_conf) + + if args.rdzv_backend == "static": + rdzv_configs["rank"] = args.node_rank + + rdzv_endpoint = get_rdzv_endpoint(args) + + ranks: Optional[Set[int]] = None + if args.local_ranks_filter: + try: + ranks = set(map(int, args.local_ranks_filter.split(","))) + assert ranks + except Exception as e: + raise ValueError( + "--local_ranks_filter must be a comma-separated list of integers e.g. --local_ranks_filter=0,1,2" + ) from e + + logs_specs_cls: Type[LogsSpecs] = _get_logs_specs_class(args.logs_specs) + logs_specs = logs_specs_cls( + log_dir=args.log_dir, + redirects=Std.from_str(args.redirects), + tee=Std.from_str(args.tee), + local_ranks_filter=ranks, + ) + + config = LaunchConfig( + min_nodes=min_nodes, + max_nodes=max_nodes, + nproc_per_node=nproc_per_node, + run_id=args.rdzv_id, + role=args.role, + rdzv_endpoint=rdzv_endpoint, + rdzv_backend=args.rdzv_backend, + rdzv_configs=rdzv_configs, + max_restarts=args.max_restarts, + monitor_interval=args.monitor_interval, + start_method=args.start_method, + log_line_prefix_template=log_line_prefix_template, + local_addr=args.local_addr, + logs_specs=logs_specs, + ) + + with_python = not args.no_python + cmd: Union[Callable, str] + cmd_args = [] + use_env = get_use_env(args) + if args.run_path: + cmd = run_script_path + cmd_args.append(args.training_script) + else: + if with_python: + cmd = os.getenv("PYTHON_EXEC", sys.executable) + cmd_args.append("-u") + if args.module: + cmd_args.append("-m") + cmd_args.append(args.training_script) + else: + if args.module: + raise ValueError( + "Don't use both the '--no-python' flag" + " and the '--module' flag at the same time." + ) + cmd = args.training_script + if not use_env: + cmd_args.append(f"--local-rank={macros.local_rank}") + cmd_args.extend(args.training_script_args) + + return config, cmd, cmd_args + + +def run_script_path(training_script: str, *training_script_args: str): + """ + Run the provided `training_script` from within this interpreter. + + Usage: `script_as_function("/abs/path/to/script.py", "--arg1", "val1")` + """ + import runpy + import sys + + sys.argv = [training_script] + [*training_script_args] + runpy.run_path(sys.argv[0], run_name="__main__") + + +def run(args): + core.multiprocessing._set_thread_name("pt_elastic") + + if args.standalone: + args.rdzv_backend = "c10d" + args.rdzv_endpoint = "localhost:0" + args.rdzv_id = str(uuid.uuid4()) + logger.info( + "\n**************************************\n" + "Rendezvous info:\n" + "--rdzv-backend=%s " + "--rdzv-endpoint=%s " + "--rdzv-id=%s\n" + "**************************************\n", + args.rdzv_backend, + args.rdzv_endpoint, + args.rdzv_id, + ) + + config, cmd, cmd_args = config_from_args(args) + elastic_launch( + config=config, + entrypoint=cmd, + )(*cmd_args) + + +@record +def main(args=None): + args = parse_args(args) + run(args) + + +if __name__ == "__main__": + main() diff --git a/mindnlp/core/distributed/tensor/__init__.py b/mindnlp/core/distributed/tensor/__init__.py new file mode 100644 index 000000000..82f67afae --- /dev/null +++ b/mindnlp/core/distributed/tensor/__init__.py @@ -0,0 +1,4 @@ +Replicate = None +DTensor = None +Placement = None +Shard = None diff --git a/mindnlp/core/distributed/tensor/parallel/__init__.py b/mindnlp/core/distributed/tensor/parallel/__init__.py new file mode 100644 index 000000000..4f316b054 --- /dev/null +++ b/mindnlp/core/distributed/tensor/parallel/__init__.py @@ -0,0 +1,2 @@ +ColwiseParallel = None +RowwiseParallel = None diff --git a/mindnlp/core/distributed/utils.py b/mindnlp/core/distributed/utils.py new file mode 100644 index 000000000..d53051b6e --- /dev/null +++ b/mindnlp/core/distributed/utils.py @@ -0,0 +1,389 @@ +# mypy: allow-untyped-defs +import dataclasses +import traceback +from typing import ( + Any, + Callable, + Container, + Dict, + List, + Optional, + OrderedDict, + overload, + Set, + Tuple, + TypeVar, +) + +from mindnlp import core +from mindnlp import core.distributed as dist +from mindnlp.core import nn +from core.nn.utils.rnn import PackedSequence + + +__all__ = [] # type: ignore[var-annotated] + + +def _pack_kwargs(*args: Any, **kwargs: Any) -> Tuple[Tuple[Any, ...], Tuple[str, ...]]: + """ + Turn argument list into separate key list and value list (unpack_kwargs does the opposite). + + Inspiration: https://github.com/facebookresearch/fairscale/blob/eeb6684/fairscale/internal/containers.py#L70 + Usage:: + + kwarg_keys, flat_args = pack_kwargs(1, 2, a=3, b=4) + assert kwarg_keys == ("a", "b") + assert flat_args == (1, 2, 3, 4) + args, kwargs = unpack_kwargs(kwarg_keys, flat_args) + assert args == (1, 2) + assert kwargs == {"a": 3, "b": 4} + Returns: + Tuple[Tuple[Any, ...], Tuple[str, ...]]: The first tuple element gives + gives both positional args and kwarg values, where the positional args + proceed kwarg values and kwarg values are ordered consistently with the + kwarg keys. The second tuple element gives the kwarg keys. + The second tuple element's length is at most the first tuple element's length. + """ + kwarg_keys: List[str] = [] + flat_args: List[Any] = list(args) + for k, v in kwargs.items(): + kwarg_keys.append(k) + flat_args.append(v) + + return tuple(flat_args), tuple(kwarg_keys) + + +def _cast_forward_inputs( + dtype: Optional[core.dtype], + *args: Any, + **kwargs: Any, +) -> Tuple[Any, Any]: + """ + Cast floating point tensors in ``args`` and ``kwargs`` to ``input_dtype``. + + This respects the existing ``requires_grad`` on the tensors. + """ + if dtype is None: + return args, kwargs + + def cast_fn(x: core.Tensor) -> core.Tensor: + if not core.is_floating_point(x) or x.dtype == dtype: + return x + return x.to(dtype) + + return (_apply_to_tensors(cast_fn, args), _apply_to_tensors(cast_fn, kwargs)) + + +def _unpack_kwargs( + flat_args: Tuple[Any, ...], kwarg_keys: Tuple[str, ...] +) -> Tuple[Tuple[Any, ...], Dict[str, Any]]: + """See _pack_kwargs.""" + assert len(kwarg_keys) <= len( + flat_args + ), f"too many keys {len(kwarg_keys)} vs. {len(flat_args)}" + if len(kwarg_keys) == 0: + return flat_args, {} + args = flat_args[: -len(kwarg_keys)] + kwargs = dict(zip(kwarg_keys, flat_args[-len(kwarg_keys) :])) + return args, kwargs + + +S = TypeVar("S", dict, list, tuple) +T = TypeVar("T", core.Tensor, PackedSequence) + + +@overload +def _recursive_to( + inputs: S, target_device: core.device, use_side_stream_for_tensor_copies: bool +) -> List[S]: + ... + + +@overload +def _recursive_to( + inputs: T, target_device: core.device, use_side_stream_for_tensor_copies: bool +) -> Tuple[T]: + ... + + +def _recursive_to(inputs, target_device, use_side_stream_for_tensor_copies): + r"""Recursively moves input to the target_device.""" + + def to_map(obj): + if isinstance(obj, (core.Tensor, PackedSequence)): + device = obj.data.device if isinstance(obj, PackedSequence) else obj.device + if device == target_device: + return (obj,) + if not use_side_stream_for_tensor_copies: + return (obj.to(target_device),) + else: + # If the custom module is not registered to torch, stream is not used for acceleration + device_mod = getattr(torch, device.type, None) + if device.type == "cpu" or device_mod is None: + return (obj.to(target_device),) + + from core.nn.parallel._functions import _get_stream + + # Perform CPU -> target_device copies in a background stream. This code is + # motivated from similar logic in torch/nn/parallel/_functions.py + stream = _get_stream(target_device) + with device_mod.stream(stream): + output = obj.to(target_device) + # synchronize with the copy stream + with device_mod.device(target_device.index): + current_stream = device_mod.current_stream() + # Sync the current stream with the copy stream + current_stream.wait_stream(stream) + # Ensure tensor memory is not reused until work on + # main stream is complete + if isinstance(obj, PackedSequence): + output.data.record_stream(current_stream) # type: ignore[arg-type] + else: + assert isinstance(output, core.Tensor) + output.record_stream(current_stream) # type: ignore[arg-type] + return (output,) + + from core.nn.parallel.scatter_gather import _is_namedtuple + + if _is_namedtuple(obj): + return [type(obj)(*args) for args in zip(*map(to_map, obj))] + if isinstance(obj, tuple) and len(obj) > 0: + return list(zip(*map(to_map, obj))) + if isinstance(obj, list) and len(obj) > 0: + return [list(i) for i in zip(*map(to_map, obj))] + if isinstance(obj, dict) and len(obj) > 0: + return [type(obj)(i) for i in zip(*map(to_map, obj.items()))] + return [obj] + + # Avoid reference cycle + try: + res = to_map(inputs) + finally: + to_map = None # type: ignore[assignment] + return res + + +def _p_assert(cond: Any, s: str, raise_assertion_error: bool = True) -> None: + """Alternate to ``assert`` when in the backward context to print the error message ``s`` since otherwise, it is swallowed.""" + if not cond: + print(s) + traceback.print_stack() + if raise_assertion_error: + raise AssertionError(s) + + +def _alloc_storage(tensor: core.Tensor, size: core.Size) -> None: + """ + Allocate storage for ``tensor`` with the given size. + + Returns: + bool: ``True`` if this method allocated storage and ``False`` if the + storage was already allocated. + """ + with core.no_grad(): + if not core.distributed._functional_collectives.is_torchdynamo_compiling(): + already_allocated = tensor._typed_storage()._size() == size.numel() + if not already_allocated: + tensor_storage_size = tensor._typed_storage()._size() + _p_assert( + tensor_storage_size == 0, + "Tensor storage should have been resized to be 0 but got PLACEHOLDEr", + ) + tensor._typed_storage()._resize_(size.numel()) + + +def _free_storage(tensor: core.Tensor): + """ + Frees the underlying storage of ``tensor``. + + Returns: + bool: ``True`` if the method freed the storage and ``False`` if the + storage was already freed. + """ + with core.no_grad(): + if not core.distributed._functional_collectives.is_torchdynamo_compiling(): + already_freed = tensor._typed_storage()._size() == 0 + if not already_freed: + _p_assert( + tensor.storage_offset() == 0, + "Freeing a tensor's storage is unsafe when it is not the sole occupant\n" + f"storage offset: {tensor.storage_offset()}\n" + f"storage size: {tensor._typed_storage()._size()}\n" + f"tensor shape: {tensor.shape}", + ) + tensor._typed_storage()._resize_(0) + + +Q = TypeVar("Q") +R = TypeVar("R", dict, list, tuple, set, OrderedDict, PackedSequence, Any) + + +@overload +def _apply_to_tensors(fn: Callable[[core.Tensor], Q], container: core.Tensor) -> Q: + ... + + +@overload +def _apply_to_tensors(fn: Callable[[core.Tensor], Any], container: R) -> R: + ... + + +def _apply_to_tensors(fn, container): + """Recursively apply to all tensor in different kinds of container types.""" + + def apply(x): + from core.nn.parallel.scatter_gather import _is_namedtuple + + if isinstance(x, core.Tensor): + return fn(x) + elif hasattr(x, "__dataclass_fields__"): + dc = dataclasses.replace(x) + changes = { + f.name: apply(getattr(dc, f.name)) for f in dataclasses.fields(dc) + } + return dataclasses.replace(dc, **changes) + elif isinstance(x, OrderedDict): + od = x.__class__() + for key, value in x.items(): + od[key] = apply(value) + return od + elif isinstance(x, PackedSequence): + apply(x.data) + return x + elif isinstance(x, dict): + return {key: apply(value) for key, value in x.items()} + elif _is_namedtuple(x): + res = (apply(el) for el in x) + return type(x)(*res) + elif isinstance(x, (list, tuple, set)): + return type(x)(apply(el) for el in x) + else: + return x + + return apply(container) + + +def _to_kwargs( + inputs: Tuple[Any, ...], + kwargs: Optional[Dict[str, Any]], + target_device: core.device, + use_side_stream_for_tensor_copies: bool, +) -> Tuple[Tuple[Any, ...], Tuple[Dict[str, Any], ...]]: + moved_inputs = ( + _recursive_to(inputs, target_device, use_side_stream_for_tensor_copies) + if inputs + else [] + ) + moved_kwargs = ( + _recursive_to(kwargs, target_device, use_side_stream_for_tensor_copies) + if kwargs + else [] + ) + if len(moved_inputs) < len(moved_kwargs): + moved_inputs.extend([() for _ in range(len(moved_kwargs) - len(inputs))]) + elif len(moved_kwargs) < len(moved_inputs): + moved_kwargs.extend([{} for _ in range(len(moved_inputs) - len(moved_kwargs))]) + return tuple(moved_inputs), tuple(moved_kwargs) + + +def _verify_param_shape_across_processes( + process_group: dist.ProcessGroup, + tensors: List[core.Tensor], + logger: Optional["dist.Logger"] = None, +): + return dist._verify_params_across_processes(process_group, tensors, logger) + + +def _sync_module_states( + module: nn.Module, + process_group: dist.ProcessGroup, + broadcast_bucket_size: int, + src: int, + params_and_buffers_to_ignore: Container[str], + broadcast_buffers: bool = True, +) -> None: + """ + Sync ``module``'s parameters and buffers state. + + Syncs ``module``'s parameters and buffers state so that all ranks contain + the same module state across all ranks. Note that this API assumes that all + parameter shapes are consistent before running the synchronization. This can + be checked with ``_verify_param_shape_across_processes``. + """ + module_states: List[core.Tensor] = [] + for name, param in module.named_parameters(): + if name not in params_and_buffers_to_ignore: + module_states.append(param.detach()) + + if broadcast_buffers: + for name, buffer in module.named_buffers(): + if name not in params_and_buffers_to_ignore: + module_states.append(buffer.detach()) + + _sync_params_and_buffers(process_group, module_states, broadcast_bucket_size, src) + + +def _sync_params_and_buffers( + process_group: dist.ProcessGroup, + module_states: List[core.Tensor], + broadcast_bucket_size: int, + src: int, +) -> None: + """Synchronize ``module_states`` (list of tensors) across all processes by broadcasting them from rank 0.""" + if len(module_states) > 0: + dist._broadcast_coalesced( + process_group, module_states, broadcast_bucket_size, src + ) + + +def _replace_by_prefix( + state_dict: Dict[str, Any], + old_prefix: str, + new_prefix: str, +) -> None: + """ + Replace all keys that match a given old_prefix with a new_prefix (in-place). + + Usage:: + + state_dict = {"layer.xyz": core.tensor(1)} + replace_by_prefix_(state_dict, "layer.", "module.layer.") + assert state_dict == {"module.layer.xyz": core.tensor(1)} + """ + if old_prefix == new_prefix: + raise ValueError("old_prefix and new_prefix must be distinct") + for key in list(state_dict.keys()): + if not key.startswith(old_prefix): + continue + new_key = new_prefix + key[len(old_prefix) :] + state_dict[new_key] = state_dict[key] + del state_dict[key] + + +def _data_ptr_allocated(tensor: core.Tensor) -> bool: + return tensor.untyped_storage().data_ptr() > 0 + + +def _get_root_modules(modules: List[nn.Module]) -> List[nn.Module]: + """ + Returns the modules in ``modules`` that are root modules (i.e. + parent-less) with respect to the set ``modules``. In other words, these + are the modules in ``modules`` that are the not child of any other + module in ``modules``. + """ + root_modules: List[nn.Module] = [] + module_to_modules: Dict[nn.Module, Set[nn.Module]] = { + module: set(module.modules()) for module in modules + } + for candidate_module in modules: + is_root_module = True + for module, _modules in module_to_modules.items(): + is_child_module = ( + candidate_module is not module and candidate_module in _modules + ) + if is_child_module: + is_root_module = False + break + if is_root_module: + root_modules.append(candidate_module) + return root_modules diff --git a/mindnlp/core/distributions/__init__.py b/mindnlp/core/distributions/__init__.py new file mode 100644 index 000000000..f012e1020 --- /dev/null +++ b/mindnlp/core/distributions/__init__.py @@ -0,0 +1,12 @@ +"""distributions""" +from .bernoulli import Bernoulli +from .categorical import Categorical +from .distribution import Distribution +from .independent import Independent +from .negative_binomial import NegativeBinomial +from .normal import Normal +from .studentT import StudentT +from .transformed_distribution import TransformedDistribution +from .transforms import * +from .relaxed_categorical import * +from .relaxed_bernoulli import * diff --git a/mindnlp/core/distributions/bernoulli.py b/mindnlp/core/distributions/bernoulli.py new file mode 100644 index 000000000..8e333e62f --- /dev/null +++ b/mindnlp/core/distributions/bernoulli.py @@ -0,0 +1,134 @@ +"""bernoulli""" +# mypy: allow-untyped-defs +# pylint: disable=method-hidden +from numbers import Number + +from math import nan +from . import constraints +from .exp_family import ExponentialFamily +from .utils import ( + broadcast_all, + lazy_property, + logits_to_probs, + probs_to_logits, +) +from .. import ops +from ..autograd import no_grad +from ..nn.functional import binary_cross_entropy_with_logits + + +__all__ = ["Bernoulli"] + + +class Bernoulli(ExponentialFamily): + r""" + Creates a Bernoulli distribution parameterized by :attr:`probs` + or :attr:`logits` (but not both). + + Samples are binary (0 or 1). They take the value `1` with probability `p` + and `0` with probability `1 - p`. + + Example:: + + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> m = Bernoulli(core.tensor([0.3])) + >>> m.sample() # 30% chance 1; 70% chance 0 + tensor([ 0.]) + + Args: + probs (Number, Tensor): the probability of sampling `1` + logits (Number, Tensor): the log-odds of sampling `1` + """ + arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real} + support = constraints.boolean + has_enumerate_support = True + _mean_carrier_measure = 0 + + def __init__(self, probs=None, logits=None, validate_args=None): + if (probs is None) == (logits is None): + raise ValueError( + "Either `probs` or `logits` must be specified, but not both." + ) + if probs is not None: + is_scalar = isinstance(probs, Number) + (self.probs,) = broadcast_all(probs) + else: + is_scalar = isinstance(logits, Number) + (self.logits,) = broadcast_all(logits) + self._param = self.probs if probs is not None else self.logits + if is_scalar: + batch_shape = () + else: + batch_shape = self._param.shape + super().__init__(batch_shape, validate_args=validate_args) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(Bernoulli, _instance) + if "probs" in self.__dict__: + new.probs = self.probs.broadcast_to(batch_shape) + new._param = new.probs + if "logits" in self.__dict__: + new.logits = self.logits.broadcast_to(batch_shape) + new._param = new.logits + super(Bernoulli, new).__init__(batch_shape, validate_args=False) + new._validate_args = self._validate_args + return new + + def _new(self, *args, **kwargs): + return self._param.new(*args, **kwargs) + + @property + def mean(self): + return self.probs + + @property + def mode(self): + mode = (self.probs >= 0.5).to(self.probs) + mode[self.probs == 0.5] = nan + return mode + + @property + def variance(self): + return self.probs * (1 - self.probs) + + @lazy_property + def logits(self): + return probs_to_logits(self.probs, is_binary=True) + + @lazy_property + def probs(self): + return logits_to_probs(self.logits, is_binary=True) + + @property + def param_shape(self): + return self._param.shape + + def sample(self, sample_shape=()): + shape = self._extended_shape(sample_shape) + with no_grad(): + return ops.bernoulli(self.probs.broadcast_to(shape)) + + def log_prob(self, value): + if self._validate_args: + self._validate_sample(value) + logits, value = broadcast_all(self.logits, value) + return -binary_cross_entropy_with_logits(logits, value, reduction="none") + + def entropy(self): + return binary_cross_entropy_with_logits( + self.logits, self.probs, reduction="none" + ) + + def enumerate_support(self, expand=True): + values = ops.arange(2, dtype=self._param.dtype) + values = values.view((-1,) + (1,) * len(self._batch_shape)) + if expand: + values = values.broadcast_to((-1,) + self._batch_shape) + return values + + @property + def _natural_params(self): + return (ops.logit(self.probs),) + + def _log_normalizer(self, x): + return ops.log1p(ops.exp(x)) diff --git a/mindnlp/core/distributions/categorical.py b/mindnlp/core/distributions/categorical.py new file mode 100644 index 000000000..549cd0799 --- /dev/null +++ b/mindnlp/core/distributions/categorical.py @@ -0,0 +1,157 @@ +"""categorical""" +# mypy: allow-untyped-defs +# pylint: disable=method-hidden, invalid-overridden-method +from math import nan +import mindspore +from .. import ops +from . import constraints +from .distribution import Distribution +from .utils import lazy_property, logits_to_probs, probs_to_logits + + +__all__ = ["Categorical"] + + +class Categorical(Distribution): + r""" + Creates a categorical distribution parameterized by either :attr:`probs` or + :attr:`logits` (but not both). + + .. note:: + It is equivalent to the distribution that :func:`core.multinomial` + samples from. + + Samples are integers from :math:`\{0, \ldots, K-1\}` where `K` is ``probs.size(-1)``. + + If `probs` is 1-dimensional with length-`K`, each element is the relative probability + of sampling the class at that index. + + If `probs` is N-dimensional, the first N-1 dimensions are treated as a batch of + relative probability vectors. + + .. note:: The `probs` argument must be non-negative, finite and have a non-zero sum, + and it will be normalized to sum to 1 along the last dimension. :attr:`probs` + will return this normalized value. + The `logits` argument will be interpreted as unnormalized log probabilities + and can therefore be any real number. It will likewise be normalized so that + the resulting probabilities sum to 1 along the last dimension. :attr:`logits` + will return this normalized value. + + See also: :func:`core.multinomial` + + Example:: + + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> m = Categorical(core.tensor([ 0.25, 0.25, 0.25, 0.25 ])) + >>> m.sample() # equal probability of 0, 1, 2, 3 + tensor(3) + + Args: + probs (Tensor): event probabilities + logits (Tensor): event log probabilities (unnormalized) + """ + arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector} + has_enumerate_support = True + + def __init__(self, probs=None, logits=None, validate_args=None): + if (probs is None) == (logits is None): + raise ValueError( + "Either `probs` or `logits` must be specified, but not both." + ) + if probs is not None: + if probs.ndim < 1: + raise ValueError("`probs` parameter must be at least one-dimensional.") + self.probs = probs / probs.sum(-1, keepdim=True) + else: + if logits.ndim < 1: + raise ValueError("`logits` parameter must be at least one-dimensional.") + # Normalize + self.logits = logits - ops.logsumexp(logits, dim=-1, keepdim=True) + self._param = self.probs if probs is not None else self.logits + self._num_events = self._param.shape[-1] + batch_shape = ( + self._param.shape[:-1] if self._param.ndimension() > 1 else () + ) + super().__init__(batch_shape, validate_args=validate_args) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(Categorical, _instance) + param_shape = batch_shape + (self._num_events,) + if "probs" in self.__dict__: + new.probs = self.probs.broadcast_to(param_shape) + new._param = new.probs + if "logits" in self.__dict__: + new.logits = self.logits.broadcast_to(param_shape) + new._param = new.logits + new._num_events = self._num_events + super(Categorical, new).__init__(batch_shape, validate_args=False) + new._validate_args = self._validate_args + return new + + def _new(self, *args, **kwargs): + return self._param.new(*args, **kwargs) + + @constraints.dependent_property(is_discrete=True, event_dim=0) + def support(self): + return constraints.integer_interval(0, self._num_events - 1) + + @lazy_property + def logits(self): + return probs_to_logits(self.probs) + + @lazy_property + def probs(self): + return logits_to_probs(self.logits) + + @property + def param_shape(self): + return self._param.shape + + @property + def mean(self): + return ops.full( + self._extended_shape(), + nan, + dtype=self.probs.dtype, + ) + + @property + def mode(self): + return self.probs.argmax(axis=-1) + + @property + def variance(self): + return ops.full( + self._extended_shape(), + nan, + dtype=self.probs.dtype, + ) + + def sample(self, sample_shape=()): + if not isinstance(sample_shape, tuple): + sample_shape = tuple(sample_shape) + probs_2d = self.probs.reshape(-1, self._num_events) + samples_2d = ops.multinomial(probs_2d, sample_shape.numel(), True).T + return samples_2d.reshape(self._extended_shape(sample_shape)) + + def log_prob(self, value): + if self._validate_args: + self._validate_sample(value) + value = value.long().unsqueeze(-1) + value, log_pmf = ops.broadcast_tensors(value, self.logits) + value = value[..., :1] + return ops.gather(log_pmf, -1, value).squeeze(-1) + + def entropy(self): + min_real = ops.finfo(self.logits.dtype).min + logits = ops.clamp(self.logits, min=min_real) + p_log_p = logits * self.probs + return -p_log_p.sum(-1) + + def enumerate_support(self, expand=True): + num_events = self._num_events + values = ops.arange(num_events, dtype=mindspore.int64) + values = values.view((-1,) + (1,) * len(self._batch_shape)) + if expand: + values = values.broadcast_to((-1,) + self._batch_shape) + return values diff --git a/mindnlp/core/distributions/chi2.py b/mindnlp/core/distributions/chi2.py new file mode 100644 index 000000000..e3aa5d333 --- /dev/null +++ b/mindnlp/core/distributions/chi2.py @@ -0,0 +1,36 @@ +"""chi2""" +# mypy: allow-untyped-defs +from . import constraints +from .gamma import Gamma + + +__all__ = ["Chi2"] + + +class Chi2(Gamma): + r""" + Creates a Chi-squared distribution parameterized by shape parameter :attr:`df`. + This is exactly equivalent to ``Gamma(alpha=0.5*df, beta=0.5)`` + + Example:: + + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> m = Chi2(core.tensor([1.0])) + >>> m.sample() # Chi2 distributed with shape df=1 + tensor([ 0.1046]) + + Args: + df (float or Tensor): shape parameter of the distribution + """ + arg_constraints = {"df": constraints.positive} + + def __init__(self, df, validate_args=None): + super().__init__(0.5 * df, 0.5, validate_args=validate_args) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(Chi2, _instance) + return super().expand(batch_shape, new) + + @property + def df(self): + return self.concentration * 2 diff --git a/mindnlp/core/distributions/constraints.py b/mindnlp/core/distributions/constraints.py new file mode 100644 index 000000000..460e83182 --- /dev/null +++ b/mindnlp/core/distributions/constraints.py @@ -0,0 +1,671 @@ +# mypy: allow-untyped-defs +r""" +The following constraints are implemented: + +- ``constraints.boolean`` +- ``constraints.cat`` +- ``constraints.corr_cholesky`` +- ``constraints.dependent`` +- ``constraints.greater_than(lower_bound)`` +- ``constraints.greater_than_eq(lower_bound)`` +- ``constraints.independent(constraint, reinterpreted_batch_ndims)`` +- ``constraints.integer_interval(lower_bound, upper_bound)`` +- ``constraints.interval(lower_bound, upper_bound)`` +- ``constraints.less_than(upper_bound)`` +- ``constraints.lower_cholesky`` +- ``constraints.lower_triangular`` +- ``constraints.multinomial`` +- ``constraints.nonnegative`` +- ``constraints.nonnegative_integer`` +- ``constraints.one_hot`` +- ``constraints.positive_integer`` +- ``constraints.positive`` +- ``constraints.positive_semidefinite`` +- ``constraints.positive_definite`` +- ``constraints.real_vector`` +- ``constraints.real`` +- ``constraints.simplex`` +- ``constraints.symmetric`` +- ``constraints.stack`` +- ``constraints.square`` +- ``constraints.symmetric`` +- ``constraints.unit_interval`` +""" + +import mindspore +from .. import ops + + +__all__ = [ + "Constraint", + "boolean", + "cat", + "corr_cholesky", + "dependent", + "dependent_property", + "greater_than", + "greater_than_eq", + "independent", + "integer_interval", + "interval", + "half_open_interval", + "is_dependent", + "less_than", + "lower_cholesky", + "lower_triangular", + "multinomial", + "nonnegative", + "nonnegative_integer", + "one_hot", + "positive", + "positive_semidefinite", + "positive_definite", + "positive_integer", + "real", + "real_vector", + "simplex", + "square", + "stack", + "symmetric", + "unit_interval", +] + + +class Constraint: + """ + Abstract base class for constraints. + + A constraint object represents a region over which a variable is valid, + e.g. within which a variable can be optimized. + + Attributes: + is_discrete (bool): Whether constrained space is discrete. + Defaults to False. + event_dim (int): Number of rightmost dimensions that together define + an event. The :meth:`check` method will remove this many dimensions + when computing validity. + """ + + is_discrete = False # Default to continuous. + event_dim = 0 # Default to univariate. + + def check(self, value): + """ + Returns a byte tensor of ``sample_shape + batch_shape`` indicating + whether each event in value satisfies this constraint. + """ + raise NotImplementedError + + def __repr__(self): + return self.__class__.__name__[1:] + "()" + + +class _Dependent(Constraint): + """ + Placeholder for variables whose support depends on other variables. + These variables obey no simple coordinate-wise constraints. + + Args: + is_discrete (bool): Optional value of ``.is_discrete`` in case this + can be computed statically. If not provided, access to the + ``.is_discrete`` attribute will raise a NotImplementedError. + event_dim (int): Optional value of ``.event_dim`` in case this + can be computed statically. If not provided, access to the + ``.event_dim`` attribute will raise a NotImplementedError. + """ + + def __init__(self, *, is_discrete=NotImplemented, event_dim=NotImplemented): + self._is_discrete = is_discrete + self._event_dim = event_dim + super().__init__() + + @property + def is_discrete(self): + if self._is_discrete is NotImplemented: + raise NotImplementedError(".is_discrete cannot be determined statically") + return self._is_discrete + + @property + def event_dim(self): + if self._event_dim is NotImplemented: + raise NotImplementedError(".event_dim cannot be determined statically") + return self._event_dim + + def __call__(self, *, is_discrete=NotImplemented, event_dim=NotImplemented): + """ + Support for syntax to customize static attributes:: + + constraints.dependent(is_discrete=True, event_dim=1) + """ + if is_discrete is NotImplemented: + is_discrete = self._is_discrete + if event_dim is NotImplemented: + event_dim = self._event_dim + return _Dependent(is_discrete=is_discrete, event_dim=event_dim) + + def check(self, x): + raise ValueError("Cannot determine validity of dependent constraint") + + +def is_dependent(constraint): + """ + Checks if ``constraint`` is a ``_Dependent`` object. + + Args: + constraint : A ``Constraint`` object. + + Returns: + ``bool``: True if ``constraint`` can be refined to the type ``_Dependent``, False otherwise. + + """ + return isinstance(constraint, _Dependent) + + +class _DependentProperty(property, _Dependent): + """ + Decorator that extends @property to act like a `Dependent` constraint when + called on a class and act like a property when called on an object. + + Example:: + + class Uniform(Distribution): + def __init__(self, low, high): + self.low = low + self.high = high + @constraints.dependent_property(is_discrete=False, event_dim=0) + def support(self): + return constraints.interval(self.low, self.high) + + Args: + fn (Callable): The function to be decorated. + is_discrete (bool): Optional value of ``.is_discrete`` in case this + can be computed statically. If not provided, access to the + ``.is_discrete`` attribute will raise a NotImplementedError. + event_dim (int): Optional value of ``.event_dim`` in case this + can be computed statically. If not provided, access to the + ``.event_dim`` attribute will raise a NotImplementedError. + """ + + def __init__( + self, fn=None, *, is_discrete=NotImplemented, event_dim=NotImplemented + ): + super().__init__(fn) + self._is_discrete = is_discrete + self._event_dim = event_dim + + def __call__(self, fn): + """ + Support for syntax to customize static attributes:: + + @constraints.dependent_property(is_discrete=True, event_dim=1) + def support(self): + ... + """ + return _DependentProperty( + fn, is_discrete=self._is_discrete, event_dim=self._event_dim + ) + + +class _IndependentConstraint(Constraint): + """ + Wraps a constraint by aggregating over ``reinterpreted_batch_ndims``-many + dims in :meth:`check`, so that an event is valid only if all its + independent entries are valid. + """ + + def __init__(self, base_constraint, reinterpreted_batch_ndims): + assert isinstance(base_constraint, Constraint) + assert isinstance(reinterpreted_batch_ndims, int) + assert reinterpreted_batch_ndims >= 0 + self.base_constraint = base_constraint + self.reinterpreted_batch_ndims = reinterpreted_batch_ndims + super().__init__() + + @property + def is_discrete(self): + return self.base_constraint.is_discrete + + @property + def event_dim(self): + return self.base_constraint.event_dim + self.reinterpreted_batch_ndims + + def check(self, value): + result = self.base_constraint.check(value) + if result.ndim < self.reinterpreted_batch_ndims: + expected = self.base_constraint.event_dim + self.reinterpreted_batch_ndims + raise ValueError( + f"Expected value.ndim >= {expected} but got {value.ndim}" + ) + result = result.reshape( + result.shape[: result.ndim - self.reinterpreted_batch_ndims] + (-1,) + ) + result = result.all(-1) + return result + + def __repr__(self): + return f"{self.__class__.__name__[1:]}({repr(self.base_constraint)}, {self.reinterpreted_batch_ndims})" + + +class _Boolean(Constraint): + """ + Constrain to the two values `{0, 1}`. + """ + + is_discrete = True + + def check(self, value): + out = ((value == 0).int() | (value == 1).int()).bool() + return out + + +class _OneHot(Constraint): + """ + Constrain to one-hot vectors. + """ + + is_discrete = True + event_dim = 1 + + def check(self, value): + is_boolean = (value == 0) | (value == 1) + is_normalized = value.sum(-1).eq(1) + return is_boolean.all(-1) & is_normalized + + +class _IntegerInterval(Constraint): + """ + Constrain to an integer interval `[lower_bound, upper_bound]`. + """ + + is_discrete = True + + def __init__(self, lower_bound, upper_bound): + self.lower_bound = lower_bound + self.upper_bound = upper_bound + super().__init__() + + def check(self, value): + return ( + (value % 1 == 0).int() & (self.lower_bound <= value).int() & (value <= self.upper_bound).int() + ).bool() + + def __repr__(self): + fmt_string = self.__class__.__name__[1:] + fmt_string += ( + f"(lower_bound={self.lower_bound}, upper_bound={self.upper_bound})" + ) + return fmt_string + + +class _IntegerLessThan(Constraint): + """ + Constrain to an integer interval `(-inf, upper_bound]`. + """ + + is_discrete = True + + def __init__(self, upper_bound): + self.upper_bound = upper_bound + super().__init__() + + def check(self, value): + return (value % 1 == 0) & (value <= self.upper_bound) + + def __repr__(self): + fmt_string = self.__class__.__name__[1:] + fmt_string += f"(upper_bound={self.upper_bound})" + return fmt_string + + +class _IntegerGreaterThan(Constraint): + """ + Constrain to an integer interval `[lower_bound, inf)`. + """ + + is_discrete = True + + def __init__(self, lower_bound): + self.lower_bound = lower_bound + super().__init__() + + def check(self, value): + return (value % 1 == 0) & (value >= self.lower_bound) + + def __repr__(self): + fmt_string = self.__class__.__name__[1:] + fmt_string += f"(lower_bound={self.lower_bound})" + return fmt_string + + +class _Real(Constraint): + """ + Trivially constrain to the extended real line `[-inf, inf]`. + """ + + def check(self, value): + # False for NANs. + return value == value # pylint: disable=comparison-with-itself + + +class _GreaterThan(Constraint): + """ + Constrain to a real half line `(lower_bound, inf]`. + """ + + def __init__(self, lower_bound): + self.lower_bound = lower_bound + super().__init__() + + def check(self, value): + return self.lower_bound < value + + def __repr__(self): + fmt_string = self.__class__.__name__[1:] + fmt_string += f"(lower_bound={self.lower_bound})" + return fmt_string + + +class _GreaterThanEq(Constraint): + """ + Constrain to a real half line `[lower_bound, inf)`. + """ + + def __init__(self, lower_bound): + self.lower_bound = lower_bound + super().__init__() + + def check(self, value): + return self.lower_bound <= value + + def __repr__(self): + fmt_string = self.__class__.__name__[1:] + fmt_string += f"(lower_bound={self.lower_bound})" + return fmt_string + + +class _LessThan(Constraint): + """ + Constrain to a real half line `[-inf, upper_bound)`. + """ + + def __init__(self, upper_bound): + self.upper_bound = upper_bound + super().__init__() + + def check(self, value): + return value < self.upper_bound + + def __repr__(self): + fmt_string = self.__class__.__name__[1:] + fmt_string += f"(upper_bound={self.upper_bound})" + return fmt_string + + +class _Interval(Constraint): + """ + Constrain to a real interval `[lower_bound, upper_bound]`. + """ + + def __init__(self, lower_bound, upper_bound): + self.lower_bound = lower_bound + self.upper_bound = upper_bound + super().__init__() + + def check(self, value): + return (self.lower_bound <= value) & (value <= self.upper_bound) + + def __repr__(self): + fmt_string = self.__class__.__name__[1:] + fmt_string += ( + f"(lower_bound={self.lower_bound}, upper_bound={self.upper_bound})" + ) + return fmt_string + + +class _HalfOpenInterval(Constraint): + """ + Constrain to a real interval `[lower_bound, upper_bound)`. + """ + + def __init__(self, lower_bound, upper_bound): + self.lower_bound = lower_bound + self.upper_bound = upper_bound + super().__init__() + + def check(self, value): + return (self.lower_bound <= value) & (value < self.upper_bound) + + def __repr__(self): + fmt_string = self.__class__.__name__[1:] + fmt_string += ( + f"(lower_bound={self.lower_bound}, upper_bound={self.upper_bound})" + ) + return fmt_string + + +class _Simplex(Constraint): + """ + Constrain to the unit simplex in the innermost (rightmost) dimension. + Specifically: `x >= 0` and `x.sum(-1) == 1`. + """ + + event_dim = 1 + + def check(self, value): + return ops.all(value >= 0, dim=-1) & ((value.sum(-1) - 1).abs() < 1e-6) + + +class _Multinomial(Constraint): + """ + Constrain to nonnegative integer values summing to at most an upper bound. + + Note due to limitations of the Multinomial distribution, this currently + checks the weaker condition ``value.sum(-1) <= upper_bound``. In the future + this may be strengthened to ``value.sum(-1) == upper_bound``. + """ + + is_discrete = True + event_dim = 1 + + def __init__(self, upper_bound): + self.upper_bound = upper_bound + + def check(self, x): + return (x >= 0).all(dim=-1) & (x.sum(dim=-1) <= self.upper_bound) + + +class _LowerTriangular(Constraint): + """ + Constrain to lower-triangular square matrices. + """ + + event_dim = 2 + + def check(self, value): + value_tril = value.tril() + return (value_tril == value).view(value.shape[:-2] + (-1,)).min(-1)[0] + + +class _LowerCholesky(Constraint): + """ + Constrain to lower-triangular square matrices with positive diagonals. + """ + + event_dim = 2 + + def check(self, value): + value_tril = value.tril() + lower_triangular = ( + (value_tril == value).view(value.shape[:-2] + (-1,)).min(-1)[0] + ) + + positive_diagonal = (value.diagonal(dim1=-2, dim2=-1) > 0).min(-1)[0] + return lower_triangular & positive_diagonal + + +class _CorrCholesky(Constraint): + """ + Constrain to lower-triangular square matrices with positive diagonals and each + row vector being of unit length. + """ + + event_dim = 2 + + def check(self, value): + tol = ( + float(ops.finfo(value.dtype).eps) * value.shape[-1] * 10 + ) # 10 is an adjustable fudge factor + row_norm = ops.norm(value.detach(), p=2, dim=-1) + unit_row_norm = (row_norm - 1.0).abs().le(tol).all(dim=-1) + return _LowerCholesky().check(value) & unit_row_norm + + +class _Square(Constraint): + """ + Constrain to square matrices. + """ + + event_dim = 2 + + def check(self, value): + return ops.full( + size=value.shape[:-2], + fill_value=(value.shape[-2] == value.shape[-1]), + dtype=mindspore.bool_, + ) + + +class _Symmetric(_Square): + """ + Constrain to Symmetric square matrices. + """ + + def check(self, value): + square_check = super().check(value) + if not square_check.all(): + return square_check + return ops.isclose(value, value.mT, atol=1e-6).all(-2).all(-1) + + +class _PositiveSemidefinite(_Symmetric): + """ + Constrain to positive-semidefinite matrices. + """ + + def check(self, value): + sym_check = super().check(value) + if not sym_check.all(): + return sym_check + return ops.eigvalsh(value).ge(0).all(-1) + + +class _PositiveDefinite(_Symmetric): + """ + Constrain to positive-definite matrices. + """ + + def check(self, value): + sym_check = super().check(value) + if not sym_check.all(): + return sym_check + return ops.linalg.cholesky_ex(value).info.eq(0) + + +class _Cat(Constraint): + """ + Constraint functor that applies a sequence of constraints + `cseq` at the submatrices at dimension `dim`, + each of size `lengths[dim]`, in a way compatible with :func:`ops.cat`. + """ + + def __init__(self, cseq, dim=0, lengths=None): + assert all(isinstance(c, Constraint) for c in cseq) + self.cseq = list(cseq) + if lengths is None: + lengths = [1] * len(self.cseq) + self.lengths = list(lengths) + assert len(self.lengths) == len(self.cseq) + self.dim = dim + super().__init__() + + @property + def is_discrete(self): + return any(c.is_discrete for c in self.cseq) + + @property + def event_dim(self): + return max(c.event_dim for c in self.cseq) + + def check(self, value): + assert -value.ndim <= self.dim < value.ndim + checks = [] + start = 0 + for constr, length in zip(self.cseq, self.lengths): + v = value.narrow(self.dim, start, length) + checks.append(constr.check(v)) + start = start + length # avoid += for jit compat + return ops.cat(checks, self.dim) + + +class _Stack(Constraint): + """ + Constraint functor that applies a sequence of constraints + `cseq` at the submatrices at dimension `dim`, + in a way compatible with :func:`ops.stack`. + """ + + def __init__(self, cseq, dim=0): + assert all(isinstance(c, Constraint) for c in cseq) + self.cseq = list(cseq) + self.dim = dim + super().__init__() + + @property + def is_discrete(self): + return any(c.is_discrete for c in self.cseq) + + @property + def event_dim(self): + dim = max(c.event_dim for c in self.cseq) + if self.dim + dim < 0: + dim += 1 + return dim + + def check(self, value): + assert -value.ndim <= self.dim < value.ndim + vs = [value.select(self.dim, i) for i in range(value.shape[self.dim])] + return ops.stack( + [constr.check(v) for v, constr in zip(vs, self.cseq)], self.dim + ) + + +# Public interface. +dependent = _Dependent() +dependent_property = _DependentProperty +independent = _IndependentConstraint +boolean = _Boolean() +one_hot = _OneHot() +nonnegative_integer = _IntegerGreaterThan(0) +positive_integer = _IntegerGreaterThan(1) +integer_interval = _IntegerInterval +real = _Real() +real_vector = independent(real, 1) +positive = _GreaterThan(0.0) +nonnegative = _GreaterThanEq(0.0) +greater_than = _GreaterThan +greater_than_eq = _GreaterThanEq +less_than = _LessThan +multinomial = _Multinomial +unit_interval = _Interval(0.0, 1.0) +interval = _Interval +half_open_interval = _HalfOpenInterval +simplex = _Simplex() +lower_triangular = _LowerTriangular() +lower_cholesky = _LowerCholesky() +corr_cholesky = _CorrCholesky() +square = _Square() +symmetric = _Symmetric() +positive_semidefinite = _PositiveSemidefinite() +positive_definite = _PositiveDefinite() +cat = _Cat +stack = _Stack diff --git a/mindnlp/core/distributions/distribution.py b/mindnlp/core/distributions/distribution.py new file mode 100644 index 000000000..1d972ebfa --- /dev/null +++ b/mindnlp/core/distributions/distribution.py @@ -0,0 +1,368 @@ +"""distribution base class""" +# mypy: allow-untyped-defs +import warnings +import builtins +from typing import Any, Dict, Optional, Tuple, Union, List +from typing_extensions import deprecated + +import mindspore +from .. import ops +from ..autograd import no_grad + +from . import constraints +from .utils import lazy_property + + +__all__ = ["Distribution"] + +_size = Union[List[builtins.int], Tuple[builtins.int, ...]] + + +class Distribution: + r""" + Distribution is the abstract base class for probability distributions. + """ + + has_rsample = False + has_enumerate_support = False + _validate_args = __debug__ + + + @staticmethod + def set_default_validate_args(value: bool) -> None: + """ + Sets whether validation is enabled or disabled. + + The default behavior mimics Python's ``assert`` statement: validation + is on by default, but is disabled if Python is run in optimized mode + (via ``python -O``). Validation may be expensive, so you may want to + disable it once a model is working. + + Args: + value (bool): Whether to enable validation. + """ + if value not in [True, False]: + raise ValueError + Distribution._validate_args = value + + + def __init__( + self, + batch_shape: tuple = (), + event_shape: tuple = (), + validate_args: Optional[bool] = None, + ): + self._batch_shape = batch_shape + self._event_shape = event_shape + if validate_args is not None: + self._validate_args = validate_args + if self._validate_args: + try: + arg_constraints = self.arg_constraints + except NotImplementedError: + arg_constraints = {} + warnings.warn( + f"{self.__class__} does not define `arg_constraints`. " + + "Please set `arg_constraints = {}` or initialize the distribution " + + "with `validate_args=False` to turn off validation." + ) + for param, constraint in arg_constraints.items(): + if constraints.is_dependent(constraint): + continue # skip constraints that cannot be checked + if param not in self.__dict__ and isinstance( + getattr(type(self), param), lazy_property + ): + continue # skip checking lazily-constructed args + value = getattr(self, param) + valid = constraint.check(value) + if not valid.all(): + raise ValueError( + f"Expected parameter {param} " + f"({type(value).__name__} of shape {tuple(value.shape)}) " + f"of distribution {repr(self)} " + f"to satisfy the constraint {repr(constraint)}, " + f"but found invalid values:\n{value}" + ) + super().__init__() + + + def expand(self, batch_shape: tuple, _instance=None): + """ + Returns a new distribution instance (or populates an existing instance + provided by a derived class) with batch dimensions expanded to + `batch_shape`. This method calls :class:`~mindspore.Tensor.expand` on + the distribution's parameters. As such, this does not allocate new + memory for the expanded distribution instance. Additionally, + this does not repeat any args checking or parameter broadcasting in + `__init__.py`, when an instance is first created. + + Args: + batch_shape (tuple): the desired expanded size. + _instance: new instance provided by subclasses that + need to override `.expand`. + + Returns: + New distribution instance with batch dimensions expanded to + `batch_size`. + """ + raise NotImplementedError + + + @property + def batch_shape(self) -> tuple: + """ + Returns the shape over which parameters are batched. + """ + return self._batch_shape + + @property + def event_shape(self) -> tuple: + """ + Returns the shape of a single sample (without batching). + """ + return self._event_shape + + @property + def arg_constraints(self) -> Dict[str, constraints.Constraint]: + """ + Returns a dictionary from argument names to + :class:`~core.distributions.constraints.Constraint` objects that + should be satisfied by each argument of this distribution. Args that + are not tensors need not appear in this dict. + """ + raise NotImplementedError + + @property + def support(self) -> Optional[Any]: + """ + Returns a :class:`~core.distributions.constraints.Constraint` object + representing this distribution's support. + """ + raise NotImplementedError + + @property + def mean(self) -> mindspore.Tensor: + """ + Returns the mean of the distribution. + """ + raise NotImplementedError + + @property + def mode(self) -> mindspore.Tensor: + """ + Returns the mode of the distribution. + """ + raise NotImplementedError(f"{self.__class__} does not implement mode") + + @property + def variance(self) -> mindspore.Tensor: + """ + Returns the variance of the distribution. + """ + raise NotImplementedError + + @property + def stddev(self) -> mindspore.Tensor: + """ + Returns the standard deviation of the distribution. + """ + return self.variance.sqrt() + + + def sample(self, sample_shape: tuple = ()) -> mindspore.Tensor: + """ + Generates a sample_shape shaped sample or sample_shape shaped batch of + samples if the distribution parameters are batched. + """ + with no_grad(): + return self.rsample(sample_shape) + + + + def rsample(self, sample_shape: tuple = ()) -> mindspore.Tensor: + """ + Generates a sample_shape shaped reparameterized sample or sample_shape + shaped batch of reparameterized samples if the distribution parameters + are batched. + """ + raise NotImplementedError + + + + @deprecated( + "`sample_n(n)` will be deprecated. Use `sample((n,))` instead.", + category=FutureWarning, + ) + def sample_n(self, n: int) -> mindspore.Tensor: + """ + Generates n samples or n batches of samples if the distribution + parameters are batched. + """ + return self.sample(tuple((n,))) + + + + def log_prob(self, value: mindspore.Tensor) -> mindspore.Tensor: + """ + Returns the log of the probability density/mass function evaluated at + `value`. + + Args: + value (Tensor): + """ + raise NotImplementedError + + + + def cdf(self, value: mindspore.Tensor) -> mindspore.Tensor: + """ + Returns the cumulative density/mass function evaluated at + `value`. + + Args: + value (Tensor): + """ + raise NotImplementedError + + + + def icdf(self, value: mindspore.Tensor) -> mindspore.Tensor: + """ + Returns the inverse cumulative density/mass function evaluated at + `value`. + + Args: + value (Tensor): + """ + raise NotImplementedError + + + + def enumerate_support(self, expand: bool = True) -> mindspore.Tensor: + """ + Returns tensor containing all values supported by a discrete + distribution. The result will enumerate over dimension 0, so the shape + of the result will be `(cardinality,) + batch_shape + event_shape` + (where `event_shape = ()` for univariate distributions). + + Note that this enumerates over all batched tensors in lock-step + `[[0, 0], [1, 1], ...]`. With `expand=False`, enumeration happens + along dim 0, but with the remaining batch dimensions being + singleton dimensions, `[[0], [1], ..`. + + To iterate over the full Cartesian product use + `itertools.product(m.enumerate_support())`. + + Args: + expand (bool): whether to expand the support over the + batch dims to match the distribution's `batch_shape`. + + Returns: + Tensor iterating over dimension 0. + """ + raise NotImplementedError + + + + def entropy(self) -> mindspore.Tensor: + """ + Returns entropy of distribution, batched over batch_shape. + + Returns: + Tensor of shape batch_shape. + """ + raise NotImplementedError + + + + def perplexity(self) -> mindspore.Tensor: + """ + Returns perplexity of distribution, batched over batch_shape. + + Returns: + Tensor of shape batch_shape. + """ + return ops.exp(self.entropy()) + + + def _extended_shape(self, sample_shape: _size = ()) -> Tuple[int, ...]: + """ + Returns the size of the sample returned by the distribution, given + a `sample_shape`. Note, that the batch and event shapes of a distribution + instance are fixed at the time of construction. If this is empty, the + returned shape is upcast to (1,). + + Args: + sample_shape (tuple): the size of the sample to be drawn. + """ + if not isinstance(sample_shape, tuple): + sample_shape = tuple(sample_shape) + return tuple(sample_shape + self._batch_shape + self._event_shape) + + def _validate_sample(self, value: mindspore.Tensor) -> None: + """ + Argument validation for distribution methods such as `log_prob`, + `cdf` and `icdf`. The rightmost dimensions of a value to be + scored via these methods must agree with the distribution's batch + and event shapes. + + Args: + value (Tensor): the tensor whose log probability is to be + computed by the `log_prob` method. + Raises + ValueError: when the rightmost dimensions of `value` do not match the + distribution's batch and event shapes. + """ + if not isinstance(value, mindspore.Tensor): + raise ValueError("The value argument to log_prob must be a Tensor") + + event_dim_start = len(value.shape) - len(self._event_shape) + if value.shape[event_dim_start:] != self._event_shape: + raise ValueError( + f"The right-most size of value must match event_shape: {value.shape} vs {self._event_shape}." + ) + + actual_shape = value.shape + expected_shape = self._batch_shape + self._event_shape + for i, j in zip(reversed(actual_shape), reversed(expected_shape)): + if i != 1 and j != 1 and i != j: + raise ValueError( + f"Value is not broadcastable with batch_shape+event_shape: {actual_shape} vs {expected_shape}." + ) + try: + support = self.support + except NotImplementedError: + warnings.warn( + f"{self.__class__} does not define `support` to enable " + + "sample validation. Please initialize the distribution with " + + "`validate_args=False` to turn off validation." + ) + return + assert support is not None + valid = support.check(value) + if not valid.all(): + raise ValueError( + "Expected value argument " + f"({type(value).__name__} of shape {tuple(value.shape)}) " + f"to be within the support ({repr(support)}) " + f"of the distribution {repr(self)}, " + f"but found invalid values:\n{value}" + ) + + def _get_checked_instance(self, cls, _instance=None): + if _instance is None and type(self).__init__ != cls.__init__: + raise NotImplementedError( + f"Subclass {self.__class__.__name__} of {cls.__name__} that defines a custom __init__ method " + "must also define a custom .expand() method." + ) + return self.__new__(type(self)) if _instance is None else _instance # pylint: disable=no-value-for-parameter + + def __repr__(self) -> str: + param_names = [k for k, _ in self.arg_constraints.items() if k in self.__dict__] + args_string = ", ".join( + [ + f"{p}: {self.__dict__[p] if self.__dict__[p].numel() == 1 else self.__dict__[p].shape}" + for p in param_names + ] + ) + return self.__class__.__name__ + "(" + args_string + ")" diff --git a/mindnlp/core/distributions/exp_family.py b/mindnlp/core/distributions/exp_family.py new file mode 100644 index 000000000..d8d79676a --- /dev/null +++ b/mindnlp/core/distributions/exp_family.py @@ -0,0 +1,65 @@ +"""exp family""" +# mypy: allow-untyped-defs +import mindspore +from .distribution import Distribution + + +__all__ = ["ExponentialFamily"] + + +class ExponentialFamily(Distribution): + r""" + ExponentialFamily is the abstract base class for probability distributions belonging to an + exponential family, whose probability mass/density function has the form is defined below + + .. math:: + + p_{F}(x; \theta) = \exp(\langle t(x), \theta\rangle - F(\theta) + k(x)) + + where :math:`\theta` denotes the natural parameters, :math:`t(x)` denotes the sufficient statistic, + :math:`F(\theta)` is the log normalizer function for a given family and :math:`k(x)` is the carrier + measure. + + Note: + This class is an intermediary between the `Distribution` class and distributions which belong + to an exponential family mainly to check the correctness of the `.entropy()` and analytic KL + divergence methods. We use this class to compute the entropy and KL divergence using the AD + framework and Bregman divergences (courtesy of: Frank Nielsen and Richard Nock, Entropies and + Cross-entropies of Exponential Families). + """ + + @property + def _natural_params(self): + """ + Abstract method for natural parameters. Returns a tuple of Tensors based + on the distribution + """ + raise NotImplementedError + + def _log_normalizer(self, *natural_params): + """ + Abstract method for log normalizer function. Returns a log normalizer based on + the distribution and input + """ + raise NotImplementedError + + @property + def _mean_carrier_measure(self): + """ + Abstract method for expected carrier measure, which is required for computing + entropy. + """ + raise NotImplementedError + + def entropy(self): + """ + Method to compute the entropy using Bregman divergence of the log normalizer. + """ + result = -self._mean_carrier_measure + nparams = [p.requires_grad_() for p in self._natural_params] + lg_normal = self._log_normalizer(*nparams) + gradients = mindspore.grad(lg_normal.sum(), nparams) + result += lg_normal + for np, g in zip(nparams, gradients): + result -= (np * g).reshape(self._batch_shape + (-1,)).sum(-1) + return result diff --git a/mindnlp/core/distributions/gamma.py b/mindnlp/core/distributions/gamma.py new file mode 100644 index 000000000..4e7d6f998 --- /dev/null +++ b/mindnlp/core/distributions/gamma.py @@ -0,0 +1,113 @@ +"""gamma""" +# mypy: allow-untyped-defs +from numbers import Number + +from .. import ops +from . import constraints +from .exp_family import ExponentialFamily +from .utils import broadcast_all + + +__all__ = ["Gamma"] + + +class Gamma(ExponentialFamily): + r""" + Creates a Gamma distribution parameterized by shape :attr:`concentration` and :attr:`rate`. + + Example:: + + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> m = Gamma(core.tensor([1.0]), core.tensor([1.0])) + >>> m.sample() # Gamma distributed with concentration=1 and rate=1 + tensor([ 0.1046]) + + Args: + concentration (float or Tensor): shape parameter of the distribution + (often referred to as alpha) + rate (float or Tensor): rate = 1 / scale of the distribution + (often referred to as beta) + """ + arg_constraints = { + "concentration": constraints.positive, + "rate": constraints.positive, + } + support = constraints.nonnegative + has_rsample = True + _mean_carrier_measure = 0 + + @property + def mean(self): + return self.concentration / self.rate + + @property + def mode(self): + return ((self.concentration - 1) / self.rate).clamp(min=0) + + @property + def variance(self): + return self.concentration / self.rate.pow(2) + + def __init__(self, concentration, rate, validate_args=None): + self.concentration, self.rate = broadcast_all(concentration, rate) + if isinstance(concentration, Number) and isinstance(rate, Number): + batch_shape = () + else: + batch_shape = self.concentration.shape + super().__init__(batch_shape, validate_args=validate_args) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(Gamma, _instance) + new.concentration = self.concentration.broadcast_to(batch_shape) + new.rate = self.rate.broadcast_to(batch_shape) + super(Gamma, new).__init__(batch_shape, validate_args=False) + new._validate_args = self._validate_args + return new + + def rsample(self, sample_shape=()): + shape = self._extended_shape(sample_shape) + + if shape == (): # pylint: disable=use-implicit-booleaness-not-comparison + sample_shape = (1,) + else: + sample_shape = shape + value = ops.gamma(sample_shape, self.concentration, self.rate) + + if shape == (): # pylint: disable=use-implicit-booleaness-not-comparison + value = ops.squeeze(value) + + value = value.clamp( + min=float(ops.finfo(value.dtype).tiny) + ) # do not record in autograd graph + return value + + def log_prob(self, value): + value = ops.as_tensor(value, dtype=self.rate.dtype) + if self._validate_args: + self._validate_sample(value) + return ( + ops.xlogy(self.concentration, self.rate) + + ops.xlogy(self.concentration - 1, value) + - self.rate * value + - ops.lgamma(self.concentration) + ) + + def entropy(self): + return ( + self.concentration + - ops.log(self.rate) + + ops.lgamma(self.concentration) + + (1.0 - self.concentration) * ops.digamma(self.concentration) + ) + + @property + def _natural_params(self): + return (self.concentration - 1, -self.rate) + + def _log_normalizer(self, x, y): + return ops.lgamma(x + 1) + (x + 1) * ops.log(-y.reciprocal()) + + def cdf(self, value): + if self._validate_args: + self._validate_sample(value) + return ops.igamma(self.concentration, self.rate * value) diff --git a/mindnlp/core/distributions/independent.py b/mindnlp/core/distributions/independent.py new file mode 100644 index 000000000..a65512c55 --- /dev/null +++ b/mindnlp/core/distributions/independent.py @@ -0,0 +1,126 @@ +"""independent""" +# mypy: allow-untyped-defs +from typing import Dict + +from . import constraints +from .distribution import Distribution +from .utils import _sum_rightmost + + +__all__ = ["Independent"] + + +class Independent(Distribution): + r""" + Reinterprets some of the batch dims of a distribution as event dims. + + This is mainly useful for changing the shape of the result of + :meth:`log_prob`. For example to create a diagonal Normal distribution with + the same shape as a Multivariate Normal distribution (so they are + interchangeable), you can:: + + >>> from core.distributions.multivariate_normal import MultivariateNormal + >>> from core.distributions.normal import Normal + >>> loc = core.zeros(3) + >>> scale = core.ones(3) + >>> mvn = MultivariateNormal(loc, scale_tril=core.diag(scale)) + >>> [mvn.batch_shape, mvn.event_shape] + [core.Size([]), core.Size([3])] + >>> normal = Normal(loc, scale) + >>> [normal.batch_shape, normal.event_shape] + [core.Size([3]), core.Size([])] + >>> diagn = Independent(normal, 1) + >>> [diagn.batch_shape, diagn.event_shape] + [core.Size([]), core.Size([3])] + + Args: + base_distribution (core.distributions.distribution.Distribution): a + base distribution + reinterpreted_batch_ndims (int): the number of batch dims to + reinterpret as event dims + """ + arg_constraints: Dict[str, constraints.Constraint] = {} + + def __init__( + self, base_distribution, reinterpreted_batch_ndims, validate_args=None + ): + if reinterpreted_batch_ndims > len(base_distribution.batch_shape): + raise ValueError( + "Expected reinterpreted_batch_ndims <= len(base_distribution.batch_shape), " + f"actual {reinterpreted_batch_ndims} vs {len(base_distribution.batch_shape)}" + ) + shape = base_distribution.batch_shape + base_distribution.event_shape + event_dim = reinterpreted_batch_ndims + len(base_distribution.event_shape) + batch_shape = shape[: len(shape) - event_dim] + event_shape = shape[len(shape) - event_dim :] + self.base_dist = base_distribution + self.reinterpreted_batch_ndims = reinterpreted_batch_ndims + super().__init__(batch_shape, event_shape, validate_args=validate_args) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(Independent, _instance) + new.base_dist = self.base_dist.expand( + batch_shape + self.event_shape[: self.reinterpreted_batch_ndims] + ) + new.reinterpreted_batch_ndims = self.reinterpreted_batch_ndims + super(Independent, new).__init__( + batch_shape, self.event_shape, validate_args=False + ) + new._validate_args = self._validate_args + return new + + @property + def has_rsample(self): + return self.base_dist.has_rsample + + @property + def has_enumerate_support(self): + if self.reinterpreted_batch_ndims > 0: + return False + return self.base_dist.has_enumerate_support + + @constraints.dependent_property + def support(self): # pylint: disable=invalid-overridden-method + result = self.base_dist.support + if self.reinterpreted_batch_ndims: + result = constraints.independent(result, self.reinterpreted_batch_ndims) + return result + + @property + def mean(self): + return self.base_dist.mean + + @property + def mode(self): + return self.base_dist.mode + + @property + def variance(self): + return self.base_dist.variance + + def sample(self, sample_shape=()): + return self.base_dist.sample(sample_shape) + + def rsample(self, sample_shape=()): + return self.base_dist.rsample(sample_shape) + + def log_prob(self, value): + log_prob = self.base_dist.log_prob(value) + return _sum_rightmost(log_prob, self.reinterpreted_batch_ndims) + + def entropy(self): + entropy = self.base_dist.entropy() + return _sum_rightmost(entropy, self.reinterpreted_batch_ndims) + + def enumerate_support(self, expand=True): + if self.reinterpreted_batch_ndims > 0: + raise NotImplementedError( + "Enumeration over cartesian product is not implemented" + ) + return self.base_dist.enumerate_support(expand=expand) + + def __repr__(self): + return ( + self.__class__.__name__ + + f"({self.base_dist}, {self.reinterpreted_batch_ndims})" + ) diff --git a/mindnlp/core/distributions/negative_binomial.py b/mindnlp/core/distributions/negative_binomial.py new file mode 100644 index 000000000..4e18ec8f8 --- /dev/null +++ b/mindnlp/core/distributions/negative_binomial.py @@ -0,0 +1,138 @@ +"""negative binomial""" +# mypy: allow-untyped-defs +# pylint: disable=method-hidden +from .. import ops +from ..autograd import no_grad +from ..nn import functional as F +from . import constraints +from .distribution import Distribution +from .gamma import Gamma +from .utils import ( + broadcast_all, + lazy_property, + logits_to_probs, + probs_to_logits, +) + + +__all__ = ["NegativeBinomial"] + + +class NegativeBinomial(Distribution): + r""" + Creates a Negative Binomial distribution, i.e. distribution + of the number of successful independent and identical Bernoulli trials + before :attr:`total_count` failures are achieved. The probability + of success of each Bernoulli trial is :attr:`probs`. + + Args: + total_count (float or Tensor): non-negative number of negative Bernoulli + trials to stop, although the distribution is still valid for real + valued count + probs (Tensor): Event probabilities of success in the half open interval [0, 1) + logits (Tensor): Event log-odds for probabilities of success + """ + arg_constraints = { + "total_count": constraints.greater_than_eq(0), + "probs": constraints.half_open_interval(0.0, 1.0), + "logits": constraints.real, + } + support = constraints.nonnegative_integer + + def __init__(self, total_count, probs=None, logits=None, validate_args=None): + if (probs is None) == (logits is None): + raise ValueError( + "Either `probs` or `logits` must be specified, but not both." + ) + if probs is not None: + ( + self.total_count, + self.probs, + ) = broadcast_all(total_count, probs) + self.total_count = self.total_count.type_as(self.probs) + else: + ( + self.total_count, + self.logits, + ) = broadcast_all(total_count, logits) + self.total_count = self.total_count.type_as(self.logits) + + self._param = self.probs if probs is not None else self.logits + batch_shape = self._param.shape + super().__init__(batch_shape, validate_args=validate_args) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(NegativeBinomial, _instance) + new.total_count = self.total_count.expand(batch_shape) + if "probs" in self.__dict__: + new.probs = self.probs.expand(batch_shape) + new._param = new.probs + if "logits" in self.__dict__: + new.logits = self.logits.expand(batch_shape) + new._param = new.logits + super(NegativeBinomial, new).__init__(batch_shape, validate_args=False) + new._validate_args = self._validate_args + return new + + def _new(self, *args, **kwargs): + return self._param.new(*args, **kwargs) + + @property + def mean(self): + return self.total_count * ops.exp(self.logits) + + @property + def mode(self): + return ((self.total_count - 1) * self.logits.exp()).floor().clamp(min=0.0) + + @property + def variance(self): + return self.mean / ops.sigmoid(-self.logits) + + @lazy_property + def logits(self): + return probs_to_logits(self.probs, is_binary=True) + + @lazy_property + def probs(self): + return logits_to_probs(self.logits, is_binary=True) + + @property + def param_shape(self): + return self._param.shape + + @lazy_property + def _gamma(self): + # Note we avoid validating because self.total_count can be zero. + return Gamma( + concentration=self.total_count, + rate=ops.exp(-self.logits), + validate_args=False, + ) + + def sample(self, sample_shape=()): + with no_grad(): + rate = self._gamma.sample(sample_shape=sample_shape) + return ops.poisson(rate) + + def log_prob(self, value): + if self._validate_args: + self._validate_sample(value) + + log_unnormalized_prob = self.total_count * F.logsigmoid( + -self.logits + ) + value * F.logsigmoid(self.logits) + + log_normalization = ( + -ops.lgamma(self.total_count + value) + + ops.lgamma(1.0 + value) + + ops.lgamma(self.total_count) + ) + # The case self.total_count == 0 and value == 0 has probability 1 but + # lgamma(0) is infinite. Handle this case separately using a function + # that does not modify tensors in place to allow Jit compilation. + log_normalization = log_normalization.masked_fill( + self.total_count + value == 0.0, 0.0 + ) + + return log_unnormalized_prob - log_normalization diff --git a/mindnlp/core/distributions/normal.py b/mindnlp/core/distributions/normal.py new file mode 100644 index 000000000..76287a9cb --- /dev/null +++ b/mindnlp/core/distributions/normal.py @@ -0,0 +1,112 @@ +"""normal""" +# mypy: allow-untyped-defs +import math +from numbers import Number, Real + +from .. import ops +from ..autograd import no_grad +from . import constraints +from .exp_family import ExponentialFamily +from .utils import _standard_normal, broadcast_all + + +__all__ = ["Normal"] + + +class Normal(ExponentialFamily): + r""" + Creates a normal (also called Gaussian) distribution parameterized by + :attr:`loc` and :attr:`scale`. + + Example:: + + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> m = Normal(core.tensor([0.0]), core.tensor([1.0])) + >>> m.sample() # normally distributed with loc=0 and scale=1 + tensor([ 0.1046]) + + Args: + loc (float or Tensor): mean of the distribution (often referred to as mu) + scale (float or Tensor): standard deviation of the distribution + (often referred to as sigma) + """ + arg_constraints = {"loc": constraints.real, "scale": constraints.positive} + support = constraints.real + has_rsample = True + _mean_carrier_measure = 0 + + @property + def mean(self): + return self.loc + + @property + def mode(self): + return self.loc + + @property + def stddev(self): + return self.scale + + @property + def variance(self): + return self.stddev.pow(2) + + def __init__(self, loc, scale, validate_args=None): + self.loc, self.scale = broadcast_all(loc, scale) + if isinstance(loc, Number) and isinstance(scale, Number): + batch_shape = () + else: + batch_shape = self.loc.shape + super().__init__(batch_shape, validate_args=validate_args) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(Normal, _instance) + new.loc = self.loc.expand(batch_shape) + new.scale = self.scale.expand(batch_shape) + super(Normal, new).__init__(batch_shape, validate_args=False) + new._validate_args = self._validate_args + return new + + def sample(self, sample_shape=()): + shape = self._extended_shape(sample_shape) + with no_grad(): + return ops.normal(self.loc.expand(shape), self.scale.expand(shape)) + + def rsample(self, sample_shape=()): + shape = self._extended_shape(sample_shape) + eps = _standard_normal(shape, dtype=self.loc.dtype) + return self.loc + eps * self.scale + + def log_prob(self, value): + if self._validate_args: + self._validate_sample(value) + # compute the variance + var = self.scale**2 + log_scale = ( + math.log(self.scale) if isinstance(self.scale, Real) else self.scale.log() + ) + return ( + -((value - self.loc) ** 2) / (2 * var) + - log_scale + - math.log(math.sqrt(2 * math.pi)) + ) + + def cdf(self, value): + if self._validate_args: + self._validate_sample(value) + return 0.5 * ( + 1 + ops.erf((value - self.loc) * self.scale.reciprocal() / math.sqrt(2)) + ) + + def icdf(self, value): + return self.loc + self.scale * ops.erfinv(2 * value - 1) * math.sqrt(2) + + def entropy(self): + return 0.5 + 0.5 * math.log(2 * math.pi) + ops.log(self.scale) + + @property + def _natural_params(self): + return (self.loc / self.scale.pow(2), -0.5 * self.scale.pow(2).reciprocal()) + + def _log_normalizer(self, x, y): + return -0.25 * x.pow(2) / y + 0.5 * ops.log(-math.pi / y) diff --git a/mindnlp/core/distributions/relaxed_bernoulli.py b/mindnlp/core/distributions/relaxed_bernoulli.py new file mode 100644 index 000000000..b020ce27e --- /dev/null +++ b/mindnlp/core/distributions/relaxed_bernoulli.py @@ -0,0 +1,159 @@ +# mypy: allow-untyped-defs +# pylint: disable=method-hidden +"""RelaxedBernoulli""" +from numbers import Number + +from .. import ops +from . import constraints +from .distribution import Distribution +from .transformed_distribution import TransformedDistribution +from .transforms import SigmoidTransform +from .utils import ( + broadcast_all, + clamp_probs, + lazy_property, + logits_to_probs, + probs_to_logits, +) + +__all__ = ["LogitRelaxedBernoulli", "RelaxedBernoulli"] + + + +class LogitRelaxedBernoulli(Distribution): + r""" + Creates a LogitRelaxedBernoulli distribution parameterized by :attr:`probs` + or :attr:`logits` (but not both), which is the logit of a RelaxedBernoulli + distribution. + + Samples are logits of values in (0, 1). See [1] for more details. + + Args: + temperature (Tensor): relaxation temperature + probs (Number, Tensor): the probability of sampling `1` + logits (Number, Tensor): the log-odds of sampling `1` + + [1] The Concrete Distribution: A Continuous Relaxation of Discrete Random + Variables (Maddison et al., 2017) + + [2] Categorical Reparametrization with Gumbel-Softmax + (Jang et al., 2017) + """ + arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real} + support = constraints.real + + def __init__(self, temperature, probs=None, logits=None, validate_args=None): + self.temperature = temperature + if (probs is None) == (logits is None): + raise ValueError( + "Either `probs` or `logits` must be specified, but not both." + ) + if probs is not None: + is_scalar = isinstance(probs, Number) + (self.probs,) = broadcast_all(probs) + else: + is_scalar = isinstance(logits, Number) + (self.logits,) = broadcast_all(logits) + self._param = self.probs if probs is not None else self.logits + if is_scalar: + batch_shape = () + else: + batch_shape = self._param.shape + super().__init__(batch_shape, validate_args=validate_args) + + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(LogitRelaxedBernoulli, _instance) + new.temperature = self.temperature + if "probs" in self.__dict__: + new.probs = self.probs.broadcast_to(batch_shape) + new._param = new.probs + if "logits" in self.__dict__: + new.logits = self.logits.broadcast_to(batch_shape) + new._param = new.logits + super(LogitRelaxedBernoulli, new).__init__(batch_shape, validate_args=False) + new._validate_args = self._validate_args + return new + + + def _new(self, *args, **kwargs): + return self._param.new(*args, **kwargs) + + @lazy_property + def logits(self): + return probs_to_logits(self.probs, is_binary=True) + + @lazy_property + def probs(self): + return logits_to_probs(self.logits, is_binary=True) + + @property + def param_shape(self): + return self._param.shape + + + def rsample(self, sample_shape=()): + shape = self._extended_shape(sample_shape) + probs = clamp_probs(self.probs.broadcast_to(shape)) + uniforms = clamp_probs( + ops.rand(shape, dtype=probs.dtype) + ) + return ( + uniforms.log() - (-uniforms).log1p() + probs.log() - (-probs).log1p() + ) / self.temperature + + + + def log_prob(self, value): + if self._validate_args: + self._validate_sample(value) + logits, value = broadcast_all(self.logits, value) + diff = logits - value.mul(self.temperature) + return self.temperature.log() + diff - 2 * diff.exp().log1p() + + +class RelaxedBernoulli(TransformedDistribution): + r""" + Creates a RelaxedBernoulli distribution, parametrized by + :attr:`temperature`, and either :attr:`probs` or :attr:`logits` + (but not both). This is a relaxed version of the `Bernoulli` distribution, + so the values are in (0, 1), and has reparametrizable samples. + + Example:: + + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> m = RelaxedBernoulli(core.tensor([2.2]), + ... core.tensor([0.1, 0.2, 0.3, 0.99])) + >>> m.sample() + tensor([ 0.2951, 0.3442, 0.8918, 0.9021]) + + Args: + temperature (Tensor): relaxation temperature + probs (Number, Tensor): the probability of sampling `1` + logits (Number, Tensor): the log-odds of sampling `1` + """ + arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real} + support = constraints.unit_interval + has_rsample = True + + def __init__(self, temperature, probs=None, logits=None, validate_args=None): + base_dist = LogitRelaxedBernoulli(temperature, probs, logits) + super().__init__(base_dist, SigmoidTransform(), validate_args=validate_args) + + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(RelaxedBernoulli, _instance) + return super().expand(batch_shape, _instance=new) + + + @property + def temperature(self): + return self.base_dist.temperature + + @property + def logits(self): + return self.base_dist.logits + + @property + def probs(self): + return self.base_dist.probs diff --git a/mindnlp/core/distributions/relaxed_categorical.py b/mindnlp/core/distributions/relaxed_categorical.py new file mode 100644 index 000000000..c749c7cd8 --- /dev/null +++ b/mindnlp/core/distributions/relaxed_categorical.py @@ -0,0 +1,144 @@ +"""RelaxedCategorical""" + +# mypy: allow-untyped-defs +from .. import ops +from . import constraints +from .categorical import Categorical +from .distribution import Distribution +from .transformed_distribution import TransformedDistribution +from .transforms import ExpTransform +from .utils import broadcast_all, clamp_probs + +__all__ = ["ExpRelaxedCategorical", "RelaxedOneHotCategorical"] + + +class ExpRelaxedCategorical(Distribution): + r""" + Creates a ExpRelaxedCategorical parameterized by + :attr:`temperature`, and either :attr:`probs` or :attr:`logits` (but not both). + Returns the log of a point in the simplex. Based on the interface to + :class:`OneHotCategorical`. + + Implementation based on [1]. + + See also: :func:`distributions.OneHotCategorical` + + Args: + temperature (Tensor): relaxation temperature + probs (Tensor): event probabilities + logits (Tensor): unnormalized log probability for each event + + [1] The Concrete Distribution: A Continuous Relaxation of Discrete Random Variables + (Maddison et al., 2017) + + [2] Categorical Reparametrization with Gumbel-Softmax + (Jang et al., 2017) + """ + arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector} + support = ( + constraints.real_vector + ) # The true support is actually a submanifold of this. + has_rsample = True + + def __init__(self, temperature, probs=None, logits=None, validate_args=None): + self._categorical = Categorical(probs, logits) + self.temperature = temperature + batch_shape = self._categorical.batch_shape + event_shape = self._categorical.param_shape[-1:] + super().__init__(batch_shape, event_shape, validate_args=validate_args) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(ExpRelaxedCategorical, _instance) + new.temperature = self.temperature + new._categorical = self._categorical.expand(batch_shape) + super(ExpRelaxedCategorical, new).__init__( + batch_shape, self.event_shape, validate_args=False + ) + new._validate_args = self._validate_args + return new + + def _new(self, *args, **kwargs): + return self._categorical._new(*args, **kwargs) + + @property + def param_shape(self): + return self._categorical.param_shape + + @property + def logits(self): + return self._categorical.logits + + @property + def probs(self): + return self._categorical.probs + + def rsample(self, sample_shape=()): + shape = self._extended_shape(sample_shape) + uniforms = clamp_probs( + ops.rand(shape, dtype=self.logits.dtype) + ) + gumbels = -((-(uniforms.log())).log()) + scores = (self.logits + gumbels) / self.temperature + return scores - ops.logsumexp(scores, dim=-1, keepdim=True) + + def log_prob(self, value): + K = self._categorical._num_events + if self._validate_args: + self._validate_sample(value) + logits, value = broadcast_all(self.logits, value) + log_scale = ops.full_like( + self.temperature, float(K) + ).lgamma() - self.temperature.log().mul(-(K - 1)) + score = logits - value.mul(self.temperature) + score = (score - ops.logsumexp(score, dim=-1, keepdim=True)).sum(-1) + return score + log_scale + + + +class RelaxedOneHotCategorical(TransformedDistribution): + r""" + Creates a RelaxedOneHotCategorical distribution parametrized by + :attr:`temperature`, and either :attr:`probs` or :attr:`logits`. + This is a relaxed version of the :class:`OneHotCategorical` distribution, so + its samples are on simplex, and are reparametrizable. + + Example:: + + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> m = RelaxedOneHotCategorical(mindspore.tensor([2.2]), + ... mindspore.tensor([0.1, 0.2, 0.3, 0.4])) + >>> m.sample() + tensor([ 0.1294, 0.2324, 0.3859, 0.2523]) + + Args: + temperature (Tensor): relaxation temperature + probs (Tensor): event probabilities + logits (Tensor): unnormalized log probability for each event + """ + arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector} + support = constraints.simplex + has_rsample = True + + def __init__(self, temperature, probs=None, logits=None, validate_args=None): + base_dist = ExpRelaxedCategorical( + temperature, probs, logits, validate_args=validate_args + ) + super().__init__(base_dist, ExpTransform(), validate_args=validate_args) + + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(RelaxedOneHotCategorical, _instance) + return super().expand(batch_shape, _instance=new) + + + @property + def temperature(self): + return self.base_dist.temperature + + @property + def logits(self): + return self.base_dist.logits + + @property + def probs(self): + return self.base_dist.probs diff --git a/mindnlp/core/distributions/studentT.py b/mindnlp/core/distributions/studentT.py new file mode 100644 index 000000000..3bbdb2559 --- /dev/null +++ b/mindnlp/core/distributions/studentT.py @@ -0,0 +1,119 @@ +"""studentT""" +# mypy: allow-untyped-defs +import math +from math import inf, nan + +from .. import ops +from . import constraints +from .chi2 import Chi2 +from .distribution import Distribution +from .utils import _standard_normal, broadcast_all + + +__all__ = ["StudentT"] + + +class StudentT(Distribution): + r""" + Creates a Student's t-distribution parameterized by degree of + freedom :attr:`df`, mean :attr:`loc` and scale :attr:`scale`. + + Example:: + + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> m = StudentT(core.tensor([2.0])) + >>> m.sample() # Student's t-distributed with degrees of freedom=2 + tensor([ 0.1046]) + + Args: + df (float or Tensor): degrees of freedom + loc (float or Tensor): mean of the distribution + scale (float or Tensor): scale of the distribution + """ + arg_constraints = { + "df": constraints.positive, + "loc": constraints.real, + "scale": constraints.positive, + } + support = constraints.real + has_rsample = True + + @property + def mean(self): + m = self.loc.copy() + m[self.df <= 1] = nan + return m + + @property + def mode(self): + return self.loc + + @property + def variance(self): + m = self.df.copy() + m[self.df > 2] = ( + self.scale[self.df > 2].pow(2) + * self.df[self.df > 2] + / (self.df[self.df > 2] - 2) + ) + m[(self.df <= 2) & (self.df > 1)] = inf + m[self.df <= 1] = nan + return m + + def __init__(self, df, loc=0.0, scale=1.0, validate_args=None): + self.df, self.loc, self.scale = broadcast_all(df, loc, scale) + self._chi2 = Chi2(self.df) + batch_shape = self.df.shape + super().__init__(batch_shape, validate_args=validate_args) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(StudentT, _instance) + new.df = self.df.expand(batch_shape) + new.loc = self.loc.expand(batch_shape) + new.scale = self.scale.expand(batch_shape) + new._chi2 = self._chi2.expand(batch_shape) + super(StudentT, new).__init__(batch_shape, validate_args=False) + new._validate_args = self._validate_args + return new + + def rsample(self, sample_shape=()): + # NOTE: This does not agree with scipy implementation as much as other distributions. + # (see https://github.com/fritzo/notebooks/blob/master/debug-student-t.ipynb). Using DoubleTensor + # parameters seems to help. + + # X ~ Normal(0, 1) + # Z ~ Chi2(df) + # Y = X / sqrt(Z / df) ~ StudentT(df) + shape = self._extended_shape(sample_shape) + X = _standard_normal(shape, dtype=self.df.dtype) + Z = self._chi2.rsample(sample_shape) + Y = X * ops.rsqrt(Z / self.df) + return self.loc + self.scale * Y + + def log_prob(self, value): + if self._validate_args: + self._validate_sample(value) + y = (value - self.loc) / self.scale + Z = ( + self.scale.log() + + 0.5 * self.df.log() + + 0.5 * math.log(math.pi) + + ops.lgamma(0.5 * self.df) + - ops.lgamma(0.5 * (self.df + 1.0)) + ) + return -0.5 * (self.df + 1.0) * ops.log1p(y**2.0 / self.df) - Z + + def entropy(self): + lbeta = ( + ops.lgamma(0.5 * self.df) + + math.lgamma(0.5) + - ops.lgamma(0.5 * (self.df + 1)) + ) + return ( + self.scale.log() + + 0.5 + * (self.df + 1) + * (ops.digamma(0.5 * (self.df + 1)) - ops.digamma(0.5 * self.df)) + + 0.5 * self.df.log() + + lbeta + ) diff --git a/mindnlp/core/distributions/transformed_distribution.py b/mindnlp/core/distributions/transformed_distribution.py new file mode 100644 index 000000000..b2b3293c3 --- /dev/null +++ b/mindnlp/core/distributions/transformed_distribution.py @@ -0,0 +1,215 @@ +"""transformed distribution""" +# mypy: allow-untyped-defs +from typing import Dict + +from ..autograd import no_grad +from . import constraints +from .distribution import Distribution +from .independent import Independent +from .transforms import ComposeTransform, Transform +from .utils import _sum_rightmost + + +__all__ = ["TransformedDistribution"] + + +class TransformedDistribution(Distribution): + r""" + Extension of the Distribution class, which applies a sequence of Transforms + to a base distribution. Let f be the composition of transforms applied:: + + X ~ BaseDistribution + Y = f(X) ~ TransformedDistribution(BaseDistribution, f) + log p(Y) = log p(X) + log |det (dX/dY)| + + Note that the ``.event_shape`` of a :class:`TransformedDistribution` is the + maximum shape of its base distribution and its transforms, since transforms + can introduce correlations among events. + + An example for the usage of :class:`TransformedDistribution` would be:: + + # Building a Logistic Distribution + # X ~ Uniform(0, 1) + # f = a + b * logit(X) + # Y ~ f(X) ~ Logistic(a, b) + base_distribution = Uniform(0, 1) + transforms = [SigmoidTransform().inv, AffineTransform(loc=a, scale=b)] + logistic = TransformedDistribution(base_distribution, transforms) + + For more examples, please look at the implementations of + :class:`~core.distributions.gumbel.Gumbel`, + :class:`~core.distributions.half_cauchy.HalfCauchy`, + :class:`~core.distributions.half_normal.HalfNormal`, + :class:`~core.distributions.log_normal.LogNormal`, + :class:`~core.distributions.pareto.Pareto`, + :class:`~core.distributions.weibull.Weibull`, + :class:`~core.distributions.relaxed_bernoulli.RelaxedBernoulli` and + :class:`~core.distributions.relaxed_categorical.RelaxedOneHotCategorical` + """ + arg_constraints: Dict[str, constraints.Constraint] = {} + + def __init__(self, base_distribution, transforms, validate_args=None): + if isinstance(transforms, Transform): + self.transforms = [ + transforms, + ] + elif isinstance(transforms, list): + if not all(isinstance(t, Transform) for t in transforms): + raise ValueError( + "transforms must be a Transform or a list of Transforms" + ) + self.transforms = transforms + else: + raise ValueError( + f"transforms must be a Transform or list, but was {transforms}" + ) + + # Reshape base_distribution according to transforms. + base_shape = base_distribution.batch_shape + base_distribution.event_shape + base_event_dim = len(base_distribution.event_shape) + transform = ComposeTransform(self.transforms) + if len(base_shape) < transform.domain.event_dim: + raise ValueError( + f"base_distribution needs to have shape with size at least {transform.domain.event_dim}, but got {base_shape}." + ) + forward_shape = transform.forward_shape(base_shape) + expanded_base_shape = transform.inverse_shape(forward_shape) + if base_shape != expanded_base_shape: + base_batch_shape = expanded_base_shape[ + : len(expanded_base_shape) - base_event_dim + ] + base_distribution = base_distribution.expand(base_batch_shape) + reinterpreted_batch_ndims = transform.domain.event_dim - base_event_dim + if reinterpreted_batch_ndims > 0: + base_distribution = Independent( + base_distribution, reinterpreted_batch_ndims + ) + self.base_dist = base_distribution + + # Compute shapes. + transform_change_in_event_dim = ( + transform.codomain.event_dim - transform.domain.event_dim + ) + event_dim = max( + transform.codomain.event_dim, # the transform is coupled + base_event_dim + transform_change_in_event_dim, # the base dist is coupled + ) + assert len(forward_shape) >= event_dim + cut = len(forward_shape) - event_dim + batch_shape = forward_shape[:cut] + event_shape = forward_shape[cut:] + super().__init__(batch_shape, event_shape, validate_args=validate_args) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(TransformedDistribution, _instance) + shape = batch_shape + self.event_shape + for t in reversed(self.transforms): + shape = t.inverse_shape(shape) + base_batch_shape = shape[: len(shape) - len(self.base_dist.event_shape)] + new.base_dist = self.base_dist.expand(base_batch_shape) + new.transforms = self.transforms + super(TransformedDistribution, new).__init__( + batch_shape, self.event_shape, validate_args=False + ) + new._validate_args = self._validate_args + return new + + @constraints.dependent_property(is_discrete=False) + def support(self): # pylint: disable=invalid-overridden-method + if not self.transforms: + return self.base_dist.support + support = self.transforms[-1].codomain + if len(self.event_shape) > support.event_dim: + support = constraints.independent( + support, len(self.event_shape) - support.event_dim + ) + return support + + @property + def has_rsample(self): + return self.base_dist.has_rsample + + def sample(self, sample_shape=()): + """ + Generates a sample_shape shaped sample or sample_shape shaped batch of + samples if the distribution parameters are batched. Samples first from + base distribution and applies `transform()` for every transform in the + list. + """ + with no_grad(): + x = self.base_dist.sample(sample_shape) + for transform in self.transforms: + x = transform(x) + return x + + def rsample(self, sample_shape=()): + """ + Generates a sample_shape shaped reparameterized sample or sample_shape + shaped batch of reparameterized samples if the distribution parameters + are batched. Samples first from base distribution and applies + `transform()` for every transform in the list. + """ + x = self.base_dist.rsample(sample_shape) + for transform in self.transforms: + x = transform(x) + return x + + def log_prob(self, value): + """ + Scores the sample by inverting the transform(s) and computing the score + using the score of the base distribution and the log abs det jacobian. + """ + if self._validate_args: + self._validate_sample(value) + event_dim = len(self.event_shape) + log_prob = 0.0 + y = value + for transform in reversed(self.transforms): + x = transform.inv(y) + event_dim += transform.domain.event_dim - transform.codomain.event_dim + log_prob = log_prob - _sum_rightmost( + transform.log_abs_det_jacobian(x, y), + event_dim - transform.domain.event_dim, + ) + y = x + + log_prob = log_prob + _sum_rightmost( + self.base_dist.log_prob(y), event_dim - len(self.base_dist.event_shape) + ) + return log_prob + + def _monotonize_cdf(self, value): + """ + This conditionally flips ``value -> 1-value`` to ensure :meth:`cdf` is + monotone increasing. + """ + sign = 1 + for transform in self.transforms: + sign = sign * transform.sign + if isinstance(sign, int) and sign == 1: + return value + return sign * (value - 0.5) + 0.5 + + def cdf(self, value): + """ + Computes the cumulative distribution function by inverting the + transform(s) and computing the score of the base distribution. + """ + for transform in self.transforms[::-1]: + value = transform.inv(value) + if self._validate_args: + self.base_dist._validate_sample(value) + value = self.base_dist.cdf(value) + value = self._monotonize_cdf(value) + return value + + def icdf(self, value): + """ + Computes the inverse cumulative distribution function using + transform(s) and computing the score of the base distribution. + """ + value = self._monotonize_cdf(value) + value = self.base_dist.icdf(value) + for transform in self.transforms: + value = transform(value) + return value diff --git a/mindnlp/core/distributions/transforms.py b/mindnlp/core/distributions/transforms.py new file mode 100644 index 000000000..0f9c8a8ee --- /dev/null +++ b/mindnlp/core/distributions/transforms.py @@ -0,0 +1,1249 @@ +"""transforms""" +# pylint: disable=invalid-overridden-method +# mypy: allow-untyped-defs +import functools +import math +import numbers +import operator +import weakref +from typing import List + +from .. import ops +from ..nn import functional as F +from . import constraints +from .utils import ( + _sum_rightmost, + broadcast_all, + lazy_property, + tril_matrix_to_vec, + vec_to_tril_matrix, +) +from ..nn.functional import pad, softplus + + +__all__ = [ + "AbsTransform", + "AffineTransform", + "CatTransform", + "ComposeTransform", + "CorrCholeskyTransform", + "CumulativeDistributionTransform", + "ExpTransform", + "IndependentTransform", + "LowerCholeskyTransform", + "PositiveDefiniteTransform", + "PowerTransform", + "ReshapeTransform", + "SigmoidTransform", + "SoftplusTransform", + "TanhTransform", + "SoftmaxTransform", + "StackTransform", + "StickBreakingTransform", + "Transform", + "identity_transform", +] + + +class Transform: + """ + Abstract class for invertable transformations with computable log + det jacobians. They are primarily used in + :class:`core.distributions.TransformedDistribution`. + + Caching is useful for transforms whose inverses are either expensive or + numerically unstable. Note that care must be taken with memoized values + since the autograd graph may be reversed. For example while the following + works with or without caching:: + + y = t(x) + t.log_abs_det_jacobian(x, y).backward() # x will receive gradients. + + However the following will error when caching due to dependency reversal:: + + y = t(x) + z = t.inv(y) + grad(z.sum(), [y]) # error because z is x + + Derived classes should implement one or both of :meth:`_call` or + :meth:`_inverse`. Derived classes that set `bijective=True` should also + implement :meth:`log_abs_det_jacobian`. + + Args: + cache_size (int): Size of cache. If zero, no caching is done. If one, + the latest single value is cached. Only 0 and 1 are supported. + + Attributes: + domain (:class:`~core.distributions.constraints.Constraint`): + The constraint representing valid inputs to this transform. + codomain (:class:`~core.distributions.constraints.Constraint`): + The constraint representing valid outputs to this transform + which are inputs to the inverse transform. + bijective (bool): Whether this transform is bijective. A transform + ``t`` is bijective iff ``t.inv(t(x)) == x`` and + ``t(t.inv(y)) == y`` for every ``x`` in the domain and ``y`` in + the codomain. Transforms that are not bijective should at least + maintain the weaker pseudoinverse properties + ``t(t.inv(t(x)) == t(x)`` and ``t.inv(t(t.inv(y))) == t.inv(y)``. + sign (int or Tensor): For bijective univariate transforms, this + should be +1 or -1 depending on whether transform is monotone + increasing or decreasing. + """ + + bijective = False + domain: constraints.Constraint + codomain: constraints.Constraint + + def __init__(self, cache_size=0): + self._cache_size = cache_size + self._inv = None + if cache_size == 0: + pass # default behavior + elif cache_size == 1: + self._cached_x_y = None, None + else: + raise ValueError("cache_size must be 0 or 1") + super().__init__() + + def __getstate__(self): + state = self.__dict__.copy() + state["_inv"] = None + return state + + @property + def event_dim(self): + if self.domain.event_dim == self.codomain.event_dim: + return self.domain.event_dim + raise ValueError("Please use either .domain.event_dim or .codomain.event_dim") + + @property + def inv(self): + """ + Returns the inverse :class:`Transform` of this transform. + This should satisfy ``t.inv.inv is t``. + """ + inv = None + if self._inv is not None: + inv = self._inv() + if inv is None: + inv = _InverseTransform(self) + self._inv = weakref.ref(inv) + return inv + + @property + def sign(self): + """ + Returns the sign of the determinant of the Jacobian, if applicable. + In general this only makes sense for bijective transforms. + """ + raise NotImplementedError + + def with_cache(self, cache_size=1): + if self._cache_size == cache_size: + return self + if type(self).__init__ is Transform.__init__: + return type(self)(cache_size=cache_size) + raise NotImplementedError(f"{type(self)}.with_cache is not implemented") + + def __eq__(self, other): + return self is other + + def __ne__(self, other): + # Necessary for Python2 + return not self.__eq__(other) + + def __call__(self, x): + """ + Computes the transform `x => y`. + """ + if self._cache_size == 0: + return self._call(x) + x_old, y_old = self._cached_x_y + if x is x_old: + return y_old + y = self._call(x) + self._cached_x_y = x, y + return y + + def _inv_call(self, y): + """ + Inverts the transform `y => x`. + """ + if self._cache_size == 0: + return self._inverse(y) + x_old, y_old = self._cached_x_y + if y is y_old: + return x_old + x = self._inverse(y) + self._cached_x_y = x, y + return x + + def _call(self, x): + """ + Abstract method to compute forward transformation. + """ + raise NotImplementedError + + def _inverse(self, y): + """ + Abstract method to compute inverse transformation. + """ + raise NotImplementedError + + def log_abs_det_jacobian(self, x, y): + """ + Computes the log det jacobian `log |dy/dx|` given input and output. + """ + raise NotImplementedError + + def __repr__(self): + return self.__class__.__name__ + "()" + + def forward_shape(self, shape): + """ + Infers the shape of the forward computation, given the input shape. + Defaults to preserving shape. + """ + return shape + + def inverse_shape(self, shape): + """ + Infers the shapes of the inverse computation, given the output shape. + Defaults to preserving shape. + """ + return shape + + +class _InverseTransform(Transform): + """ + Inverts a single :class:`Transform`. + This class is private; please instead use the ``Transform.inv`` property. + """ + + def __init__(self, transform: Transform): + super().__init__(cache_size=transform._cache_size) + self._inv: Transform = transform + + @constraints.dependent_property(is_discrete=False) + def domain(self): + assert self._inv is not None + return self._inv.codomain + + @constraints.dependent_property(is_discrete=False) + def codomain(self): + assert self._inv is not None + return self._inv.domain + + @property + def bijective(self): + assert self._inv is not None + return self._inv.bijective + + @property + def sign(self): + assert self._inv is not None + return self._inv.sign + + @property + def inv(self): + return self._inv + + def with_cache(self, cache_size=1): + assert self._inv is not None + return self.inv.with_cache(cache_size).inv + + def __eq__(self, other): + if not isinstance(other, _InverseTransform): + return False + assert self._inv is not None + return self._inv == other._inv + + def __repr__(self): + return f"{self.__class__.__name__}({repr(self._inv)})" + + def __call__(self, x): + assert self._inv is not None + return self._inv._inv_call(x) + + def log_abs_det_jacobian(self, x, y): + assert self._inv is not None + return -self._inv.log_abs_det_jacobian(y, x) + + def forward_shape(self, shape): + return self._inv.inverse_shape(shape) + + def inverse_shape(self, shape): + return self._inv.forward_shape(shape) + + +class ComposeTransform(Transform): + """ + Composes multiple transforms in a chain. + The transforms being composed are responsible for caching. + + Args: + parts (list of :class:`Transform`): A list of transforms to compose. + cache_size (int): Size of cache. If zero, no caching is done. If one, + the latest single value is cached. Only 0 and 1 are supported. + """ + + def __init__(self, parts: List[Transform], cache_size=0): + if cache_size: + parts = [part.with_cache(cache_size) for part in parts] + super().__init__(cache_size=cache_size) + self.parts = parts + + def __eq__(self, other): + if not isinstance(other, ComposeTransform): + return False + return self.parts == other.parts + + @constraints.dependent_property(is_discrete=False) + def domain(self): + if not self.parts: + return constraints.real + domain = self.parts[0].domain + # Adjust event_dim to be maximum among all parts. + event_dim = self.parts[-1].codomain.event_dim + for part in reversed(self.parts): + event_dim += part.domain.event_dim - part.codomain.event_dim + event_dim = max(event_dim, part.domain.event_dim) + assert event_dim >= domain.event_dim + if event_dim > domain.event_dim: + domain = constraints.independent(domain, event_dim - domain.event_dim) + return domain + + @constraints.dependent_property(is_discrete=False) + def codomain(self): + if not self.parts: + return constraints.real + codomain = self.parts[-1].codomain + # Adjust event_dim to be maximum among all parts. + event_dim = self.parts[0].domain.event_dim + for part in self.parts: + event_dim += part.codomain.event_dim - part.domain.event_dim + event_dim = max(event_dim, part.codomain.event_dim) + assert event_dim >= codomain.event_dim + if event_dim > codomain.event_dim: + codomain = constraints.independent(codomain, event_dim - codomain.event_dim) + return codomain + + @lazy_property + def bijective(self): + return all(p.bijective for p in self.parts) + + @lazy_property + def sign(self): + sign = 1 + for p in self.parts: + sign = sign * p.sign + return sign + + @property + def inv(self): + inv = None + if self._inv is not None: + inv = self._inv() + if inv is None: + inv = ComposeTransform([p.inv for p in reversed(self.parts)]) + self._inv = weakref.ref(inv) + inv._inv = weakref.ref(self) + return inv + + def with_cache(self, cache_size=1): + if self._cache_size == cache_size: + return self + return ComposeTransform(self.parts, cache_size=cache_size) + + def __call__(self, x): + for part in self.parts: + x = part(x) + return x + + def log_abs_det_jacobian(self, x, y): + if not self.parts: + return ops.zeros_like(x) + + # Compute intermediates. This will be free if parts[:-1] are all cached. + xs = [x] + for part in self.parts[:-1]: + xs.append(part(xs[-1])) + xs.append(y) + + terms = [] + event_dim = self.domain.event_dim + for part, x, y in zip(self.parts, xs[:-1], xs[1:]): + terms.append( + _sum_rightmost( + part.log_abs_det_jacobian(x, y), event_dim - part.domain.event_dim + ) + ) + event_dim += part.codomain.event_dim - part.domain.event_dim + return functools.reduce(operator.add, terms) + + def forward_shape(self, shape): + for part in self.parts: + shape = part.forward_shape(shape) + return shape + + def inverse_shape(self, shape): + for part in reversed(self.parts): + shape = part.inverse_shape(shape) + return shape + + def __repr__(self): + fmt_string = self.__class__.__name__ + "(\n " + fmt_string += ",\n ".join([p.__repr__() for p in self.parts]) + fmt_string += "\n)" + return fmt_string + + +identity_transform = ComposeTransform([]) + + +class IndependentTransform(Transform): + """ + Wrapper around another transform to treat + ``reinterpreted_batch_ndims``-many extra of the right most dimensions as + dependent. This has no effect on the forward or backward transforms, but + does sum out ``reinterpreted_batch_ndims``-many of the rightmost dimensions + in :meth:`log_abs_det_jacobian`. + + Args: + base_transform (:class:`Transform`): A base transform. + reinterpreted_batch_ndims (int): The number of extra rightmost + dimensions to treat as dependent. + """ + + def __init__(self, base_transform, reinterpreted_batch_ndims, cache_size=0): + super().__init__(cache_size=cache_size) + self.base_transform = base_transform.with_cache(cache_size) + self.reinterpreted_batch_ndims = reinterpreted_batch_ndims + + def with_cache(self, cache_size=1): + if self._cache_size == cache_size: + return self + return IndependentTransform( + self.base_transform, self.reinterpreted_batch_ndims, cache_size=cache_size + ) + + @constraints.dependent_property(is_discrete=False) + def domain(self): + return constraints.independent( + self.base_transform.domain, self.reinterpreted_batch_ndims + ) + + @constraints.dependent_property(is_discrete=False) + def codomain(self): + return constraints.independent( + self.base_transform.codomain, self.reinterpreted_batch_ndims + ) + + @property + def bijective(self): + return self.base_transform.bijective + + @property + def sign(self): + return self.base_transform.sign + + def _call(self, x): + if x.dim() < self.domain.event_dim: + raise ValueError("Too few dimensions on input") + return self.base_transform(x) + + def _inverse(self, y): + if y.dim() < self.codomain.event_dim: + raise ValueError("Too few dimensions on input") + return self.base_transform.inv(y) + + def log_abs_det_jacobian(self, x, y): + result = self.base_transform.log_abs_det_jacobian(x, y) + result = _sum_rightmost(result, self.reinterpreted_batch_ndims) + return result + + def __repr__(self): + return f"{self.__class__.__name__}({repr(self.base_transform)}, {self.reinterpreted_batch_ndims})" + + def forward_shape(self, shape): + return self.base_transform.forward_shape(shape) + + def inverse_shape(self, shape): + return self.base_transform.inverse_shape(shape) + + +class ReshapeTransform(Transform): + """ + Unit Jacobian transform to reshape the rightmost part of a tensor. + + Note that ``in_shape`` and ``out_shape`` must have the same number of + elements, just as for :meth:`mindspore.Tensor.reshape`. + + Arguments: + in_shape: The input event shape. + out_shape: The output event shape. + """ + + bijective = True + + def __init__(self, in_shape, out_shape, cache_size=0): + self.in_shape = in_shape + self.out_shape = out_shape + if self.in_shape.numel() != self.out_shape.numel(): + raise ValueError("in_shape, out_shape have different numbers of elements") + super().__init__(cache_size=cache_size) + + @constraints.dependent_property + def domain(self): + return constraints.independent(constraints.real, len(self.in_shape)) + + @constraints.dependent_property + def codomain(self): + return constraints.independent(constraints.real, len(self.out_shape)) + + def with_cache(self, cache_size=1): + if self._cache_size == cache_size: + return self + return ReshapeTransform(self.in_shape, self.out_shape, cache_size=cache_size) + + def _call(self, x): + batch_shape = x.shape[: x.dim() - len(self.in_shape)] + return x.reshape(batch_shape + self.out_shape) + + def _inverse(self, y): + batch_shape = y.shape[: y.dim() - len(self.out_shape)] + return y.reshape(batch_shape + self.in_shape) + + def log_abs_det_jacobian(self, x, y): + batch_shape = x.shape[: x.dim() - len(self.in_shape)] + return x.new_zeros(batch_shape) + + def forward_shape(self, shape): + if len(shape) < len(self.in_shape): + raise ValueError("Too few dimensions on input") + cut = len(shape) - len(self.in_shape) + if shape[cut:] != self.in_shape: + raise ValueError( + f"Shape mismatch: expected {shape[cut:]} but got {self.in_shape}" + ) + return shape[:cut] + self.out_shape + + def inverse_shape(self, shape): + if len(shape) < len(self.out_shape): + raise ValueError("Too few dimensions on input") + cut = len(shape) - len(self.out_shape) + if shape[cut:] != self.out_shape: + raise ValueError( + f"Shape mismatch: expected {shape[cut:]} but got {self.out_shape}" + ) + return shape[:cut] + self.in_shape + + +class ExpTransform(Transform): + r""" + Transform via the mapping :math:`y = \exp(x)`. + """ + domain = constraints.real + codomain = constraints.positive + bijective = True + sign = +1 + + def __eq__(self, other): + return isinstance(other, ExpTransform) + + def _call(self, x): + return x.exp() + + def _inverse(self, y): + return y.log() + + def log_abs_det_jacobian(self, x, y): + return x + + +class PowerTransform(Transform): + r""" + Transform via the mapping :math:`y = x^{\text{exponent}}`. + """ + domain = constraints.positive + codomain = constraints.positive + bijective = True + + def __init__(self, exponent, cache_size=0): + super().__init__(cache_size=cache_size) + (self.exponent,) = broadcast_all(exponent) + + def with_cache(self, cache_size=1): + if self._cache_size == cache_size: + return self + return PowerTransform(self.exponent, cache_size=cache_size) + + @lazy_property + def sign(self): + return self.exponent.sign() + + def __eq__(self, other): + if not isinstance(other, PowerTransform): + return False + return self.exponent.eq(other.exponent).all().item() + + def _call(self, x): + return x.pow(self.exponent) + + def _inverse(self, y): + return y.pow(1 / self.exponent) + + def log_abs_det_jacobian(self, x, y): + return (self.exponent * y / x).abs().log() + + def forward_shape(self, shape): + return ops.broadcast_shapes(shape, getattr(self.exponent, "shape", ())) + + def inverse_shape(self, shape): + return ops.broadcast_shapes(shape, getattr(self.exponent, "shape", ())) + + +def _clipped_sigmoid(x): + finfo = ops.finfo(x.dtype) + return ops.clamp(ops.sigmoid(x), min=finfo.tiny, max=1.0 - finfo.eps) + + +class SigmoidTransform(Transform): + r""" + Transform via the mapping :math:`y = \frac{1}{1 + \exp(-x)}` and :math:`x = \text{logit}(y)`. + """ + domain = constraints.real + codomain = constraints.unit_interval + bijective = True + sign = +1 + + def __eq__(self, other): + return isinstance(other, SigmoidTransform) + + def _call(self, x): + return _clipped_sigmoid(x) + + def _inverse(self, y): + finfo = ops.finfo(y.dtype) + y = y.clamp(min=finfo.tiny, max=1.0 - finfo.eps) + return y.log() - (-y).log1p() + + def log_abs_det_jacobian(self, x, y): + return -F.softplus(-x) - F.softplus(x) + + +class SoftplusTransform(Transform): + r""" + Transform via the mapping :math:`\text{Softplus}(x) = \log(1 + \exp(x))`. + The implementation reverts to the linear function when :math:`x > 20`. + """ + domain = constraints.real + codomain = constraints.positive + bijective = True + sign = +1 + + def __eq__(self, other): + return isinstance(other, SoftplusTransform) + + def _call(self, x): + return softplus(x) + + def _inverse(self, y): + return (-y).expm1().neg().log() + y + + def log_abs_det_jacobian(self, x, y): + return -softplus(-x) + + +class TanhTransform(Transform): + r""" + Transform via the mapping :math:`y = \tanh(x)`. + + It is equivalent to + ``` + ComposeTransform([AffineTransform(0., 2.), SigmoidTransform(), AffineTransform(-1., 2.)]) + ``` + However this might not be numerically stable, thus it is recommended to use `TanhTransform` + instead. + + Note that one should use `cache_size=1` when it comes to `NaN/Inf` values. + + """ + domain = constraints.real + codomain = constraints.interval(-1.0, 1.0) + bijective = True + sign = +1 + + def __eq__(self, other): + return isinstance(other, TanhTransform) + + def _call(self, x): + return x.tanh() + + def _inverse(self, y): + # We do not clamp to the boundary here as it may degrade the performance of certain algorithms. + # one should use `cache_size=1` instead + return ops.atanh(y) + + def log_abs_det_jacobian(self, x, y): + # We use a formula that is more numerically stable, see details in the following link + # https://github.com/tensorflow/probability/blob/master/tensorflow_probability/python/bijectors/tanh.py#L69-L80 + return 2.0 * (math.log(2.0) - x - softplus(-2.0 * x)) + + +class AbsTransform(Transform): + r""" + Transform via the mapping :math:`y = |x|`. + """ + domain = constraints.real + codomain = constraints.positive + + def __eq__(self, other): + return isinstance(other, AbsTransform) + + def _call(self, x): + return x.abs() + + def _inverse(self, y): + return y + + +class AffineTransform(Transform): + r""" + Transform via the pointwise affine mapping :math:`y = \text{loc} + \text{scale} \times x`. + + Args: + loc (Tensor or float): Location parameter. + scale (Tensor or float): Scale parameter. + event_dim (int): Optional size of `event_shape`. This should be zero + for univariate random variables, 1 for distributions over vectors, + 2 for distributions over matrices, etc. + """ + bijective = True + + def __init__(self, loc, scale, event_dim=0, cache_size=0): + super().__init__(cache_size=cache_size) + self.loc = loc + self.scale = scale + self._event_dim = event_dim + + @property + def event_dim(self): + return self._event_dim + + @constraints.dependent_property(is_discrete=False) + def domain(self): + if self.event_dim == 0: + return constraints.real + return constraints.independent(constraints.real, self.event_dim) + + @constraints.dependent_property(is_discrete=False) + def codomain(self): + if self.event_dim == 0: + return constraints.real + return constraints.independent(constraints.real, self.event_dim) + + def with_cache(self, cache_size=1): + if self._cache_size == cache_size: + return self + return AffineTransform( + self.loc, self.scale, self.event_dim, cache_size=cache_size + ) + + def __eq__(self, other): + if not isinstance(other, AffineTransform): + return False + + if isinstance(self.loc, numbers.Number) and isinstance( + other.loc, numbers.Number + ): + if self.loc != other.loc: + return False + else: + if not (self.loc == other.loc).all().item(): + return False + + if isinstance(self.scale, numbers.Number) and isinstance( + other.scale, numbers.Number + ): + if self.scale != other.scale: + return False + else: + if not (self.scale == other.scale).all().item(): + return False + + return True + + @property + def sign(self): + if isinstance(self.scale, numbers.Real): + return 1 if float(self.scale) > 0 else -1 if float(self.scale) < 0 else 0 + return self.scale.sign() + + def _call(self, x): + return self.loc + self.scale * x + + def _inverse(self, y): + return (y - self.loc) / self.scale + + def log_abs_det_jacobian(self, x, y): + shape = x.shape + scale = self.scale + if isinstance(scale, numbers.Real): + result = ops.full_like(x, math.log(abs(scale))) + else: + result = ops.abs(scale).log() + if self.event_dim: + result_size = result.size()[: -self.event_dim] + (-1,) + result = result.view(result_size).sum(-1) + shape = shape[: -self.event_dim] + return result.expand(shape) + + def forward_shape(self, shape): + return ops.broadcast_shapes( + shape, getattr(self.loc, "shape", ()), getattr(self.scale, "shape", ()) + ) + + def inverse_shape(self, shape): + return ops.broadcast_shapes( + shape, getattr(self.loc, "shape", ()), getattr(self.scale, "shape", ()) + ) + + +class CorrCholeskyTransform(Transform): + r""" + Transforms an uncontrained real vector :math:`x` with length :math:`D*(D-1)/2` into the + Cholesky factor of a D-dimension correlation matrix. This Cholesky factor is a lower + triangular matrix with positive diagonals and unit Euclidean norm for each row. + The transform is processed as follows: + + 1. First we convert x into a lower triangular matrix in row order. + 2. For each row :math:`X_i` of the lower triangular part, we apply a *signed* version of + class :class:`StickBreakingTransform` to transform :math:`X_i` into a + unit Euclidean length vector using the following steps: + - Scales into the interval :math:`(-1, 1)` domain: :math:`r_i = \tanh(X_i)`. + - Transforms into an unsigned domain: :math:`z_i = r_i^2`. + - Applies :math:`s_i = StickBreakingTransform(z_i)`. + - Transforms back into signed domain: :math:`y_i = sign(r_i) * \sqrt{s_i}`. + """ + domain = constraints.real_vector + codomain = constraints.corr_cholesky + bijective = True + + def _call(self, x): + x = ops.tanh(x) + eps = ops.finfo(x.dtype).eps + x = x.clamp(min=-1 + eps, max=1 - eps) + r = vec_to_tril_matrix(x, diag=-1) + # apply stick-breaking on the squared values + # Note that y = sign(r) * sqrt(z * z1m_cumprod) + # = (sign(r) * sqrt(z)) * sqrt(z1m_cumprod) = r * sqrt(z1m_cumprod) + z = r**2 + z1m_cumprod_sqrt = (1 - z).sqrt().cumprod(-1) + # Diagonal elements must be 1. + r = r + ops.eye(r.shape[-1], dtype=r.dtype) + y = r * pad(z1m_cumprod_sqrt[..., :-1], [1, 0], value=1) + return y + + def _inverse(self, y): + # inverse stick-breaking + # See: https://mc-stan.org/docs/2_18/reference-manual/cholesky-factors-of-correlation-matrices-1.html + y_cumsum = 1 - ops.cumsum(y * y, dim=-1) + y_cumsum_shifted = pad(y_cumsum[..., :-1], [1, 0], value=1) + y_vec = tril_matrix_to_vec(y, diag=-1) + y_cumsum_vec = tril_matrix_to_vec(y_cumsum_shifted, diag=-1) + t = y_vec / (y_cumsum_vec).sqrt() + # inverse of tanh + x = (t.log1p() - t.neg().log1p()) / 2 + return x + + def log_abs_det_jacobian(self, x, y, intermediates=None): + # Because domain and codomain are two spaces with different dimensions, determinant of + # Jacobian is not well-defined. We return `log_abs_det_jacobian` of `x` and the + # flattened lower triangular part of `y`. + + # See: https://mc-stan.org/docs/2_18/reference-manual/cholesky-factors-of-correlation-matrices-1.html + y1m_cumsum = 1 - (y * y).cumsum(dim=-1) + # by taking diagonal=-2, we don't need to shift z_cumprod to the right + # also works for 2 x 2 matrix + y1m_cumsum_tril = tril_matrix_to_vec(y1m_cumsum, diag=-2) + stick_breaking_logdet = 0.5 * (y1m_cumsum_tril).log().sum(-1) + tanh_logdet = -2 * (x + softplus(-2 * x) - math.log(2.0)).sum(dim=-1) + return stick_breaking_logdet + tanh_logdet + + def forward_shape(self, shape): + # Reshape from (..., N) to (..., D, D). + if len(shape) < 1: + raise ValueError("Too few dimensions on input") + N = shape[-1] + D = round((0.25 + 2 * N) ** 0.5 + 0.5) + if D * (D - 1) // 2 != N: + raise ValueError("Input is not a flattend lower-diagonal number") + return shape[:-1] + (D, D) + + def inverse_shape(self, shape): + # Reshape from (..., D, D) to (..., N). + if len(shape) < 2: + raise ValueError("Too few dimensions on input") + if shape[-2] != shape[-1]: + raise ValueError("Input is not square") + D = shape[-1] + N = D * (D - 1) // 2 + return shape[:-2] + (N,) + + +class SoftmaxTransform(Transform): + r""" + Transform from unconstrained space to the simplex via :math:`y = \exp(x)` then + normalizing. + + This is not bijective and cannot be used for HMC. However this acts mostly + coordinate-wise (except for the final normalization), and thus is + appropriate for coordinate-wise optimization algorithms. + """ + domain = constraints.real_vector + codomain = constraints.simplex + + def __eq__(self, other): + return isinstance(other, SoftmaxTransform) + + def _call(self, x): + logprobs = x + probs = (logprobs - logprobs.max(-1, True)[0]).exp() + return probs / probs.sum(-1, True) + + def _inverse(self, y): + probs = y + return probs.log() + + def forward_shape(self, shape): + if len(shape) < 1: + raise ValueError("Too few dimensions on input") + return shape + + def inverse_shape(self, shape): + if len(shape) < 1: + raise ValueError("Too few dimensions on input") + return shape + + +class StickBreakingTransform(Transform): + """ + Transform from unconstrained space to the simplex of one additional + dimension via a stick-breaking process. + + This transform arises as an iterated sigmoid transform in a stick-breaking + construction of the `Dirichlet` distribution: the first logit is + transformed via sigmoid to the first probability and the probability of + everything else, and then the process recurses. + + This is bijective and appropriate for use in HMC; however it mixes + coordinates together and is less appropriate for optimization. + """ + + domain = constraints.real_vector + codomain = constraints.simplex + bijective = True + + def __eq__(self, other): + return isinstance(other, StickBreakingTransform) + + def _call(self, x): + offset = x.shape[-1] + 1 - x.new_ones(x.shape[-1]).cumsum(-1) + z = _clipped_sigmoid(x - offset.log()) + z_cumprod = (1 - z).cumprod(-1) + y = pad(z, [0, 1], value=1) * pad(z_cumprod, [1, 0], value=1) + return y + + def _inverse(self, y): + y_crop = y[..., :-1] + offset = y.shape[-1] - y.new_ones(y_crop.shape[-1]).cumsum(-1) + sf = 1 - y_crop.cumsum(-1) + # we clamp to make sure that sf is positive which sometimes does not + # happen when y[-1] ~ 0 or y[:-1].sum() ~ 1 + sf = ops.clamp(sf, min=ops.finfo(y.dtype).tiny) + x = y_crop.log() - sf.log() + offset.log() + return x + + def log_abs_det_jacobian(self, x, y): + offset = x.shape[-1] + 1 - x.new_ones(x.shape[-1]).cumsum(-1) + x = x - offset.log() + # use the identity 1 - sigmoid(x) = exp(-x) * sigmoid(x) + detJ = (-x + F.logsigmoid(x) + y[..., :-1].log()).sum(-1) + return detJ + + def forward_shape(self, shape): + if len(shape) < 1: + raise ValueError("Too few dimensions on input") + return shape[:-1] + (shape[-1] + 1,) + + def inverse_shape(self, shape): + if len(shape) < 1: + raise ValueError("Too few dimensions on input") + return shape[:-1] + (shape[-1] - 1,) + + +class LowerCholeskyTransform(Transform): + """ + Transform from unconstrained matrices to lower-triangular matrices with + nonnegative diagonal entries. + + This is useful for parameterizing positive definite matrices in terms of + their Cholesky factorization. + """ + + domain = constraints.independent(constraints.real, 2) + codomain = constraints.lower_cholesky + + def __eq__(self, other): + return isinstance(other, LowerCholeskyTransform) + + def _call(self, x): + return x.tril(-1) + x.diagonal(dim1=-2, dim2=-1).exp().diag_embed() + + def _inverse(self, y): + return y.tril(-1) + y.diagonal(dim1=-2, dim2=-1).log().diag_embed() + + +class PositiveDefiniteTransform(Transform): + """ + Transform from unconstrained matrices to positive-definite matrices. + """ + + domain = constraints.independent(constraints.real, 2) + codomain = constraints.positive_definite # type: ignore[assignment] + + def __eq__(self, other): + return isinstance(other, PositiveDefiniteTransform) + + def _call(self, x): + x = LowerCholeskyTransform()(x) + return x @ x.mT + + def _inverse(self, y): + y = ops.linalg.cholesky(y) + return LowerCholeskyTransform().inv(y) + + +class CatTransform(Transform): + """ + Transform functor that applies a sequence of transforms `tseq` + component-wise to each submatrix at `dim`, of length `lengths[dim]`, + in a way compatible with :func:`ops.cat`. + + Example:: + + x0 = ops.cat([ops.range(1, 10), ops.range(1, 10)], dim=0) + x = ops.cat([x0, x0], dim=0) + t0 = CatTransform([ExpTransform(), identity_transform], dim=0, lengths=[10, 10]) + t = CatTransform([t0, t0], dim=0, lengths=[20, 20]) + y = t(x) + """ + + transforms: List[Transform] + + def __init__(self, tseq, dim=0, lengths=None, cache_size=0): + assert all(isinstance(t, Transform) for t in tseq) + if cache_size: + tseq = [t.with_cache(cache_size) for t in tseq] + super().__init__(cache_size=cache_size) + self.transforms = list(tseq) + if lengths is None: + lengths = [1] * len(self.transforms) + self.lengths = list(lengths) + assert len(self.lengths) == len(self.transforms) + self.dim = dim + + @lazy_property + def event_dim(self): + return max(t.event_dim for t in self.transforms) + + @lazy_property + def length(self): + return sum(self.lengths) + + def with_cache(self, cache_size=1): + if self._cache_size == cache_size: + return self + return CatTransform(self.transforms, self.dim, self.lengths, cache_size) + + def _call(self, x): + assert -x.dim() <= self.dim < x.dim() + assert x.size(self.dim) == self.length + yslices = [] + start = 0 + for trans, length in zip(self.transforms, self.lengths): + xslice = x.narrow(self.dim, start, length) + yslices.append(trans(xslice)) + start = start + length # avoid += for jit compat + return ops.cat(yslices, dim=self.dim) + + def _inverse(self, y): + assert -y.dim() <= self.dim < y.dim() + assert y.size(self.dim) == self.length + xslices = [] + start = 0 + for trans, length in zip(self.transforms, self.lengths): + yslice = y.narrow(self.dim, start, length) + xslices.append(trans.inv(yslice)) + start = start + length # avoid += for jit compat + return ops.cat(xslices, dim=self.dim) + + def log_abs_det_jacobian(self, x, y): + assert -x.dim() <= self.dim < x.dim() + assert x.size(self.dim) == self.length + assert -y.dim() <= self.dim < y.dim() + assert y.size(self.dim) == self.length + logdetjacs = [] + start = 0 + for trans, length in zip(self.transforms, self.lengths): + xslice = x.narrow(self.dim, start, length) + yslice = y.narrow(self.dim, start, length) + logdetjac = trans.log_abs_det_jacobian(xslice, yslice) + if trans.event_dim < self.event_dim: + logdetjac = _sum_rightmost(logdetjac, self.event_dim - trans.event_dim) + logdetjacs.append(logdetjac) + start = start + length # avoid += for jit compat + # Decide whether to concatenate or sum. + dim = self.dim + if dim >= 0: + dim = dim - x.dim() + dim = dim + self.event_dim + if dim < 0: + return ops.cat(logdetjacs, dim=dim) + else: + return sum(logdetjacs) + + @property + def bijective(self): + return all(t.bijective for t in self.transforms) + + @constraints.dependent_property + def domain(self): + return constraints.cat( + [t.domain for t in self.transforms], self.dim, self.lengths + ) + + @constraints.dependent_property + def codomain(self): + return constraints.cat( + [t.codomain for t in self.transforms], self.dim, self.lengths + ) + + +class StackTransform(Transform): + """ + Transform functor that applies a sequence of transforms `tseq` + component-wise to each submatrix at `dim` + in a way compatible with :func:`ops.stack`. + + Example:: + + x = ops.stack([ops.range(1, 10), ops.range(1, 10)], dim=1) + t = StackTransform([ExpTransform(), identity_transform], dim=1) + y = t(x) + """ + + transforms: List[Transform] + + def __init__(self, tseq, dim=0, cache_size=0): + assert all(isinstance(t, Transform) for t in tseq) + if cache_size: + tseq = [t.with_cache(cache_size) for t in tseq] + super().__init__(cache_size=cache_size) + self.transforms = list(tseq) + self.dim = dim + + def with_cache(self, cache_size=1): + if self._cache_size == cache_size: + return self + return StackTransform(self.transforms, self.dim, cache_size) + + def _slice(self, z): + return [z.select(self.dim, i) for i in range(z.size(self.dim))] + + def _call(self, x): + assert -x.dim() <= self.dim < x.dim() + assert x.size(self.dim) == len(self.transforms) + yslices = [] + for xslice, trans in zip(self._slice(x), self.transforms): + yslices.append(trans(xslice)) + return ops.stack(yslices, dim=self.dim) + + def _inverse(self, y): + assert -y.dim() <= self.dim < y.dim() + assert y.size(self.dim) == len(self.transforms) + xslices = [] + for yslice, trans in zip(self._slice(y), self.transforms): + xslices.append(trans.inv(yslice)) + return ops.stack(xslices, dim=self.dim) + + def log_abs_det_jacobian(self, x, y): + assert -x.dim() <= self.dim < x.dim() + assert x.size(self.dim) == len(self.transforms) + assert -y.dim() <= self.dim < y.dim() + assert y.size(self.dim) == len(self.transforms) + logdetjacs = [] + yslices = self._slice(y) + xslices = self._slice(x) + for xslice, yslice, trans in zip(xslices, yslices, self.transforms): + logdetjacs.append(trans.log_abs_det_jacobian(xslice, yslice)) + return ops.stack(logdetjacs, dim=self.dim) + + @property + def bijective(self): + return all(t.bijective for t in self.transforms) + + @constraints.dependent_property + def domain(self): + return constraints.stack([t.domain for t in self.transforms], self.dim) + + @constraints.dependent_property + def codomain(self): + return constraints.stack([t.codomain for t in self.transforms], self.dim) + + +class CumulativeDistributionTransform(Transform): + """ + Transform via the cumulative distribution function of a probability distribution. + + Args: + distribution (Distribution): Distribution whose cumulative distribution function to use for + the transformation. + + Example:: + + # Construct a Gaussian copula from a multivariate normal. + base_dist = MultivariateNormal( + loc=ops.zeros(2), + scale_tril=LKJCholesky(2).sample(), + ) + transform = CumulativeDistributionTransform(Normal(0, 1)) + copula = TransformedDistribution(base_dist, [transform]) + """ + + bijective = True + codomain = constraints.unit_interval + sign = +1 + + def __init__(self, distribution, cache_size=0): + super().__init__(cache_size=cache_size) + self.distribution = distribution + + @property + def domain(self): + return self.distribution.support + + def _call(self, x): + return self.distribution.cdf(x) + + def _inverse(self, y): + return self.distribution.icdf(y) + + def log_abs_det_jacobian(self, x, y): + return self.distribution.log_prob(x) + + def with_cache(self, cache_size=1): + if self._cache_size == cache_size: + return self + return CumulativeDistributionTransform(self.distribution, cache_size=cache_size) diff --git a/mindnlp/core/distributions/utils.py b/mindnlp/core/distributions/utils.py new file mode 100644 index 000000000..596f1d68b --- /dev/null +++ b/mindnlp/core/distributions/utils.py @@ -0,0 +1,194 @@ +"""distribution utils""" +# mypy: allow-untyped-defs +from functools import update_wrapper +from numbers import Number +from typing import Any, Dict + +import mindspore +from .. import ops +from ..autograd import enable_grad +from .._bind import get_default_dtype +from ..nn import functional as F + + +euler_constant = 0.57721566490153286060 # Euler Mascheroni Constant + +__all__ = [ + "broadcast_all", + "logits_to_probs", + "clamp_probs", + "probs_to_logits", + "lazy_property", + "tril_matrix_to_vec", + "vec_to_tril_matrix", +] + + +def broadcast_all(*values): + r""" + Given a list of values (possibly containing numbers), returns a list where each + value is broadcasted based on the following rules: + - `core.*Tensor` instances are broadcasted as per :ref:`_broadcasting-semantics`. + - numbers.Number instances (scalars) are upcast to tensors having + the same size and type as the first tensor passed to `values`. If all the + values are scalars, then they are upcasted to scalar Tensors. + + Args: + values (list of `numbers.Number`, `core.*Tensor` or objects implementing __torch_function__) + + Raises: + ValueError: if any of the values is not a `numbers.Number` instance, + a `core.*Tensor` instance, or an instance implementing __torch_function__ + """ + if not all(isinstance(v, (mindspore.Tensor, Number)) for v in values): + raise ValueError( + "Input arguments must all be instances of numbers.Number, " + "mindspore.Tensor or objects implementing __torch_function__." + ) + if not all(isinstance(v, mindspore.Tensor) for v in values): + options: Dict[str, Any] = {"dtype": get_default_dtype()} + for value in values: + if isinstance(value, mindspore.Tensor): + options = {"dtype": value.dtype} + break + new_values = [ + v if isinstance(v, mindspore.Tensor) else mindspore.tensor(v, **options) for v in values + ] + return ops.broadcast_tensors(*new_values) + return ops.broadcast_tensors(*values) + + +def _standard_normal(shape, dtype): + return ops.normal(size = shape).to(dtype) + + +def _sum_rightmost(value, dim): + r""" + Sum out ``dim`` many rightmost dimensions of a given tensor. + + Args: + value (Tensor): A tensor of ``.dim()`` at least ``dim``. + dim (int): The number of rightmost dims to sum out. + """ + if dim == 0: + return value + required_shape = value.shape[:-dim] + (-1,) + return value.reshape(required_shape).sum(-1) + + +def logits_to_probs(logits, is_binary=False): + r""" + Converts a tensor of logits into probabilities. Note that for the + binary case, each value denotes log odds, whereas for the + multi-dimensional case, the values along the last dimension denote + the log probabilities (possibly unnormalized) of the events. + """ + if is_binary: + return ops.sigmoid(logits) + return F.softmax(logits, dim=-1) + + +def clamp_probs(probs): + """Clamps the probabilities to be in the open interval `(0, 1)`. + + The probabilities would be clamped between `eps` and `1 - eps`, + and `eps` would be the smallest representable positive number for the input data type. + + Args: + probs (Tensor): A tensor of probabilities. + + Returns: + Tensor: The clamped probabilities. + + Examples: + >>> probs = mindspore.tensor([0.0, 0.5, 1.0]) + >>> clamp_probs(probs) + tensor([1.1921e-07, 5.0000e-01, 1.0000e+00]) + + >>> probs = mindspore.tensor([0.0, 0.5, 1.0], dtype=mindspore.float64) + >>> clamp_probs(probs) + tensor([2.2204e-16, 5.0000e-01, 1.0000e+00], dtype=mindspore.float64) + + """ + eps = ops.finfo(probs.dtype).eps + return probs.clamp(min=eps, max=1 - eps) + + +def probs_to_logits(probs, is_binary=False): + r""" + Converts a tensor of probabilities into logits. For the binary case, + this denotes the probability of occurrence of the event indexed by `1`. + For the multi-dimensional case, the values along the last dimension + denote the probabilities of occurrence of each of the events. + """ + ps_clamped = clamp_probs(probs) + if is_binary: + return ops.log(ps_clamped) - ops.log1p(-ps_clamped) + return ops.log(ps_clamped) + + +class lazy_property: + r""" + Used as a decorator for lazy loading of class attributes. This uses a + non-data descriptor that calls the wrapped method to compute the property on + first call; thereafter replacing the wrapped method into an instance + attribute. + """ + + def __init__(self, wrapped): + self.wrapped = wrapped + update_wrapper(self, wrapped) # type:ignore[arg-type] + + def __get__(self, instance, obj_type=None): + if instance is None: + return _lazy_property_and_property(self.wrapped) + with enable_grad(): + value = self.wrapped(instance) + setattr(instance, self.wrapped.__name__, value) + return value + + +class _lazy_property_and_property(lazy_property, property): + """We want lazy properties to look like multiple things. + + * property when Sphinx autodoc looks + * lazy_property when Distribution validate_args looks + """ + + +def tril_matrix_to_vec(mat: mindspore.Tensor, diag: int = 0) -> mindspore.Tensor: + r""" + Convert a `D x D` matrix or a batch of matrices into a (batched) vector + which comprises of lower triangular elements from the matrix in row order. + """ + n = mat.shape[-1] + # if not core._C._get_tracing_state() and (diag < -n or diag >= n): + # raise ValueError(f"diag ({diag}) provided is outside [{-n}, {n-1}].") + arange = ops.arange(n) + tril_mask = arange < arange.view(-1, 1) + (diag + 1) + vec = mat[..., tril_mask] + return vec + + +def vec_to_tril_matrix(vec: mindspore.Tensor, diag: int = 0) -> mindspore.Tensor: + r""" + Convert a vector or a batch of vectors into a batched `D x D` + lower triangular matrix containing elements from the vector in row order. + """ + # +ve root of D**2 + (1+2*diag)*D - |diag| * (diag+1) - 2*vec.shape[-1] = 0 + n = ( + -(1 + 2 * diag) + + ((1 + 2 * diag) ** 2 + 8 * vec.shape[-1] + 4 * abs(diag) * (diag + 1)) ** 0.5 + ) / 2 + eps = ops.finfo(vec.dtype).eps + # if not core._C._get_tracing_state() and (round(n) - n > eps): + # raise ValueError( + # f"The size of last dimension is {vec.shape[-1]} which cannot be expressed as " + # + "the lower triangular part of a square D x D matrix." + # ) + n = round(n.item()) if isinstance(n, mindspore.Tensor) else round(n) + mat = vec.new_zeros(vec.shape[:-1] + (n, n)) + arange = ops.arange(n) + tril_mask = arange < arange.view(-1, 1) + (diag + 1) + mat[..., tril_mask] = vec + return mat diff --git a/mindnlp/core/executor.py b/mindnlp/core/executor.py new file mode 100644 index 000000000..c353c69fd --- /dev/null +++ b/mindnlp/core/executor.py @@ -0,0 +1,41 @@ +import mindspore +from mindspore._c_expression import TensorNode, SequenceNode, NoneTypeNode, AnyTypeNode, Tensor as MSTensor +import mindspore.common._stub_tensor +from mindspore.common.api import _pynative_executor +from mindspore.common._stub_tensor import _convert_python_data + +from mindnlp import core +from ._tensor import Tensor +from .dispatcher import dispatcher + +def _convert_stub(stub, device): + "convert stub to StubNode or Value" + if isinstance(stub, (MSTensor, TensorNode)): + return Tensor(stub, device=device) + if isinstance(stub, tuple): + return tuple(_convert_stub(e, device) for e in stub) + if isinstance(stub, SequenceNode): + elements = stub.get_elements() + return tuple(_convert_stub(e, device) for e in elements) + if isinstance(stub, NoneTypeNode): + val = stub.get_real_value() + return _convert_python_data(val) + if isinstance(stub, AnyTypeNode): + val = stub.get_real_node() + return _convert_stub(val, device) + return _convert_python_data(stub) + + +def execute(func_name, *args, **kwargs): + requires_grad = kwargs.pop('requires_grad', False) + user_created = kwargs.pop('user_created', False) + out, device = dispatcher.dispatch(func_name, *args, **kwargs) + out_tensor = _convert_stub(out, device=device) + if requires_grad: + out_tensor._requires_grad = True + if user_created: + out_tensor._user_created = True + out_tensor.attach_grad() + + return out_tensor + diff --git a/mindnlp/core/export/__init__.py b/mindnlp/core/export/__init__.py new file mode 100644 index 000000000..223cbcf8b --- /dev/null +++ b/mindnlp/core/export/__init__.py @@ -0,0 +1 @@ +ExportedProgram = None diff --git a/mindnlp/core/fx/__init__.py b/mindnlp/core/fx/__init__.py new file mode 100644 index 000000000..ce027f4c0 --- /dev/null +++ b/mindnlp/core/fx/__init__.py @@ -0,0 +1,11 @@ +from ._symbolic_trace import ( # noqa: F401 + # PH, + # ProxyableClassMeta, + # symbolic_trace, + # Tracer, + wrap, +) + +from .proxy import Proxy + +from . import _pytree diff --git a/mindnlp/core/fx/_pytree.py b/mindnlp/core/fx/_pytree.py new file mode 100644 index 000000000..383bbed12 --- /dev/null +++ b/mindnlp/core/fx/_pytree.py @@ -0,0 +1,113 @@ +from collections import namedtuple +from typing import Any, Callable, Optional, TypeVar +from typing_extensions import NamedTuple + +from mindnlp import core +from mindnlp.core.utils._pytree import PyTree, tree_flatten, TreeSpec + + +FlattenFuncSpec = Callable[[PyTree, TreeSpec], list] +FlattenFuncExactMatchSpec = Callable[[PyTree, TreeSpec], bool] + +SUPPORTED_NODES: dict[type[Any], FlattenFuncSpec] = {} +SUPPORTED_NODES_EXACT_MATCH: dict[type[Any], Optional[FlattenFuncExactMatchSpec]] = {} + +_T = TypeVar("_T") +_K = TypeVar("_K") +_V = TypeVar("_V") + + +def register_pytree_flatten_spec( + cls: type[Any], + flatten_fn_spec: FlattenFuncSpec, + flatten_fn_exact_match_spec: Optional[FlattenFuncExactMatchSpec] = None, +) -> None: + SUPPORTED_NODES[cls] = flatten_fn_spec + SUPPORTED_NODES_EXACT_MATCH[cls] = flatten_fn_exact_match_spec + + +def _deregister_pytree_flatten_spec( + cls: type[Any], +) -> None: + del SUPPORTED_NODES[cls] + del SUPPORTED_NODES_EXACT_MATCH[cls] + + +def tree_flatten_spec( + pytree: PyTree, + spec: TreeSpec, +) -> list[Any]: + if spec.is_leaf(): + return [pytree] + # I guess these exist for BC, FC reasons. + # In general, we should be able to directly + # use pytree tree flattener to flatten them, + # as export serializes the pytree seperately. + # Will remove it in follow up PR. + if spec.type in SUPPORTED_NODES: + flatten_fn_spec = SUPPORTED_NODES[spec.type] + child_pytrees = flatten_fn_spec(pytree, spec) + result = [] + for child, child_spec in zip(child_pytrees, spec.children_specs): + flat = tree_flatten_spec(child, child_spec) + result += flat + return result + flat_result, real_spec = tree_flatten(pytree) + if spec != real_spec: + raise RuntimeError( + f"Real spec {real_spec} of object {pytree} is different from expected spec {spec}. " + f"Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml" + ) + return flat_result + + +def _dict_flatten_spec(d: dict[_K, _V], spec: TreeSpec) -> list[_V]: + return [d[k] for k in spec.context] + + +def _list_flatten_spec(d: list[_T], spec: TreeSpec) -> list[_T]: + return [d[i] for i in range(spec.num_children)] + + +def _tuple_flatten_spec(d: tuple[_T, ...], spec: TreeSpec) -> list[_T]: + return [d[i] for i in range(spec.num_children)] + + +def _namedtuple_flatten_spec(d: NamedTuple, spec: TreeSpec) -> list[Any]: + return [d[i] for i in range(spec.num_children)] + + +def _dict_flatten_spec_exact_match(d: dict[_K, _V], spec: TreeSpec) -> bool: + return len(d) == spec.num_children + + +def _list_flatten_spec_exact_match(d: list[_T], spec: TreeSpec) -> bool: + return len(d) == spec.num_children + + +def _tuple_flatten_spec_exact_match(d: tuple[_T, ...], spec: TreeSpec) -> bool: + return len(d) == spec.num_children + + +def _namedtuple_flatten_spec_exact_match(d: NamedTuple, spec: TreeSpec) -> bool: + return len(d) == spec.num_children + + +register_pytree_flatten_spec(dict, _dict_flatten_spec, _dict_flatten_spec_exact_match) +register_pytree_flatten_spec(list, _list_flatten_spec, _list_flatten_spec_exact_match) +register_pytree_flatten_spec( + tuple, + _tuple_flatten_spec, + _tuple_flatten_spec_exact_match, +) +for return_type in core.return_types.all_return_types: + register_pytree_flatten_spec( + return_type, + _tuple_flatten_spec, + _tuple_flatten_spec_exact_match, + ) +register_pytree_flatten_spec( + namedtuple, # type: ignore[arg-type] + _namedtuple_flatten_spec, + _namedtuple_flatten_spec_exact_match, +) \ No newline at end of file diff --git a/mindnlp/core/fx/_symbolic_trace.py b/mindnlp/core/fx/_symbolic_trace.py new file mode 100644 index 000000000..2e9d28668 --- /dev/null +++ b/mindnlp/core/fx/_symbolic_trace.py @@ -0,0 +1,76 @@ +import builtins +import collections +import contextlib +import copy +import functools +import inspect +import math +import os +import warnings +from itertools import chain +from types import CodeType, FunctionType, ModuleType +from typing import Any, Callable, NamedTuple, Optional, Union + +_wrapped_fns_to_patch: dict[tuple[int, str], dict] = {} + +def wrap(fn_or_name: Union[str, Callable]): + """ + This function can be called at module-level scope to register fn_or_name as a "leaf function". + A "leaf function" will be preserved as a CallFunction node in the FX trace instead of being + traced through:: + + # foo/bar/baz.py + def my_custom_function(x, y): + return x * x + y * y + + + torch.fx.wrap("my_custom_function") + + + def fn_to_be_traced(x, y): + # When symbolic tracing, the below call to my_custom_function will be inserted into + # the graph rather than tracing it. + return my_custom_function(x, y) + + This function can also equivalently be used as a decorator:: + + # foo/bar/baz.py + @torch.fx.wrap + def my_custom_function(x, y): + return x * x + y * y + + A wrapped function can be thought of a "leaf function", analogous to the concept of + "leaf modules", that is, they are functions that are left as calls in the FX trace + rather than traced through. + + Args: + + fn_or_name (Union[str, Callable]): The function or name of the global function to insert into the + graph when it's called + """ + if not callable(fn_or_name) and not isinstance(fn_or_name, str): + raise RuntimeError( + "Unsupported type for global function! Must be either a callable or " + "string name" + ) + + if callable(fn_or_name): + assert not isinstance(fn_or_name, str) # to make mypy happy + fn_name = fn_or_name.__name__ + else: + assert isinstance( + fn_or_name, str + ), "fn_or_name must be a global function or string name" + fn_name = fn_or_name + + currentframe = inspect.currentframe() + assert currentframe is not None + f = currentframe.f_back + assert f is not None + if f.f_code.co_name != "": + raise NotImplementedError("wrap must be called at the top level of a module") + + # consider implementing Callable version of this via _autowrap_function_ids / _autowrap_search + # semantics would be slightly different, but would add support `from x import wrapped_function` + _wrapped_fns_to_patch[(id(f.f_globals), fn_name)] = f.f_globals + return fn_or_name diff --git a/mindnlp/core/fx/proxy.py b/mindnlp/core/fx/proxy.py new file mode 100644 index 000000000..46b350813 --- /dev/null +++ b/mindnlp/core/fx/proxy.py @@ -0,0 +1,2 @@ +class Proxy: + pass diff --git a/mindnlp/core/hub.py b/mindnlp/core/hub.py new file mode 100644 index 000000000..743d8f5e7 --- /dev/null +++ b/mindnlp/core/hub.py @@ -0,0 +1,880 @@ +"""torch""" +# mypy: allow-untyped-defs +import contextlib +import errno +import hashlib +import json +import os +import re +import shutil +import sys +import tempfile +import uuid +import warnings +import zipfile +from pathlib import Path +from typing import Any, Dict, Optional +from urllib.error import HTTPError, URLError +from urllib.parse import urlparse # noqa: F401 +from urllib.request import Request, urlopen +from typing_extensions import deprecated +from .serialization import load + +class _Faketqdm: # type: ignore[no-redef] + def __init__(self, total=None, disable=False, **kwargs): + self.total = total + self.disable = disable + self.n = 0 + # Ignore all extra *args and **kwargs lest you want to reinvent tqdm + + def update(self, n): + if self.disable: + return + + self.n += n + if self.total is None: + sys.stderr.write(f"\r{self.n:.1f} bytes") + else: + sys.stderr.write(f"\r{100 * self.n / float(self.total):.1f}%") + sys.stderr.flush() + + # Don't bother implementing; use real tqdm if you want + def set_description(self, *args, **kwargs): + pass + + def write(self, s): + sys.stderr.write(f"{s}\n") + + def close(self): + self.disable = True + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.disable: + return + + sys.stderr.write("\n") + + +try: + from tqdm import tqdm # If tqdm is installed use it, otherwise use the fake wrapper +except ImportError: + tqdm = _Faketqdm + +__all__ = [ + "download_url_to_file", + "get_dir", + "help", + "list", + "load", + "load_state_dict_from_url", + "set_dir", +] + +# matches bfd8deac from resnet18-bfd8deac.pth +HASH_REGEX = re.compile(r"-([a-f0-9]*)\.") + +_TRUSTED_REPO_OWNERS = ( + "facebookresearch", + "facebookincubator", + "pytorch", + "fairinternal", +) +ENV_GITHUB_TOKEN = "GITHUB_TOKEN" +ENV_TORCH_HOME = "TORCH_HOME" +ENV_XDG_CACHE_HOME = "XDG_CACHE_HOME" +DEFAULT_CACHE_DIR = "~/.cache" +VAR_DEPENDENCY = "dependencies" +MODULE_HUBCONF = "hubconf.py" +READ_DATA_CHUNK = 128 * 1024 +_hub_dir: Optional[str] = None + + +@contextlib.contextmanager +def _add_to_sys_path(path): + sys.path.insert(0, path) + try: + yield + finally: + sys.path.remove(path) + + +# Copied from tools/shared/module_loader to be included in torch package +def _import_module(name, path): + import importlib.util + from importlib.abc import Loader + + spec = importlib.util.spec_from_file_location(name, path) + assert spec is not None + module = importlib.util.module_from_spec(spec) + assert isinstance(spec.loader, Loader) + spec.loader.exec_module(module) + return module + + +def _remove_if_exists(path): + if os.path.exists(path): + if os.path.isfile(path): + os.remove(path) + else: + shutil.rmtree(path) + + +def _git_archive_link(repo_owner, repo_name, ref): + # See https://docs.github.com/en/rest/reference/repos#download-a-repository-archive-zip + return f"https://github.com/{repo_owner}/{repo_name}/zipball/{ref}" + + +def _load_attr_from_module(module, func_name): + # Check if callable is defined in the module + if func_name not in dir(module): + return None + return getattr(module, func_name) + + +def _get_torch_home(): + torch_home = os.path.expanduser( + os.getenv( + ENV_TORCH_HOME, + os.path.join(os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), "torch"), + ) + ) + return torch_home + + +def _parse_repo_info(github): + if ":" in github: + repo_info, ref = github.split(":") + else: + repo_info, ref = github, None + repo_owner, repo_name = repo_info.split("/") + + if ref is None: + # The ref wasn't specified by the user, so we need to figure out the + # default branch: main or master. Our assumption is that if main exists + # then it's the default branch, otherwise it's master. + try: + with urlopen(f"https://github.com/{repo_owner}/{repo_name}/tree/main/"): + ref = "main" + except HTTPError as e: + if e.code == 404: + ref = "master" + else: + raise + except URLError as e: + # No internet connection, need to check for cache as last resort + for possible_ref in ("main", "master"): + if os.path.exists( + f"{get_dir()}/{repo_owner}_{repo_name}_{possible_ref}" + ): + ref = possible_ref + break + if ref is None: + raise RuntimeError( + "It looks like there is no internet connection and the " + f"repo could not be found in the cache ({get_dir()})" + ) from e + return repo_owner, repo_name, ref + + +def _read_url(url): + with urlopen(url) as r: + return r.read().decode(r.headers.get_content_charset("utf-8")) + + +def _validate_not_a_forked_repo(repo_owner, repo_name, ref): + # Use urlopen to avoid depending on local git. + headers = {"Accept": "application/vnd.github.v3+json"} + token = os.environ.get(ENV_GITHUB_TOKEN) + if token is not None: + headers["Authorization"] = f"token {token}" + for url_prefix in ( + f"https://api.github.com/repos/{repo_owner}/{repo_name}/branches", + f"https://api.github.com/repos/{repo_owner}/{repo_name}/tags", + ): + page = 0 + while True: + page += 1 + url = f"{url_prefix}?per_page=100&page={page}" + response = json.loads(_read_url(Request(url, headers=headers))) + # Empty response means no more data to process + if not response: + break + for br in response: + if br["name"] == ref or br["commit"]["sha"].startswith(ref): + return + + raise ValueError( + f"Cannot find {ref} in https://github.com/{repo_owner}/{repo_name}. " + "If it's a commit from a forked repo, please call hub.load() with forked repo directly." + ) + + +def _get_cache_or_reload( + github, + force_reload, + trust_repo, + calling_fn, + verbose=True, + skip_validation=False, +): + # Setup hub_dir to save downloaded files + hub_dir = get_dir() + os.makedirs(hub_dir, exist_ok=True) + # Parse github repo information + repo_owner, repo_name, ref = _parse_repo_info(github) + # Github allows branch name with slash '/', + # this causes confusion with path on both Linux and Windows. + # Backslash is not allowed in Github branch name so no need to + # to worry about it. + normalized_br = ref.replace("/", "_") + # Github renames folder repo-v1.x.x to repo-1.x.x + # We don't know the repo name before downloading the zip file + # and inspect name from it. + # To check if cached repo exists, we need to normalize folder names. + owner_name_branch = "_".join([repo_owner, repo_name, normalized_br]) + repo_dir = os.path.join(hub_dir, owner_name_branch) + # Check that the repo is in the trusted list + _check_repo_is_trusted( + repo_owner, + repo_name, + owner_name_branch, + trust_repo=trust_repo, + calling_fn=calling_fn, + ) + + use_cache = (not force_reload) and os.path.exists(repo_dir) + + if use_cache: + if verbose: + sys.stderr.write(f"Using cache found in {repo_dir}\n") + else: + # Validate the tag/branch is from the original repo instead of a forked repo + if not skip_validation: + _validate_not_a_forked_repo(repo_owner, repo_name, ref) + + cached_file = os.path.join(hub_dir, normalized_br + ".zip") + _remove_if_exists(cached_file) + + try: + url = _git_archive_link(repo_owner, repo_name, ref) + sys.stderr.write(f'Downloading: "{url}" to {cached_file}\n') + download_url_to_file(url, cached_file, progress=False) + except HTTPError as err: + if err.code == 300: + # Getting a 300 Multiple Choices error likely means that the ref is both a tag and a branch + # in the repo. This can be disambiguated by explicitely using refs/heads/ or refs/tags + # See https://git-scm.com/book/en/v2/Git-Internals-Git-References + # Here, we do the same as git: we throw a warning, and assume the user wanted the branch + warnings.warn( + f"The ref {ref} is ambiguous. Perhaps it is both a tag and a branch in the repo? " + "Torchhub will now assume that it's a branch. " + "You can disambiguate tags and branches by explicitly passing refs/heads/branch_name or " + "refs/tags/tag_name as the ref. That might require using skip_validation=True." + ) + disambiguated_branch_ref = f"refs/heads/{ref}" + url = _git_archive_link( + repo_owner, repo_name, ref=disambiguated_branch_ref + ) + download_url_to_file(url, cached_file, progress=False) + else: + raise + + with zipfile.ZipFile(cached_file) as cached_zipfile: + extraced_repo_name = cached_zipfile.infolist()[0].filename + extracted_repo = os.path.join(hub_dir, extraced_repo_name) + _remove_if_exists(extracted_repo) + # Unzip the code and rename the base folder + cached_zipfile.extractall(hub_dir) + + _remove_if_exists(cached_file) + _remove_if_exists(repo_dir) + shutil.move(extracted_repo, repo_dir) # rename the repo + + return repo_dir + + +def _check_repo_is_trusted( + repo_owner, + repo_name, + owner_name_branch, + trust_repo, + calling_fn="load", +): + hub_dir = get_dir() + filepath = os.path.join(hub_dir, "trusted_list") + + if not os.path.exists(filepath): + Path(filepath).touch() + with open(filepath) as file: + trusted_repos = tuple(line.strip() for line in file) + + # To minimize friction of introducing the new trust_repo mechanism, we consider that + # if a repo was already downloaded by torchhub, then it is already trusted (even if it's not in the allowlist) + trusted_repos_legacy = next(os.walk(hub_dir))[1] + + owner_name = "_".join([repo_owner, repo_name]) + is_trusted = ( + owner_name in trusted_repos + or owner_name_branch in trusted_repos_legacy + or repo_owner in _TRUSTED_REPO_OWNERS + ) + + # TODO: Remove `None` option in 2.0 and change the default to "check" + if trust_repo is None: + if not is_trusted: + warnings.warn( + "You are about to download and run code from an untrusted repository. In a future release, this won't " + "be allowed. To add the repository to your trusted list, change the command to {calling_fn}(..., " + "trust_repo=False) and a command prompt will appear asking for an explicit confirmation of trust, " + f"or {calling_fn}(..., trust_repo=True), which will assume that the prompt is to be answered with " + f"'yes'. You can also use {calling_fn}(..., trust_repo='check') which will only prompt for " + f"confirmation if the repo is not already trusted. This will eventually be the default behaviour" + ) + return + + if (trust_repo is False) or (trust_repo == "check" and not is_trusted): + response = input( + f"The repository {owner_name} does not belong to the list of trusted repositories and as such cannot be downloaded. " + "Do you trust this repository and wish to add it to the trusted list of repositories (y/N)?" + ) + if response.lower() in ("y", "yes"): + if is_trusted: + print("The repository is already trusted.") + elif response.lower() in ("n", "no", ""): + raise Exception("Untrusted repository.") # noqa: TRY002 + else: + raise ValueError(f"Unrecognized response {response}.") + + # At this point we're sure that the user trusts the repo (or wants to trust it) + if not is_trusted: + with open(filepath, "a") as file: + file.write(owner_name + "\n") + + +def _check_module_exists(name): + import importlib.util + + return importlib.util.find_spec(name) is not None + + +def _check_dependencies(m): + dependencies = _load_attr_from_module(m, VAR_DEPENDENCY) + + if dependencies is not None: + missing_deps = [pkg for pkg in dependencies if not _check_module_exists(pkg)] + if len(missing_deps): + raise RuntimeError(f"Missing dependencies: {', '.join(missing_deps)}") + + +def _load_entry_from_hubconf(m, model): + if not isinstance(model, str): + raise ValueError("Invalid input: model should be a string of function name") + + # Note that if a missing dependency is imported at top level of hubconf, it will + # throw before this function. It's a chicken and egg situation where we have to + # load hubconf to know what're the dependencies, but to import hubconf it requires + # a missing package. This is fine, Python will throw proper error message for users. + _check_dependencies(m) + + func = _load_attr_from_module(m, model) + + if func is None or not callable(func): + raise RuntimeError(f"Cannot find callable {model} in hubconf") + + return func + + + +def get_dir(): + r""" + Get the Torch Hub cache directory used for storing downloaded models & weights. + + If :func:`~core.hub.set_dir` is not called, default path is ``$TORCH_HOME/hub`` where + environment variable ``$TORCH_HOME`` defaults to ``$XDG_CACHE_HOME/torch``. + ``$XDG_CACHE_HOME`` follows the X Design Group specification of the Linux + filesystem layout, with a default value ``~/.cache`` if the environment + variable is not set. + """ + # Issue warning to move data if old env is set + if os.getenv("TORCH_HUB"): + warnings.warn("TORCH_HUB is deprecated, please use env TORCH_HOME instead") + + if _hub_dir is not None: + return _hub_dir + return os.path.join(_get_torch_home(), "hub") + + + + +def set_dir(d): + r""" + Optionally set the Torch Hub directory used to save downloaded models & weights. + + Args: + d (str): path to a local folder to save downloaded models & weights. + """ + global _hub_dir + _hub_dir = os.path.expanduser(d) + + + + +def list( + github, + force_reload=False, + skip_validation=False, + trust_repo=None, + verbose=True, +): + r""" + List all callable entrypoints available in the repo specified by ``github``. + + Args: + github (str): a string with format "repo_owner/repo_name[:ref]" with an optional + ref (tag or branch). If ``ref`` is not specified, the default branch is assumed to be ``main`` if + it exists, and otherwise ``master``. + Example: 'pytorch/vision:0.10' + force_reload (bool, optional): whether to discard the existing cache and force a fresh download. + Default is ``False``. + skip_validation (bool, optional): if ``False``, torchhub will check that the branch or commit + specified by the ``github`` argument properly belongs to the repo owner. This will make + requests to the GitHub API; you can specify a non-default GitHub token by setting the + ``GITHUB_TOKEN`` environment variable. Default is ``False``. + trust_repo (bool, str or None): ``"check"``, ``True``, ``False`` or ``None``. + This parameter was introduced in v1.12 and helps ensuring that users + only run code from repos that they trust. + + - If ``False``, a prompt will ask the user whether the repo should + be trusted. + - If ``True``, the repo will be added to the trusted list and loaded + without requiring explicit confirmation. + - If ``"check"``, the repo will be checked against the list of + trusted repos in the cache. If it is not present in that list, the + behaviour will fall back onto the ``trust_repo=False`` option. + - If ``None``: this will raise a warning, inviting the user to set + ``trust_repo`` to either ``False``, ``True`` or ``"check"``. This + is only present for backward compatibility and will be removed in + v2.0. + + Default is ``None`` and will eventually change to ``"check"`` in v2.0. + verbose (bool, optional): If ``False``, mute messages about hitting + local caches. Note that the message about first download cannot be + muted. Default is ``True``. + + Returns: + list: The available callables entrypoint + + Example: + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_HUB) + >>> entrypoints = core.hub.list("pytorch/vision", force_reload=True) + """ + repo_dir = _get_cache_or_reload( + github, + force_reload, + trust_repo, + "list", + verbose=verbose, + skip_validation=skip_validation, + ) + + with _add_to_sys_path(repo_dir): + hubconf_path = os.path.join(repo_dir, MODULE_HUBCONF) + hub_module = _import_module(MODULE_HUBCONF, hubconf_path) + + # We take functions starts with '_' as internal helper functions + entrypoints = [ + f + for f in dir(hub_module) + if callable(getattr(hub_module, f)) and not f.startswith("_") + ] + + return entrypoints + + + + +def help(github, model, force_reload=False, skip_validation=False, trust_repo=None): + r""" + Show the docstring of entrypoint ``model``. + + Args: + github (str): a string with format with an optional + ref (a tag or a branch). If ``ref`` is not specified, the default branch is assumed + to be ``main`` if it exists, and otherwise ``master``. + Example: 'pytorch/vision:0.10' + model (str): a string of entrypoint name defined in repo's ``hubconf.py`` + force_reload (bool, optional): whether to discard the existing cache and force a fresh download. + Default is ``False``. + skip_validation (bool, optional): if ``False``, torchhub will check that the ref + specified by the ``github`` argument properly belongs to the repo owner. This will make + requests to the GitHub API; you can specify a non-default GitHub token by setting the + ``GITHUB_TOKEN`` environment variable. Default is ``False``. + trust_repo (bool, str or None): ``"check"``, ``True``, ``False`` or ``None``. + This parameter was introduced in v1.12 and helps ensuring that users + only run code from repos that they trust. + + - If ``False``, a prompt will ask the user whether the repo should + be trusted. + - If ``True``, the repo will be added to the trusted list and loaded + without requiring explicit confirmation. + - If ``"check"``, the repo will be checked against the list of + trusted repos in the cache. If it is not present in that list, the + behaviour will fall back onto the ``trust_repo=False`` option. + - If ``None``: this will raise a warning, inviting the user to set + ``trust_repo`` to either ``False``, ``True`` or ``"check"``. This + is only present for backward compatibility and will be removed in + v2.0. + + Default is ``None`` and will eventually change to ``"check"`` in v2.0. + Example: + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_HUB) + >>> print(core.hub.help("pytorch/vision", "resnet18", force_reload=True)) + """ + repo_dir = _get_cache_or_reload( + github, + force_reload, + trust_repo, + "help", + verbose=True, + skip_validation=skip_validation, + ) + + with _add_to_sys_path(repo_dir): + hubconf_path = os.path.join(repo_dir, MODULE_HUBCONF) + hub_module = _import_module(MODULE_HUBCONF, hubconf_path) + + entry = _load_entry_from_hubconf(hub_module, model) + + return entry.__doc__ + + + + +def load( + repo_or_dir, + model, + *args, + source="github", + trust_repo=None, + force_reload=False, + verbose=True, + skip_validation=False, + **kwargs, +): + r""" + Load a model from a github repo or a local directory. + + Note: Loading a model is the typical use case, but this can also be used to + for loading other objects such as tokenizers, loss functions, etc. + + If ``source`` is 'github', ``repo_or_dir`` is expected to be + of the form ``repo_owner/repo_name[:ref]`` with an optional + ref (a tag or a branch). + + If ``source`` is 'local', ``repo_or_dir`` is expected to be a + path to a local directory. + + Args: + repo_or_dir (str): If ``source`` is 'github', + this should correspond to a github repo with format ``repo_owner/repo_name[:ref]`` with + an optional ref (tag or branch), for example 'pytorch/vision:0.10'. If ``ref`` is not specified, + the default branch is assumed to be ``main`` if it exists, and otherwise ``master``. + If ``source`` is 'local' then it should be a path to a local directory. + model (str): the name of a callable (entrypoint) defined in the + repo/dir's ``hubconf.py``. + *args (optional): the corresponding args for callable ``model``. + source (str, optional): 'github' or 'local'. Specifies how + ``repo_or_dir`` is to be interpreted. Default is 'github'. + trust_repo (bool, str or None): ``"check"``, ``True``, ``False`` or ``None``. + This parameter was introduced in v1.12 and helps ensuring that users + only run code from repos that they trust. + + - If ``False``, a prompt will ask the user whether the repo should + be trusted. + - If ``True``, the repo will be added to the trusted list and loaded + without requiring explicit confirmation. + - If ``"check"``, the repo will be checked against the list of + trusted repos in the cache. If it is not present in that list, the + behaviour will fall back onto the ``trust_repo=False`` option. + - If ``None``: this will raise a warning, inviting the user to set + ``trust_repo`` to either ``False``, ``True`` or ``"check"``. This + is only present for backward compatibility and will be removed in + v2.0. + + Default is ``None`` and will eventually change to ``"check"`` in v2.0. + force_reload (bool, optional): whether to force a fresh download of + the github repo unconditionally. Does not have any effect if + ``source = 'local'``. Default is ``False``. + verbose (bool, optional): If ``False``, mute messages about hitting + local caches. Note that the message about first download cannot be + muted. Does not have any effect if ``source = 'local'``. + Default is ``True``. + skip_validation (bool, optional): if ``False``, torchhub will check that the branch or commit + specified by the ``github`` argument properly belongs to the repo owner. This will make + requests to the GitHub API; you can specify a non-default GitHub token by setting the + ``GITHUB_TOKEN`` environment variable. Default is ``False``. + **kwargs (optional): the corresponding kwargs for callable ``model``. + + Returns: + The output of the ``model`` callable when called with the given + ``*args`` and ``**kwargs``. + + Example: + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_HUB) + >>> # from a github repo + >>> repo = "pytorch/vision" + >>> model = core.hub.load( + ... repo, "resnet50", weights="ResNet50_Weights.IMAGENET1K_V1" + ... ) + >>> # from a local directory + >>> path = "/some/local/path/pytorch/vision" + >>> # xdoctest: +SKIP + >>> model = core.hub.load(path, "resnet50", weights="ResNet50_Weights.DEFAULT") + """ + source = source.lower() + + if source not in ("github", "local"): + raise ValueError( + f'Unknown source: "{source}". Allowed values: "github" | "local".' + ) + + if source == "github": + repo_or_dir = _get_cache_or_reload( + repo_or_dir, + force_reload, + trust_repo, + "load", + verbose=verbose, + skip_validation=skip_validation, + ) + + model = _load_local(repo_or_dir, model, *args, **kwargs) + return model + + + +def _load_local(hubconf_dir, model, *args, **kwargs): + r""" + Load a model from a local directory with a ``hubconf.py``. + + Args: + hubconf_dir (str): path to a local directory that contains a + ``hubconf.py``. + model (str): name of an entrypoint defined in the directory's + ``hubconf.py``. + *args (optional): the corresponding args for callable ``model``. + **kwargs (optional): the corresponding kwargs for callable ``model``. + + Returns: + a single model with corresponding pretrained weights. + + Example: + >>> # xdoctest: +SKIP("stub local path") + >>> path = "/some/local/path/pytorch/vision" + >>> model = _load_local(path, "resnet50", weights="ResNet50_Weights.IMAGENET1K_V1") + """ + with _add_to_sys_path(hubconf_dir): + hubconf_path = os.path.join(hubconf_dir, MODULE_HUBCONF) + hub_module = _import_module(MODULE_HUBCONF, hubconf_path) + + entry = _load_entry_from_hubconf(hub_module, model) + model = entry(*args, **kwargs) + + return model + + + +def download_url_to_file( + url: str, + dst: str, + hash_prefix: Optional[str] = None, + progress: bool = True, +) -> None: + r"""Download object at the given URL to a local path. + + Args: + url (str): URL of the object to download + dst (str): Full path where object will be saved, e.g. ``/tmp/temporary_file`` + hash_prefix (str, optional): If not None, the SHA256 downloaded file should start with ``hash_prefix``. + Default: None + progress (bool, optional): whether or not to display a progress bar to stderr + Default: True + + Example: + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_HUB) + >>> # xdoctest: +REQUIRES(POSIX) + >>> core.hub.download_url_to_file( + ... "https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth", + ... "/tmp/temporary_file", + ... ) + + """ + file_size = None + req = Request(url, headers={"User-Agent": "core.hub"}) + u = urlopen(req) + meta = u.info() + if hasattr(meta, "getheaders"): + content_length = meta.getheaders("Content-Length") + else: + content_length = meta.get_all("Content-Length") + if content_length is not None and len(content_length) > 0: + file_size = int(content_length[0]) + + # We deliberately save it in a temp file and move it after + # download is complete. This prevents a local working checkpoint + # being overridden by a broken download. + # We deliberately do not use NamedTemporaryFile to avoid restrictive + # file permissions being applied to the downloaded file. + dst = os.path.expanduser(dst) + for seq in range(tempfile.TMP_MAX): + tmp_dst = dst + "." + uuid.uuid4().hex + ".partial" + try: + f = open(tmp_dst, "w+b") + except FileExistsError: + continue + break + else: + raise FileExistsError(errno.EEXIST, "No usable temporary file name found") + + try: + if hash_prefix is not None: + sha256 = hashlib.sha256() + with tqdm( + total=file_size, + disable=not progress, + unit="B", + unit_scale=True, + unit_divisor=1024, + ) as pbar: + while True: + buffer = u.read(READ_DATA_CHUNK) + if len(buffer) == 0: + break + f.write(buffer) # type: ignore[possibly-undefined] + if hash_prefix is not None: + sha256.update(buffer) # type: ignore[possibly-undefined] + pbar.update(len(buffer)) + + f.close() + if hash_prefix is not None: + digest = sha256.hexdigest() # type: ignore[possibly-undefined] + if digest[: len(hash_prefix)] != hash_prefix: + raise RuntimeError( + f'invalid hash value (expected "{hash_prefix}", got "{digest}")' + ) + shutil.move(f.name, dst) + finally: + f.close() + if os.path.exists(f.name): + os.remove(f.name) + + + +# Hub used to support automatically extracts from zipfile manually compressed by users. +# The legacy zip format expects only one file from core.save() < 1.6 in the zip. +# We should remove this support since zipfile is now default zipfile format for core.save(). +def _is_legacy_zip_format(filename: str) -> bool: + if zipfile.is_zipfile(filename): + infolist = zipfile.ZipFile(filename).infolist() + return len(infolist) == 1 and not infolist[0].is_dir() + return False + + +@deprecated( + "Falling back to the old format < 1.6. This support will be " + "deprecated in favor of default zipfile format introduced in 1.6. " + "Please redo core.save() to save it in the new zipfile format.", + category=FutureWarning, +) +def _legacy_zip_load( + filename: str, + model_dir: str, + weights_only: bool, +) -> Dict[str, Any]: + # Note: extractall() defaults to overwrite file if exists. No need to clean up beforehand. + # We deliberately don't handle tarfile here since our legacy serialization format was in tar. + # E.g. resnet18-5c106cde.pth which is widely used. + with zipfile.ZipFile(filename) as f: + members = f.infolist() + if len(members) != 1: + raise RuntimeError("Only one file(not dir) is allowed in the zipfile") + f.extractall(model_dir) + extraced_name = members[0].filename + extracted_file = os.path.join(model_dir, extraced_name) + return load( + extracted_file, weights_only=weights_only + ) + + + +def load_state_dict_from_url( + url: str, + model_dir: Optional[str] = None, + progress: bool = True, + check_hash: bool = False, + file_name: Optional[str] = None, + weights_only: bool = False, +) -> Dict[str, Any]: + r"""Loads the Torch serialized object at the given URL. + + If downloaded file is a zip file, it will be automatically + decompressed. + + If the object is already present in `model_dir`, it's deserialized and + returned. + The default value of ``model_dir`` is ``/checkpoints`` where + ``hub_dir`` is the directory returned by :func:`~core.hub.get_dir`. + + Args: + url (str): URL of the object to download + model_dir (str, optional): directory in which to save the object + map_location (optional): a function or a dict specifying how to remap storage locations (see core.load) + progress (bool, optional): whether or not to display a progress bar to stderr. + Default: True + check_hash(bool, optional): If True, the filename part of the URL should follow the naming convention + ``filename-.ext`` where ```` is the first eight or more + digits of the SHA256 hash of the contents of the file. The hash is used to + ensure unique names and to verify the contents of the file. + Default: False + file_name (str, optional): name for the downloaded file. Filename from ``url`` will be used if not set. + weights_only(bool, optional): If True, only weights will be loaded and no complex pickled objects. + Recommended for untrusted sources. See :func:`~core.load` for more details. + + Example: + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_HUB) + >>> state_dict = core.hub.load_state_dict_from_url( + ... "https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth" + ... ) + + """ + # Issue warning to move data if old env is set + if os.getenv("TORCH_MODEL_ZOO"): + warnings.warn( + "TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead" + ) + + if model_dir is None: + hub_dir = get_dir() + model_dir = os.path.join(hub_dir, "checkpoints") + + os.makedirs(model_dir, exist_ok=True) + + parts = urlparse(url) + filename = os.path.basename(parts.path) + if file_name is not None: + filename = file_name + cached_file = os.path.join(model_dir, filename) + if not os.path.exists(cached_file): + sys.stderr.write(f'Downloading: "{url}" to {cached_file}\n') + hash_prefix = None + if check_hash: + r = HASH_REGEX.search(filename) # r is Optional[Match[str]] + hash_prefix = r.group(1) if r else None + download_url_to_file(url, cached_file, hash_prefix, progress=progress) + + if _is_legacy_zip_format(cached_file): + return _legacy_zip_load(cached_file, model_dir, weights_only) + return core.load(cached_file, weights_only=weights_only) diff --git a/mindnlp/core/jit/__init__.py b/mindnlp/core/jit/__init__.py new file mode 100644 index 000000000..9b66f51a3 --- /dev/null +++ b/mindnlp/core/jit/__init__.py @@ -0,0 +1,77 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from .._jit_internal import ( + # _Await, + # _drop, + # _IgnoreContextManager, + # _isinstance, + # _overload, + # _overload_method, + # export, + # Final, + # Future, + # ignore, + # is_scripting, + unused, +) + +from ._trace import ( + # _flatten, + # _get_trace_graph, + _script_if_tracing, + # _unique_state_dict, + # is_tracing, + # ONNXTracedModule, + # TopLevelTracedModule, + # trace, + # trace_module, + # TracedModule, + # TracerWarning, + # TracingCheckError, +) + +def is_tracing(): + return False + +def is_scripting(): + return False + +def script(obj, optimize=None, _frames_up=0, _rcb=None, example_inputs=None): + return obj + +def ignore(drop=False, **kwargs): + + if callable(drop): + return drop + + def decorator(fn): + return fn + + return decorator + +def _overload_method(func): + pass + +def interface(obj): + pass + +def script_if_tracing(fn): + """ + Compiles ``fn`` when it is first called during tracing. + + ``torch.jit.script`` has a non-negligible start up time when it is first called due to + lazy-initializations of many compiler builtins. Therefore you should not use + it in library code. However, you may want to have parts of your library work + in tracing even if they use control flow. In these cases, you should use + ``@torch.jit.script_if_tracing`` to substitute for + ``torch.jit.script``. + + Args: + fn: A function to compile. + + Returns: + If called during tracing, a :class:`ScriptFunction` created by `torch.jit.script` is returned. + Otherwise, the original function `fn` is returned. + """ + return _script_if_tracing(fn) \ No newline at end of file diff --git a/mindnlp/core/jit/_trace.py b/mindnlp/core/jit/_trace.py new file mode 100644 index 000000000..7457b3d3d --- /dev/null +++ b/mindnlp/core/jit/_trace.py @@ -0,0 +1,22 @@ +import functools + +from typing import Any, Callable, Optional, TypeVar +from typing_extensions import ParamSpec + +R = TypeVar("R", covariant=True) # return type (always covariant) +P = ParamSpec("P") + +def _script_if_tracing(fn: Callable[P, R]) -> Callable[P, R]: + @functools.wraps(fn) + def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: + if not is_tracing(): + # Not tracing, don't do anything + return fn(*args, **kwargs) + + compiled_fn: Callable[P, R] = script(wrapper.__original_fn) # type: ignore[attr-defined] + return compiled_fn(*args, **kwargs) + + wrapper.__original_fn = fn # type: ignore[attr-defined] + wrapper.__script_if_tracing_wrapper = True # type: ignore[attr-defined] + + return wrapper diff --git a/mindnlp/core/jit/annotations.py b/mindnlp/core/jit/annotations.py new file mode 100644 index 000000000..b67c8a807 --- /dev/null +++ b/mindnlp/core/jit/annotations.py @@ -0,0 +1,22 @@ +from core._jit_internal import ( # type: ignore[attr-defined] + # _Await, + # _qualified_name, + # Any, + # BroadcastingList1, + BroadcastingList2, + # BroadcastingList3, + # Dict, + # Future, + # is_await, + # is_dict, + # is_future, + # is_ignored_fn, + # is_list, + # is_optional, + # is_tuple, + # is_union, + # List, + # Optional, + # Tuple, + # Union, +) diff --git a/mindnlp/core/library.py b/mindnlp/core/library.py new file mode 100644 index 000000000..3d9b81195 --- /dev/null +++ b/mindnlp/core/library.py @@ -0,0 +1,4 @@ +def register_fake(*args, **kwargs): + def register(func): + return func + return register \ No newline at end of file diff --git a/mindnlp/core/multiprocessing.py b/mindnlp/core/multiprocessing.py new file mode 100644 index 000000000..bd19b65a6 --- /dev/null +++ b/mindnlp/core/multiprocessing.py @@ -0,0 +1 @@ +from multiprocessing import * \ No newline at end of file diff --git a/mindnlp/core/nn/__init__.py b/mindnlp/core/nn/__init__.py new file mode 100644 index 000000000..c3eecf6d8 --- /dev/null +++ b/mindnlp/core/nn/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""mindnlp nn""" +from . import utils, functional, init +from .modules import * +from .parameter import Parameter diff --git a/mindnlp/core/nn/_reduction.py b/mindnlp/core/nn/_reduction.py new file mode 100644 index 000000000..42e92a5a8 --- /dev/null +++ b/mindnlp/core/nn/_reduction.py @@ -0,0 +1,61 @@ +"""reduction utils""" +import warnings +from typing import Optional + + +# NB: Keep this file in sync with enums in aten/src/ATen/core/Reduction.h + + +def get_enum(reduction: str) -> int: + if reduction == "none": + ret = 0 + elif reduction == "mean": + ret = 1 + elif reduction == "elementwise_mean": + warnings.warn( + "reduction='elementwise_mean' is deprecated. " + "Please use reduction='mean' instead." + ) + ret = 1 + elif reduction == "sum": + ret = 2 + else: + ret = -1 # TODO: remove once JIT exceptions support control flow + raise ValueError(f"{reduction} is not a valid value for reduction") + return ret + + +# In order to support previous versions, accept boolean size_average and reduce +# and convert them into the new constants for now + + +# We use these functions in torch/legacy as well, in which case we'll silence the warning +def legacy_get_string( + size_average: Optional[bool], + reduce: Optional[bool], + emit_warning: bool = True, +) -> str: + warning = "size_average and reduce args will be deprecated, please use reduction='{}' instead." + + if size_average is None: + size_average = True + if reduce is None: + reduce = True + + if size_average and reduce: + ret = "mean" + elif reduce: + ret = "sum" + else: + ret = "none" + if emit_warning: + warnings.warn(warning.format(ret)) + return ret + + +def legacy_get_enum( + size_average: Optional[bool], + reduce: Optional[bool], + emit_warning: bool = True, +) -> int: + return get_enum(legacy_get_string(size_average, reduce, emit_warning)) diff --git a/mindnlp/core/nn/attention/__init__.py b/mindnlp/core/nn/attention/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/mindnlp/core/nn/attention/flex_attention.py b/mindnlp/core/nn/attention/flex_attention.py new file mode 100644 index 000000000..f9205be20 --- /dev/null +++ b/mindnlp/core/nn/attention/flex_attention.py @@ -0,0 +1,3 @@ +BlockMask = None +flex_attention = None +create_block_mask = None diff --git a/mindnlp/core/nn/common_types.py b/mindnlp/core/nn/common_types.py new file mode 100644 index 000000000..b8ec85a9e --- /dev/null +++ b/mindnlp/core/nn/common_types.py @@ -0,0 +1,43 @@ +"""common types""" +from typing import TypeVar, Union, Tuple, Optional +from mindnlp.core import Tensor + +# Create some useful type aliases + +# Template for arguments which can be supplied as a tuple, or which can be a scalar which MindSpore will internally +# broadcast to a tuple. +# Comes in several variants: A tuple of unknown size, and a fixed-size tuple for 1d, 2d, or 3d operations. +T = TypeVar('T') +_scalar_or_tuple_any_t = Union[T, Tuple[T, ...]] +_scalar_or_tuple_1_t = Union[T, Tuple[T]] +_scalar_or_tuple_2_t = Union[T, Tuple[T, T]] +_scalar_or_tuple_3_t = Union[T, Tuple[T, T, T]] +_scalar_or_tuple_4_t = Union[T, Tuple[T, T, T, T]] +_scalar_or_tuple_5_t = Union[T, Tuple[T, T, T, T, T]] +_scalar_or_tuple_6_t = Union[T, Tuple[T, T, T, T, T, T]] + +# For arguments which represent size parameters (eg, kernel size, padding) +_size_any_t = _scalar_or_tuple_any_t[int] +_size_1_t = _scalar_or_tuple_1_t[int] +_size_2_t = _scalar_or_tuple_2_t[int] +_size_3_t = _scalar_or_tuple_3_t[int] +_size_4_t = _scalar_or_tuple_4_t[int] +_size_5_t = _scalar_or_tuple_5_t[int] +_size_6_t = _scalar_or_tuple_6_t[int] + +# For arguments which represent optional size parameters (eg, adaptive pool parameters) +_size_any_opt_t = _scalar_or_tuple_any_t[Optional[int]] +_size_2_opt_t = _scalar_or_tuple_2_t[Optional[int]] +_size_3_opt_t = _scalar_or_tuple_3_t[Optional[int]] + +# For arguments that represent a ratio to adjust each dimension of an input with (eg, upsampling parameters) +_ratio_2_t = _scalar_or_tuple_2_t[float] +_ratio_3_t = _scalar_or_tuple_3_t[float] +_ratio_any_t = _scalar_or_tuple_any_t[float] + +_tensor_list_t = _scalar_or_tuple_any_t[Tensor] + +# For the return value of max pooling operations that may or may not return indices. +# With the proposed 'Literal' feature to Python typing, it might be possible to +# eventually eliminate this. +_maybe_indices_t = _scalar_or_tuple_2_t[Tensor] diff --git a/mindnlp/core/nn/functional.py b/mindnlp/core/nn/functional.py new file mode 100644 index 000000000..a2d180c18 --- /dev/null +++ b/mindnlp/core/nn/functional.py @@ -0,0 +1,1224 @@ +"""nn functional""" +import math +import warnings +from typing import Optional, Tuple, List +import numpy as np +from mindspore import ops, mint +from mindspore.ops.auto_generate.gen_arg_handler import dtype_to_type_id +from mindspore.common.generator import default_generator +from mindspore.ops._primitive_cache import _get_cache_prim + +from mindnlp import core +from mindnlp.core.executor import execute +from ..configs import DEVICE_TARGET, ON_ORANGE_PI, use_pyboost + +generator_step_ = 12 + +def gelu(input, approximate='none'): + if approximate == 'tanh': + return execute('gelu', input) + return input * 0.5 * (1.0 + core.erf(input / math.sqrt(2.0))) + +def relu(input, inplace=False): + if inplace: + execute('inplace_relu', input) + return input + return execute('relu', input) + +def tanh(input, inplace=False): + if inplace: + execute('inplace_tanh', input) + return input + return execute('tanh', input) + + +def sigmoid(input): + return execute('sigmoid', input) + +def silu(input): + return execute('silu', input) + +def mish(input): + return ops.mish(input) + +def relu6(input): + return ops.relu6(input) + +def elu(input, alpha=1.0): + if use_pyboost(): + return mint.nn.functional.elu(input, alpha) + return ops.elu(input, alpha) + +def glu(input, dim=-1): + return ops.glu(input, dim) + +def softplus(input, beta=1, threshold=20): + if use_pyboost(): + return mint.nn.functional.softplus(input, beta, threshold) + return ops.softplus(input, beta, threshold) + +def logsigmoid(input): + return execute('logsigmoid', input) + +def leaky_relu(input, alpha=0.2): + if use_pyboost(): + return mint.nn.functional.leaky_relu(input, alpha) + return ops.leaky_relu(input, alpha) + +def prelu(input, weight): + return ops.prelu(input, weight) + +def celu(input, alpha=1., inplace=False): + return ops.celu(input, alpha) + +def selu(input): + return ops.selu(input) + +def hardsigmoid(input, inplace=False): + return ops.hardsigmoid(input) + +def hardswish(input: core.Tensor, inplace: bool = False) -> core.Tensor: + return ops.hardswish(input) + +def hardshrink(input, lambd=0.5): + return execute('hard_shrink', input, lambd) + +def avg_pool1d(input_array, pool_size, stride, padding=0, ceil_mode=False, count_include_pad=True): + """ + Perform 1D average pooling on the input array of shape (N, C, L) without using explicit for loops. + + Parameters: + - input_array (numpy array): The input array to be pooled, shape (N, C, L). + - pool_size (int): The size of the pooling window. + - stride (int): The stride of the pooling window. + - padding (int): The amount of zero-padding to add to both sides of the input array. + - ceil_mode (bool): If True, use ceil instead of floor to compute the output length. + - count_include_pad (bool): If True, include padding in the average calculation. + + Returns: + - numpy array: The result of the average pooling operation. + """ + N, C, L = input_array.shape + + # Add padding to the input array + if padding > 0: + input_array = ops.pad(input_array, ((0, 0), (0, 0), (padding, padding)), mode='constant', value=(0, 0)) + + # Calculate the output length + if ceil_mode: + output_length = int(np.ceil((L + 2 * padding - pool_size) / stride).astype(int) + 1) + else: + output_length = int(np.floor((L + 2 * padding - pool_size) / stride).astype(int) + 1) + + # Initialize the output array + output_array = ops.zeros((N, C, output_length)) + + # Generate the starting indices of the pooling windows + indices = ops.arange(output_length) * stride + indices = indices[:, None] + ops.arange(pool_size) + + # Ensure indices are within bounds + indices = ops.minimum(indices, input_array.shape[2] - 1) + + # Use advanced indexing to extract the pooling windows + windows = input_array[:, :, indices] + + # Calculate the mean along the pooling window dimension + if count_include_pad: + output_array = ops.mean(windows, axis=-1) + else: + valid_counts = ops.sum(windows != 0, dim=-1) + valid_counts = ops.maximum(valid_counts, 1) # Avoid division by zero + output_array = ops.sum(windows, dim=-1) / valid_counts + + return output_array + +def avg_pool2d(input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True, divisor_override=0): + """ + Perform 2D average pooling on the input array. + + Parameters: + - input_array (numpy array): The input array to be pooled, shape (N, C, H, W). + - pool_size (tuple): The size of the pooling window (pool_height, pool_width). + - stride (tuple): The stride of the pooling window (stride_height, stride_width). + - padding (int or tuple): The amount of zero-padding to add to all sides of the input array. + - ceil_mode (bool): If True, use ceil instead of floor to compute the output length. + - count_include_pad (bool): If True, include padding in the average calculation. + + Returns: + - numpy array: The result of the average pooling operation. + """ + if use_pyboost(): + return mindspore.ops.function.nn_func.avg_pool2d_ext(input, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override) + + return ops.avg_pool2d(input, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override) + +def dropout(input, p=0.5, training=True, inplace=False): + if not training: + return input + seed, offset = default_generator._step(generator_step_) # pylint: disable=protected-access + out, _ = execute('dropout_ext', input, p, seed, offset) + if inplace: + input.copy_(out) + return input + return out + +def dropout2d(input, p=0.5, training=False): + return ops.dropout2d(input, p, training) + +def drop_and_mask(keep_prob, seed=None): + seed0, seed1 = _get_seed(seed, "dropout") + dropout_op = ops.Dropout(keep_prob=keep_prob, Seed0=seed0, Seed1=seed1) + dropout_op = _set_prim_op_user_data(dropout_op, "random_cache", False) + out, mask = dropout_op(input) + return out, mask + +dense_ = ops.Dense() +def linear(input, weight, bias=None): + if ON_ORANGE_PI: + input = input.to(core.float16) + weight = weight.to(core.float16) + if bias is not None: + bias = bias.to(core.float16) + return dense_(input, weight) + bias + return dense_(input, weight) + if use_pyboost(): + return mint.nn.functional.linear(input, weight, bias) + return dense_(input, weight, bias) + +def binary_cross_entropy_with_logits(input, target, weight=None, reduction='mean', pos_weight=None): + if input.shape != target.shape: + target = target.unsqueeze(1).expand_as(input).to(input.dtype) + if use_pyboost(): + return mint.nn.functional.binary_cross_entropy_with_logits(input, target, weight, reduction, pos_weight) + return ops.binary_cross_entropy_with_logits(input, target.astype(input.dtype), weight, pos_weight, reduction) + +def gumbel_softmax(logits: core.Tensor, tau: float = 1, hard: bool = False, eps: float = 1e-10, dim: int = -1) -> core.Tensor: + if eps != 1e-10: + warnings.warn("`eps` parameter is deprecated and has no effect.") + + uniform_samples = _get_cache_prim(ops.UniformReal)()(logits.shape) + gumbels = -ops.log(-ops.log(uniform_samples + eps) + eps) # ~Gumbel(0, 1) + gumbels = (logits + gumbels) / tau # ~Gumbel(logits,tau) + y_soft = softmax(gumbels, dim) + + if hard: + # Straight through. + index = y_soft.argmax(dim) + y_hard = one_hot(index, logits.shape[dim]) + ret = ops.stop_gradient(y_hard - y_soft) + y_soft + else: + # Reparametrization trick. + ret = y_soft + return ret + +def log_softmax(input, dim=-1, dtype=None): + if input.device.type == 'cpu': + return execute('log_softmax', input, dim) + return execute('log_softmax_ext', input, dim, + dtype if dtype is None else dtype_to_type_id('LogSoftmaxExt', 'dtype', dtype)) + +def embedding(input, weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False): + if use_pyboost(): + return mint.nn.functional.embedding(input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq) + return ops.gather(weight, input, 0) + +def rms_norm(input, normalized_shape, weight, eps=1e-5): + return execute('rms_norm', input, weight, eps)[0] + +def fast_gelu(x): + return ops.fast_gelu(x) + +def swiglu(x, dim=-1): + return execute('swiglu', x, dim) + +def apply_rotary_pos_emb(query, key, cos, sin, position_ids, cos_format=0): + return mindspore.ops.auto_generate.gen_ops_def.apply_rotary_pos_emb_( + query, key, cos, sin, position_ids, cos_format + ) + +def _reflection_pad(input, pad): + """reflection pad""" + out = input + if len(pad) == 2: + out = execute('reflection_pad_1d', input, pad) + elif len(pad) == 4: + out = execute('reflection_pad_2d', input, pad) + else: + out = execute('reflection_pad_3d', input, pad) + return out + +def _replication_pad(input, pad): + """replication pad""" + out = input + if len(pad) == 2: + out = execute('replication_pad_1d', input, pad) + elif len(pad) == 4: + out = execute('replication_pad_2d', input, pad) + else: + out = execute('replication_pad_3d', input, pad) + return out + +def pad(input, pad, mode='constant', value=0.0): + out = input + if (isinstance(pad, tuple) and not pad): + return out + if mode == "constant": + value = 0 if value is None else value + out = execute('constant_pad_nd', input, pad, value) + else: + if value is not None and value != 0: + raise ValueError(f"Padding mode {mode} doesn\'t take in value argument.") + if mode == "circular": + out = _circular_pad(input, pad) + elif mode == "reflect": + out = _reflection_pad(input, pad) + elif mode == "replicate": + out = _replication_pad(input, pad) + else: + raise ValueError(f"Pad filling mode must be 'constant' 'circular' 'reflect' or 'replicate'.") + return out + +def nll_loss(input, target, weight=None, ignore_index=-100, reduction='mean', label_smoothing=0.0): + return _inner_nll_loss(input, target, weight, ignore_index, reduction, label_smoothing) + +def cross_entropy(input, target, weight=None, ignore_index=-100, reduction='mean', label_smoothing=0.0): + input = input.to(core.float32) + class_dim = 0 if input.ndim == 1 else 1 + if target.dtype in [core.float32, core.float16]: + return _cross_entropy(input, target, class_dim, weight, reduction, label_smoothing) + return nll_loss(log_softmax(input, class_dim), target, weight, ignore_index, reduction, label_smoothing) + + +def _cross_entropy(inputs, target, target_dim, weight=None, reduction='mean', label_smoothing=0.0): + """cross entropy inner function""" + class_dim = 0 if inputs.ndim == 1 else 1 + n_classes = inputs.shape[class_dim] + inputs = log_softmax(inputs, class_dim) + if label_smoothing > 0.0: + target = target * (1 - label_smoothing) + label_smoothing / n_classes + + if weight is None: + weight = core.ones_like(inputs) + elif inputs.ndim != 1: + broadcast_shape = [1 for _ in range(inputs.ndim)] + broadcast_shape[1] = weight.shape[0] + weight = weight.reshape(broadcast_shape) + + if reduction == 'mean': + return -(inputs * target * weight).sum() / (inputs.size / n_classes) + if reduction == 'sum': + return -(inputs * target * weight).sum() + return -(inputs * target * weight).sum(class_dim) + + +def _inner_nll_loss(inputs, target, weight=None, ignore_index=-100, reduction='mean', label_smoothing=0.0): + ndim = inputs.ndim + if ndim == 2: + ret = _nll_loss(inputs, target, -1, weight, ignore_index, reduction, label_smoothing) + elif ndim == 4: + ret = _nll_loss(inputs, target, 1, weight, ignore_index, reduction, label_smoothing) + elif ndim == 1: + ret = _nll_loss(inputs, target, 0, weight, ignore_index, reduction, label_smoothing) + else: + n = inputs.shape[0] + c = inputs.shape[1] + out_size = (n,) + inputs.shape[2:] + inputs = inputs.view((n, c, 1, -1)) + target = target.view((n, 1, -1)) + if reduction != 'none': + ret = _nll_loss(inputs, target, 1, weight, ignore_index, reduction, label_smoothing) + else: + ret = _nll_loss(inputs, target, 1, weight, ignore_index, label_smoothing=label_smoothing) + ret = ret.view(out_size) + return ret + + +def _nll_loss(inputs, target, target_dim=-1, weight=None, ignore_index=None, reduction='none', label_smoothing=0.0): + """nll loss inner function""" + if target.ndim == inputs.ndim - 1: + target = target.unsqueeze(target_dim) + if ignore_index is not None: + non_pad_mask = core.equal(target, ignore_index) + target = target.masked_fill(non_pad_mask, 0) + else: + non_pad_mask = target + if weight is not None: + loss_weights = core.gather(weight, 0, target) + orig_shape = inputs.shape + if inputs.ndim != 2: + inputs = inputs.view(orig_shape[:2] + (-1,)) + weight = weight.view(weight.shape + (1,)) + weighted_inputs = inputs * weight + weighted_inputs = weighted_inputs.view(orig_shape) + loss = core.neg(core.gather(weighted_inputs, target_dim, target)) + smooth_loss = core.neg(weighted_inputs.sum(dim=target_dim, keepdim=True)) + else: + loss = core.neg(core.gather(inputs, target_dim, target)) + smooth_loss = core.neg(inputs.sum(dim=target_dim, keepdim=True)) + loss_weights = core.ones_like(loss) + + if ignore_index is not None: + loss = loss.masked_fill(non_pad_mask, 0.) + loss_weights = loss_weights.masked_fill(non_pad_mask, 0.) + smooth_loss = smooth_loss.masked_fill(non_pad_mask, 0.) + + loss = loss.squeeze(target_dim) + smooth_loss = smooth_loss.squeeze(target_dim) + + if reduction == 'sum': + loss = loss.sum() + smooth_loss = smooth_loss.sum() + if reduction == 'mean': + loss = loss.sum() / loss_weights.sum() + smooth_loss = smooth_loss.sum() / loss_weights.sum() + + eps_i = label_smoothing / inputs.shape[target_dim] + if label_smoothing != 0: + loss = (1. - label_smoothing) * loss + eps_i * smooth_loss + + return loss + +def mse_loss(input, target, reduction='mean'): + return ops.mse_loss(input, target, reduction) + +def l1_loss(input, target, reduction='mean'): + return ops.l1_loss(input, target, reduction) + +def smooth_l1_loss(input, target, beta=1.0, reduction='none'): + input = input.to(mindspore.float32) + target = target.to(mindspore.float32) + return ops.smooth_l1_loss(input, target, beta, reduction) + +def kl_div(logits, labels, reduction='mean', log_target=False): + if log_target: + labels = ops.log(labels) + return ops.kl_div(logits, labels, reduction) + +def manual_softmax(x, dim=-1): + exp_x = ops.exp(x - ops.max(x, axis=dim, keepdims=True)[0]) + return exp_x / ops.sum(exp_x, dim=dim, keepdim=True) + +def softmax(input, dim=-1, *, dtype=None): + if use_pyboost(): + return mint.nn.functional.softmax(input, dim, dtype=dtype) + if dtype is not None: + input = input.to(dtype) + if dim is None: + dim = -1 + if ON_ORANGE_PI: + return manual_softmax(input, dim) + softmax_ = _get_cache_prim(ops.Softmax)(dim) + return softmax_(input) + +def layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-5): + if weight is None: + weight = ops.ones(normalized_shape, dtype=input.dtype) + if bias is None: + bias = ops.zeros(normalized_shape, dtype=input.dtype) + if use_pyboost(): + return mint.nn.functional.layer_norm(input, normalized_shape, weight, bias, eps) + if weight is not None: + begin_axis = input.ndim - weight.ndim + else: + begin_axis = -1 + _layer_norm = _get_cache_prim(ops.LayerNorm)(begin_axis, begin_axis, epsilon=eps) + return _layer_norm(input, weight, bias)[0] + +def interpolate(input, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False): + return ops.interpolate(input, size, scale_factor, mode, align_corners, recompute_scale_factor) + +def normalize(input, p=2.0, dim=1, eps=1e-6): + r""" + Normalize a tensor along a specified dimension. + + Args: + input (core.Tensor): The input tensor to be normalized. + p (float, optional): The power parameter for the normalization. Default is 2.0. + dim (int, optional): The dimension along which to normalize the tensor. Default is 1. + + Returns: + None + + Raises: + TypeError: If the input is not a tensor. + ValueError: If the specified dimension is out of range or if the power parameter is not a positive number. + + This function normalizes the input tensor along the specified dimension using the power parameter 'p'. + The normalization is performed by dividing each element of the tensor by the Lp norm of the tensor along the specified dimension. + The Lp norm is defined as the p-th root of the sum of the absolute values raised to the power of 'p'. + The resulting tensor will have the same shape as the input tensor. + """ + return input / ops.norm(input, ord=p, dim=dim, keepdim=True) + +def batch_norm(input, running_mean, running_var, weight=None, bias=None, training=False, momentum=0.1, eps=1e-05): + + if running_mean is None: + running_mean = ops.ones(input.shape[1]) + if running_var is None: + running_var = ops.zeros(input.shape[1]) + if weight is None: + weight = ops.ones(input.shape[1]) + if bias is None: + bias = ops.zeros(input.shape[1]) + + if use_pyboost(): + return mint.nn.functional.batch_norm( + input, + running_mean, + running_var, + weight, + bias, + training, + momentum, + eps + ) + return ops.batch_norm( + input, + running_mean, + running_var, + weight, + bias, + training, + momentum, + eps + ) + +def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): + pad_mode = 'pad' + if not isinstance(padding, (int, tuple)): + pad_mode = padding + + return ops.conv2d(input, weight, bias=bias, stride=stride, pad_mode=pad_mode, padding=padding, dilation=dilation, groups=groups) + +def max_pool2d(input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False): + if use_pyboost(): + return mint.nn.functional.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode=ceil_mode, return_indices=return_indices) + return ops.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode=ceil_mode, return_indices=return_indices) + +def max_pool1d(input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False): + if stride is None: + stride = kernel_size + + kernel_size = (1, kernel_size) + stride = (1, stride) + padding = (0, padding) + dilation = (1, dilation) + + input_2d = input.unsqueeze(2) + + if return_indices: + output_2d, indices_2d = max_pool2d(input_2d, kernel_size, stride, padding, dilation, ceil_mode, return_indices) + output_1d = output_2d.squeeze(2) + indices_1d = indices_2d.squeeze(2) + return output_1d, indices_1d + else: + output_2d = max_pool2d(input_2d, kernel_size, stride, padding, dilation, ceil_mode) + output_1d = output_2d.squeeze(2) + return output_1d + + +def group_norm(input, num_groups, weight=None, bias=None, eps=1e-5): + if use_pyboost(): + return mint.nn.functional.group_norm(input, num_groups, weight, bias, eps) + + input_shape = input.shape + N = input_shape[0] + C = input_shape[1] + input_reshaped = input.view(1, N * num_groups, -1 if N!=0 else 1) + outputs = batch_norm(input_reshaped, None, None, None, None, True, 0., eps) + out = outputs.view(input_shape) + affine_param_shape = [1] * input.ndim + affine_param_shape[1] = C + affine_param_shape = tuple(affine_param_shape) + if weight is not None and bias is not None: + out = bias.view(affine_param_shape).addcmul(out, weight.view(affine_param_shape), 1) + elif weight is not None: + out = out.mul(weight.view(affine_param_shape)) + elif bias is not None: + out = out.add(bias.view(affine_param_shape)) + return out + + +def _in_projection( + q, + k, + v, + w_q, + w_k, + w_v, + b_q=None, + b_k=None, + b_v=None, +): + r""" + Performs the in-projection step of the attention operation. This is simply + a triple of linear projections, with shape constraints on the weights which + ensure embedding dimension uniformity in the projected outputs. + Output is a triple containing projection tensors for query, key and value. + Args: + q, k, v: query, key and value tensors to be projected. + w_q, w_k, w_v: weights for q, k and v, respectively. + b_q, b_k, b_v: optional biases for q, k and v, respectively. + Shape: + Inputs: + - q: :math:`(Qdims..., Eq)` where Eq is the query embedding dimension and Qdims are any + number of leading dimensions. + - k: :math:`(Kdims..., Ek)` where Ek is the key embedding dimension and Kdims are any + number of leading dimensions. + - v: :math:`(Vdims..., Ev)` where Ev is the value embedding dimension and Vdims are any + number of leading dimensions. + - w_q: :math:`(Eq, Eq)` + - w_k: :math:`(Eq, Ek)` + - w_v: :math:`(Eq, Ev)` + - b_q: :math:`(Eq)` + - b_k: :math:`(Eq)` + - b_v: :math:`(Eq)` + Output: in output triple :math:`(q', k', v')`, + - q': :math:`[Qdims..., Eq]` + - k': :math:`[Kdims..., Eq]` + - v': :math:`[Vdims..., Eq]` + """ + Eq, Ek, Ev = q.shape[-1], k.shape[-1], v.shape[-1] + assert w_q.shape == ( + Eq, Eq), f"expecting query weights shape of {(Eq, Eq)}, but got {w_q.shape}" + assert w_k.shape == ( + Eq, Ek), f"expecting key weights shape of {(Eq, Ek)}, but got {w_k.shape}" + assert w_v.shape == ( + Eq, Ev), f"expecting value weights shape of {(Eq, Ev)}, but got {w_v.shape}" + assert b_q is None or b_q.shape == ( + Eq,), f"expecting query bias shape of {(Eq,)}, but got {b_q.shape}" + assert b_k is None or b_k.shape == ( + Eq,), f"expecting key bias shape of {(Eq,)}, but got {b_k.shape}" + assert b_v is None or b_v.shape == ( + Eq,), f"expecting value bias shape of {(Eq,)}, but got {b_v.shape}" + return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v) + + +def _in_projection_packed( + q: core.Tensor, + k: core.Tensor, + v: core.Tensor, + w: core.Tensor, + b: Optional[core.Tensor] = None, +) -> List[core.Tensor]: + r""" + Performs the in-projection step of the attention operation, using packed weights. + Output is a triple containing projection tensors for query, key and value. + + Args: + q, k, v: query, key and value tensors to be projected. For self-attention, + these are typically the same tensor; for encoder-decoder attention, + k and v are typically the same tensor. (We take advantage of these + identities for performance if they are present.) Regardless, q, k and v + must share a common embedding dimension; otherwise their shapes may vary. + w: projection weights for q, k and v, packed into a single tensor. Weights + are packed along dimension 0, in q, k, v order. + b: optional projection biases for q, k and v, packed into a single tensor + in q, k, v order. + + Shape: + Inputs: + - q: :math:`(..., E)` where E is the embedding dimension + - k: :math:`(..., E)` where E is the embedding dimension + - v: :math:`(..., E)` where E is the embedding dimension + - w: :math:`(E * 3, E)` where E is the embedding dimension + - b: :math:`E * 3` where E is the embedding dimension + + Output: + - in output list :math:`[q', k', v']`, each output tensor will have the + same shape as the corresponding input tensor. + """ + E = q.shape[-1] + if k is v: + if q is k: + # self-attention + # proj = linear(q, w, b) + # # reshape to 3, E and not E, 3 is deliberate for better memory coalescing and keeping same order as chunk() + # proj = proj.unflatten(-1, (3, E)).unsqueeze(0).swapaxes(0, -2).squeeze(-2) + # return proj[0], proj[1], proj[2] + return linear(q, w, b).chunk(3, axis=-1) + else: + # encoder-decoder attention + w_q, w_kv = w.split([E, E * 2]) + if b is None: + b_q = b_kv = None + else: + b_q, b_kv = b.split([E, E * 2]) + # q_proj = linear(q, w_q, b_q) + # kv_proj = linear(k, w_kv, b_kv) + # # reshape to 2, E and not E, 2 is deliberate for better memory coalescing and keeping same order as chunk() + # kv_proj = kv_proj.unflatten(-1, (2, E)).unsqueeze(0).swapaxes(0, -2).squeeze(-2) + # return (q_proj, kv_proj[0], kv_proj[1]) + return (linear(q, w_q, b_q),) + linear(k, w_kv, b_kv).chunk(2, axis=-1) + else: + w_q, w_k, w_v = w.chunk(3) + if b is None: + b_q = b_k = b_v = None + else: + b_q, b_k, b_v = b.chunk(3) + return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v) + +def scaled_dot_product_attention(query, key, value, attn_mask, dropout_p, is_causal): + embed_size = query.shape[-1] + scaling_factor = ops.sqrt(ops.sqrt(core.Tensor(embed_size, dtype=query.dtype))) + query = query / scaling_factor + + if is_causal: + L = query.shape[-2], S = key.shape[-2] + attn_mask = ops.ones((L, S), mindspore.bool_).tril() + + attn = ops.matmul(query, key.swapaxes(-2, -1) / scaling_factor) + if attn_mask is not None: + attn = attn + attn_mask + attn = softmax(attn, -1) + if dropout_p > 0.: + attn = ops.dropout(attn, dropout_p) + output = ops.matmul(attn, value) + + return output + + +def _mha_shape_check(query, key, value, key_padding_mask, attn_mask, num_heads): + # Verifies the expected shape for `query, `key`, `value`, `key_padding_mask` and `attn_mask` + # and returns if the input is batched or not. + # Raises an error if `query` is not 2-D (unbatched) or 3-D (batched) tensor. + + # Shape check. + if query.ndim == 3: + # Batched Inputs + is_batched = True + assert key.ndim == 3 and value.ndim == 3, \ + ("For batched (3-D) `query`, expected `key` and `value` to be 3-D" + f" but found {key.ndim}-D and {value.ndim}-D tensors respectively") + if key_padding_mask is not None: + assert key_padding_mask.ndim == 2, \ + ("For batched (3-D) `query`, expected `key_padding_mask` to be `None` or 2-D" + f" but found {key_padding_mask.ndim}-D tensor instead") + if attn_mask is not None: + assert attn_mask.ndim in (2, 3), \ + ("For batched (3-D) `query`, expected `attn_mask` to be `None`, 2-D or 3-D" + f" but found {attn_mask.ndim}-D tensor instead") + elif query.ndim == 2: + # Unbatched Inputs + is_batched = False + assert key.ndim == 2 and value.ndim == 2, \ + ("For unbatched (2-D) `query`, expected `key` and `value` to be 2-D" + f" but found {key.ndim}-D and {value.ndim}-D tensors respectively") + + if key_padding_mask is not None: + assert key_padding_mask.ndim == 1, \ + ("For unbatched (2-D) `query`, expected `key_padding_mask` to be `None` or 1-D" + f" but found {key_padding_mask.ndim}-D tensor instead") + + if attn_mask is not None: + assert attn_mask.ndim in (2, 3), \ + ("For unbatched (2-D) `query`, expected `attn_mask` to be `None`, 2-D or 3-D" + f" but found {attn_mask.ndim}-D tensor instead") + if attn_mask.ndim == 3: + expected_shape = (num_heads, query.shape[0], key.shape[0]) + assert attn_mask.shape == expected_shape, \ + (f"Expected `attn_mask` shape to be {expected_shape} but got {attn_mask.shape}") + else: + raise AssertionError( + f"query should be unbatched 2D or batched 3D tensor but received {query.ndim}-D query tensor") + + return is_batched + + +def multi_head_attention_forward( + query: core.Tensor, + key: core.Tensor, + value: core.Tensor, + embed_dim_to_check: int, + num_heads: int, + in_proj_weight: Optional[core.Tensor], + in_proj_bias: Optional[core.Tensor], + bias_k: Optional[core.Tensor], + bias_v: Optional[core.Tensor], + add_zero_attn: bool, + dropout_p: float, + out_proj_weight: core.Tensor, + out_proj_bias: Optional[core.Tensor], + training: bool = True, + key_padding_mask: Optional[core.Tensor] = None, + need_weights: bool = True, + attn_mask: Optional[core.Tensor] = None, + use_separate_proj_weight: bool = False, + q_proj_weight: Optional[core.Tensor] = None, + k_proj_weight: Optional[core.Tensor] = None, + v_proj_weight: Optional[core.Tensor] = None, + static_k: Optional[core.Tensor] = None, + static_v: Optional[core.Tensor] = None, + average_attn_weights: bool = True, + is_causal: bool = False, +) -> Tuple[core.Tensor, Optional[core.Tensor]]: + r""" + Args: + query, key, value: map a query and a set of key-value pairs to an output. + See "Attention Is All You Need" for more details. + embed_dim_to_check: total dimension of the model. + num_heads: parallel attention heads. + in_proj_weight, in_proj_bias: input projection weight and bias. + bias_k, bias_v: bias of the key and value sequences to be added at dim=0. + add_zero_attn: add a new batch of zeros to the key and + value sequences at dim=1. + dropout_p: probability of an element to be zeroed. + out_proj_weight, out_proj_bias: the output projection weight and bias. + training: apply dropout if is ``True``. + key_padding_mask: if provided, specified padding elements in the key will + be ignored by the attention. This is an binary mask. When the value is True, + the corresponding value on the attention layer will be filled with -inf. + need_weights: output attn_output_weights. + Default: `True` + Note: `needs_weight` defaults to `True`, but should be set to `False` + For best performance when attention weights are not needed. + *Setting needs_weights to `True` + leads to a significant performance degradation.* + attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all + the batches while a 3D mask allows to specify a different mask for the entries of each batch. + is_causal: If specified, applies a causal mask as attention mask, and ignores + attn_mask for computing scaled dot product attention. + Default: ``False``. + .. warning:: + is_causal is provides a hint that the attn_mask is the + causal mask.Providing incorrect hints can result in + incorrect execution, including forward and backward + compatibility. + use_separate_proj_weight: the function accept the proj. weights for query, key, + and value in different forms. If false, in_proj_weight will be used, which is + a combination of q_proj_weight, k_proj_weight, v_proj_weight. + q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias. + static_k, static_v: static key and value used for attention operators. + average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across heads. + Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an effect + when ``need_weights=True.``. Default: True + + + Shape: + Inputs: + - query: :math:`(L, E)` or :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - key: :math:`(S, E)` or :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - value: :math:`(S, E)` or :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - key_padding_mask: :math:`(S)` or :math:`(N, S)` where N is the batch size, S is the source sequence length. + If a Floatcore.Tensor is provided, it will be directly added to the value. + If a Boolcore.Tensor is provided, the positions with the + value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. + - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. + 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, + S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked + positions. If a Boolcore.Tensor is provided, positions with ``True`` + are not allowed to attend while ``False`` values will be unchanged. If a Floatcore.Tensor + is provided, it will be added to the attention weight. + - static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length, + N is the batch size, E is the embedding dimension. E/num_heads is the head dimension. + - static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length, + N is the batch size, E is the embedding dimension. E/num_heads is the head dimension. + + Outputs: + - attn_output: :math:`(L, E)` or :math:`(L, N, E)` where L is the target sequence length, N is the batch size, + E is the embedding dimension. + - attn_output_weights: Only returned when ``need_weights=True``. If ``average_attn_weights=True``, returns + attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or + :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and + :math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per + head of shape :math:`(num_heads, L, S)` when input is unbatched or :math:`(N, num_heads, L, S)`. + """ + is_batched = _mha_shape_check(query, key, value, key_padding_mask, attn_mask, num_heads) + + # For unbatched input, we unsqueeze at the expected batch-dim to pretend that the input + # is batched, run the computation and before returning squeeze the + # batch dimension so that the output doesn't carry this temporary batch dimension. + if not is_batched: + # unsqueeze if the input is unbatched + query = query.unsqueeze(1) + key = key.unsqueeze(1) + value = value.unsqueeze(1) + if key_padding_mask is not None: + key_padding_mask = key_padding_mask.unsqueeze(0) + + # set up shape vars + tgt_len, bsz, embed_dim = query.shape + src_len, _, _ = key.shape + + key_padding_mask = _canonical_mask( + mask=key_padding_mask, + mask_name="key_padding_mask", + other_type=_none_or_dtype(attn_mask), + other_name="attn_mask", + target_type=query.dtype + ) + + if is_causal and attn_mask is None: + raise RuntimeError( + "Need attn_mask if specifying the is_causal hint. " + "You may use the Transformer module method " + "`generate_square_subsequent_mask` to create this mask." + ) + + if is_causal and key_padding_mask is None and not need_weights: + # when we have a kpm or need weights, we need attn_mask + # Otherwise, we use the is_causal hint go as is_causal + # indicator to SDPA. + attn_mask = None + else: + attn_mask = _canonical_mask( + mask=attn_mask, + mask_name="attn_mask", + other_type=None, + other_name="", + target_type=query.dtype, + check_other=False, + ) + + if key_padding_mask is not None: + # We have the attn_mask, and use that to merge kpm into it. + # Turn off use of is_causal hint, as the merged mask is no + # longer causal. + is_causal = False + + assert embed_dim == embed_dim_to_check, \ + f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}" + if isinstance(embed_dim, core.Tensor): + # embed_dim can be a tensor when JIT tracing + head_dim = embed_dim.div(num_heads, rounding_mode='trunc') + else: + head_dim = embed_dim // num_heads + assert head_dim * num_heads == embed_dim, f"embed_dim {embed_dim} not divisible by num_heads {num_heads}" + if use_separate_proj_weight: + # allow MHA to have different embedding dimensions when separate projection weights are used + assert key.shape[:2] == value.shape[:2], \ + f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}" + else: + assert key.shape == value.shape, f"key shape {key.shape} does not match value shape {value.shape}" + + # + # compute in-projection + # + if not use_separate_proj_weight: + assert in_proj_weight is not None, "use_separate_proj_weight is False but in_proj_weight is None" + q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias) + else: + assert q_proj_weight is not None, "use_separate_proj_weight is True but q_proj_weight is None" + assert k_proj_weight is not None, "use_separate_proj_weight is True but k_proj_weight is None" + assert v_proj_weight is not None, "use_separate_proj_weight is True but v_proj_weight is None" + if in_proj_bias is None: + b_q = b_k = b_v = None + else: + b_q, b_k, b_v = in_proj_bias.chunk(3) + q, k, v = _in_projection(query, key, value, q_proj_weight, k_proj_weight, v_proj_weight, b_q, b_k, b_v) + + # prep attention mask + + if attn_mask is not None: + # ensure attn_mask's dim is 3 + if attn_mask.ndim == 2: + correct_2d_size = (tgt_len, src_len) + if attn_mask.shape != correct_2d_size: + raise RuntimeError(f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.") + attn_mask = attn_mask.unsqueeze(0) + elif attn_mask.ndim == 3: + correct_3d_size = (bsz * num_heads, tgt_len, src_len) + if attn_mask.shape != correct_3d_size: + raise RuntimeError(f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}.") + else: + raise RuntimeError(f"attn_mask's dimension {attn_mask.ndim} is not supported") + + # add bias along batch dimension (currently second) + if bias_k is not None and bias_v is not None: + assert static_k is None, "bias cannot be added to static key." + assert static_v is None, "bias cannot be added to static value." + k = ops.cat([k, bias_k.repeat(1, bsz, 1)]) + v = ops.cat([v, bias_v.repeat(1, bsz, 1)]) + if attn_mask is not None: + attn_mask = pad(attn_mask, (0, 1)) + if key_padding_mask is not None: + key_padding_mask = pad(key_padding_mask, (0, 1)) + else: + assert bias_k is None + assert bias_v is None + + # + # reshape q, k, v for multihead attention and make em batch first + # + q = q.view(tgt_len, bsz * num_heads, head_dim).swapaxes(0, 1) + if static_k is None: + k = k.view(k.shape[0], bsz * num_heads, head_dim).swapaxes(0, 1) + else: + # TODO finish disentangling control flow so we don't do in-projections when statics are passed + assert static_k.shape[0] == bsz * num_heads, \ + f"expecting static_k.shape[0] of {bsz * num_heads}, but got {static_k.shape[0]}" + assert static_k.shape[2] == head_dim, \ + f"expecting static_k.shape[2] of {head_dim}, but got {static_k.shape[2]}" + k = static_k + if static_v is None: + v = v.view(v.shape[0], bsz * num_heads, head_dim).swapaxes(0, 1) + else: + # TODO finish disentangling control flow so we don't do in-projections when statics are passed + assert static_v.shape[0] == bsz * num_heads, \ + f"expecting static_v.shape[0] of {bsz * num_heads}, but got {static_v.shape[0]}" + assert static_v.shape[2] == head_dim, \ + f"expecting static_v.shape[2] of {head_dim}, but got {static_v.shape[2]}" + v = static_v + + # add zero attention along batch dimension (now first) + if add_zero_attn: + zero_attn_shape = (bsz * num_heads, 1, head_dim) + k = ops.cat([k, ops.zeros(zero_attn_shape, dtype=k.dtype)], axis=1) + v = ops.cat([v, ops.zeros(zero_attn_shape, dtype=v.dtype)], axis=1) + if attn_mask is not None: + attn_mask = pad(attn_mask, (0, 1)) + if key_padding_mask is not None: + key_padding_mask = pad(key_padding_mask, (0, 1)) + + # update source sequence length after adjustments + src_len = k.shape[1] + + # merge key padding and attention masks + if key_padding_mask is not None: + assert key_padding_mask.shape == (bsz, src_len), \ + f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}" + key_padding_mask = key_padding_mask.view(bsz, 1, 1, src_len). \ + broadcast_to((-1, num_heads, -1, -1)).reshape(bsz * num_heads, 1, src_len) + if attn_mask is None: + attn_mask = key_padding_mask + else: + attn_mask = attn_mask + key_padding_mask + + # adjust dropout probability + if not training: + dropout_p = 0.0 + + # + # (deep breath) calculate attention and out projection + # + + if need_weights: + B, Nt, E = q.shape + q_scaled = q / math.sqrt(E) + + assert not (is_causal and attn_mask is None), "FIXME: is_causal not implemented for need_weights" + + if attn_mask is not None: + attn_output_weights = ops.baddbmm(attn_mask, q_scaled, k.swapaxes(-2, -1)) + else: + attn_output_weights = ops.bmm(q_scaled, k.swapaxes(-2, -1)) + attn_output_weights = softmax(attn_output_weights, dim=-1) + if dropout_p > 0.0: + attn_output_weights = dropout(attn_output_weights, p=dropout_p) + + attn_output = ops.bmm(attn_output_weights, v) + + attn_output = attn_output.swapaxes(0, 1).view(tgt_len * bsz, embed_dim) + attn_output = linear(attn_output, out_proj_weight, out_proj_bias) + attn_output = attn_output.view(tgt_len, bsz, attn_output.shape[1]) + + # optionally average attention weights over heads + attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len) + if average_attn_weights: + attn_output_weights = attn_output_weights.mean(axis=1) + + if not is_batched: + # squeeze the output if input was unbatched + attn_output = attn_output.squeeze(1) + attn_output_weights = attn_output_weights.squeeze(0) + return attn_output, attn_output_weights + else: + # attn_mask can be either (L,S) or (N*num_heads, L, S) + # if attn_mask's shape is (1, L, S) we need to unsqueeze to (1, 1, L, S) + # in order to match the input for SDPA of (N, num_heads, L, S) + if attn_mask is not None: + if attn_mask.shape[0] == 1 and attn_mask.ndim == 3: + attn_mask = attn_mask.unsqueeze(0) + else: + attn_mask = attn_mask.view(bsz, num_heads, -1, src_len) + + q = q.view(bsz, num_heads, tgt_len, head_dim) + k = k.view(bsz, num_heads, src_len, head_dim) + v = v.view(bsz, num_heads, src_len, head_dim) + + attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal) + attn_output = attn_output.permute(2, 0, 1, 3).view(bsz * tgt_len, embed_dim) + + attn_output = linear(attn_output, out_proj_weight, out_proj_bias) + attn_output = attn_output.view(tgt_len, bsz, attn_output.shape[1]) + if not is_batched: + # squeeze the output if input was unbatched + attn_output = attn_output.squeeze(1) + return attn_output, None + +def _canonical_mask( + mask: Optional[core.Tensor], + mask_name: str, + other_type: Optional[int], + other_name: str, + target_type: int, + check_other: bool = True, +) -> Optional[core.Tensor]: + if mask is not None: + _mask_dtype = mask.dtype + _mask_is_float = ops.is_floating_point(mask) + if _mask_dtype != mindspore.bool_ and not _mask_is_float: + raise AssertionError( + f"only bool and floating types of {mask_name} are supported") + if check_other and other_type is not None: + if _mask_dtype != other_type: + warnings.warn( + f"Support for mismatched {mask_name} and {other_name} " + "is deprecated. Use same type for both instead." + ) + if not _mask_is_float: + zero_tensor = ops.zeros_like(mask, dtype=target_type) + mask = ops.where(mask, core.Tensor(float("-inf"), target_type), zero_tensor) + # mask = ( + # ops.zeros_like(mask, dtype=target_type) + # .masked_fill_(mask, float("-inf")) + # ) + return mask + +def _none_or_dtype(input: Optional[core.Tensor]) -> Optional[int]: + if input is None: + return None + elif isinstance(input, core.Tensor): + return input.dtype + raise RuntimeError("input to _none_or_dtype() must be None or core.Tensor") + +def unfold(input, kernel_size, dilation=1, padding=0, stride=1): + if use_pyboost(): + return mint.nn.functional.unfold(input, kernel_size, dilation, padding, stride) + return ops.unfold(input, kernel_size, dilation, padding, stride) + +def fold(input, output_size, kernel_size, dilation=1, padding=0, stride=1): + if use_pyboost(): + return mint.nn.functional.fold(input, output_size, kernel_size, dilation, padding, stride) + return ops.fold(input, output_size, kernel_size, dilation, padding, stride) + +def conv1d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): + pad_mode = 'pad' + pad = padding + if isinstance(padding, tuple): + pad = (0, 0, padding[0], padding[0]) + elif isinstance(padding, int): + pad = (0, 0) + (padding,) * 2 + if not isinstance(padding, (int, tuple)): + pad_mode = padding + pad = (0,) * 4 + + _conv2d = _get_cache_prim(ops.Conv2D)(out_channel=weight.shape[0] * groups, + kernel_size=(1, weight.shape[-1]), + mode=1, + pad_mode=pad_mode, + pad=pad, + stride=(1, stride), + dilation=(1, dilation), + group=groups) + + input = input.expand_dims(2) + output = _conv2d(input, weight.expand_dims(2)) + + if bias is not None: + output = ops.bias_add(output, bias) + + output = output.squeeze(2) + return output + +def ctc_loss(log_probs, targets, input_lengths, target_lengths, blank=0, reduction='mean', zero_infinity=False): + ctc_loss_op = _get_cache_prim(nn_ops.CTCLossV2)(blank=blank, reduction="none", zero_infinity=zero_infinity) + loss, _ = ctc_loss_op(log_probs, targets, input_lengths, target_lengths) + if zero_infinity: + loss = ops.where(ops.isinf(loss), 0., loss) + if reduction == 'sum': + loss = loss.sum() + if reduction == 'mean': + input_type = loss.dtype + target_length_t = target_lengths.clip(1., None) + loss = loss.astype("float32") + loss = loss / target_length_t + loss = loss.mean() + loss = loss.astype(input_type) + return loss + +def one_hot(tensor, num_classes=-1): + if use_pyboost(): + return mint.nn.functional.one_hot(tensor, num_classes) + return ops.one_hot(tensor, num_classes) + +def pixel_shuffle(input, upscale_factor): + return ops.pixel_shuffle(input, upscale_factor) + +def pixel_unshuffle(input, downscale_factor): + return ops.pixel_shuffle(input, downscale_factor) + +def grid_sample(input, grid, mode='bilinear', padding_mode='zeros', align_corners=False): + if use_pyboost(): + return mint.nn.functional.grid_sample(input, grid, mode, padding_mode, align_corners) + return ops.grid_sample(input, grid, mode, padding_mode, align_corners) + +def cosine_similarity(x1, x2, dim=1, eps=1e-8): + if DEVICE_TARGET == 'Ascend': + zero_norm_mask = ((x1.sum(dim) == 0).int() & (x2.sum(dim) == 0).int()).bool() + else: + zero_norm_mask = (x1.sum(dim) == 0) & (x2.sum(dim) == 0) + + cosine_sim = ops.cosine_similarity(x1, x2, dim, eps) + return ops.select(zero_norm_mask, ops.ones_like(cosine_sim), cosine_sim) + +# def pairwise_distance(): +# return ops.pairwise_distance + +def make_attention_mask( + query_input: core.Tensor, + key_input: core.Tensor, + dtype=core.float32, +): + """Mask-making helper for attention weights. + + In case of 1d inputs (i.e., `[batch..., len_q]`, `[batch..., len_kv]`, the + attention weights will be `[batch..., heads, len_q, len_kv]` and this + function will produce `[batch..., 1, len_q, len_kv]`. + + Args: + query_input: a batched, flat input of query_length size + key_input: a batched, flat input of key_length size + dtype: mask return dtype + + Returns: + A `[batch..., 1, len_q, len_kv]` shaped mask for 1d attention. + """ + mask = ops.greater_equal( + ops.expand_dims(query_input, axis=-1), ops.expand_dims(key_input, axis=-2) + ) + mask = ops.expand_dims(mask, axis=-3) + return mask.astype(dtype) + + +def make_causal_mask( + x: core.Tensor, dtype=core.float32 +) -> core.Tensor: + """Make a causal mask for self-attention. + + In case of 1d inputs (i.e., `[batch..., len]`, the self-attention weights + will be `[batch..., heads, len, len]` and this function will produce a + causal mask of shape `[batch..., 1, len, len]`. + + Args: + x: input array of shape `[batch..., len]` + extra_batch_dims: number of batch dims to add singleton axes for, none by + default + dtype: mask return dtype + + Returns: + A `[batch..., 1, len, len]` shaped causal mask for 1d attention. + """ + idxs = ops.broadcast_to(ops.arange(x.shape[-1], dtype=mindspore.int32), x.shape) + return make_attention_mask( + idxs, + idxs, + dtype=dtype, + ) + +def rotary_position_embedding(x, cos, sin, mode=0): + return execute('rotary_position_embedding', x, cos, sin, mode) diff --git a/mindnlp/core/nn/init.py b/mindnlp/core/nn/init.py new file mode 100644 index 000000000..8f2a764da --- /dev/null +++ b/mindnlp/core/nn/init.py @@ -0,0 +1,616 @@ +# mypy: allow-untyped-defs +"""This file contains utilities for initializing neural network parameters.""" +import math +import warnings +from typing import Optional as _Optional + +from mindnlp import core +from mindnlp.core import Tensor + +def _no_grad_uniform_(tensor, a, b, generator=None): + with core.no_grad(): + return tensor.uniform_(a, b, generator=generator) + +def _no_grad_normal_(tensor, mean, std, generator=None): + with core.no_grad(): + return tensor.normal_(mean, std, generator=generator) + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b, generator=None): + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn( + "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2, + ) + + with core.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1, generator=generator) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.0)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def _no_grad_fill_(tensor, val): + with core.no_grad(): + return tensor.fill_(val) + + +def _no_grad_zero_(tensor): + with core.no_grad(): + return tensor.zero_() + +def calculate_gain(nonlinearity, param=None): + r"""Return the recommended gain value for the given nonlinearity function. + + The values are as follows: + + ================= ==================================================== + nonlinearity gain + ================= ==================================================== + Linear / Identity :math:`1` + Conv{1,2,3}D :math:`1` + Sigmoid :math:`1` + Tanh :math:`\frac{5}{3}` + ReLU :math:`\sqrt{2}` + Leaky Relu :math:`\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}` + SELU :math:`\frac{3}{4}` + ================= ==================================================== + + .. warning:: + In order to implement `Self-Normalizing Neural Networks`_ , + you should use ``nonlinearity='linear'`` instead of ``nonlinearity='selu'``. + This gives the initial weights a variance of ``1 / N``, + which is necessary to induce a stable fixed point in the forward pass. + In contrast, the default gain for ``SELU`` sacrifices the normalization + effect for more stable gradient flow in rectangular layers. + + Args: + nonlinearity: the non-linear function (`nn.functional` name) + param: optional parameter for the non-linear function + + Examples: + >>> gain = nn.init.calculate_gain('leaky_relu', 0.2) # leaky_relu with negative_slope=0.2 + + .. _Self-Normalizing Neural Networks: https://papers.nips.cc/paper/2017/hash/5d44ee6f2c3f71b73125876103c8f6c4-Abstract.html + """ + linear_fns = [ + "linear", + "conv1d", + "conv2d", + "conv3d", + "conv_transpose1d", + "conv_transpose2d", + "conv_transpose3d", + ] + if nonlinearity in linear_fns or nonlinearity == "sigmoid": + return 1 + elif nonlinearity == "tanh": + return 5.0 / 3 + elif nonlinearity == "relu": + return math.sqrt(2.0) + elif nonlinearity == "leaky_relu": + if param is None: + negative_slope = 0.01 + elif ( + not isinstance(param, bool) + and isinstance(param, int) + or isinstance(param, float) + ): + # True/False are instances of int, hence check above + negative_slope = param + else: + raise ValueError(f"negative_slope {param} not a valid number") + return math.sqrt(2.0 / (1 + negative_slope**2)) + elif nonlinearity == "selu": + return ( + 3.0 / 4 + ) # Value found empirically (https://github.com/pytorch/pytorch/pull/50664) + else: + raise ValueError(f"Unsupported nonlinearity {nonlinearity}") + + +def uniform_( + tensor: Tensor, + a: float = 0.0, + b: float = 1.0, + generator: _Optional[core.Generator] = None, +) -> Tensor: + r"""Fill the input Tensor with values drawn from the uniform distribution. + + :math:`\mathcal{U}(a, b)`. + + Args: + tensor: an n-dimensional `core.Tensor` + a: the lower bound of the uniform distribution + b: the upper bound of the uniform distribution + + Examples: + >>> w = core.empty(3, 5) + >>> nn.init.uniform_(w) + """ + return _no_grad_uniform_(tensor, a, b, generator) + + +def normal_( + tensor: Tensor, + mean: float = 0.0, + std: float = 1.0, + generator: _Optional[core.Generator] = None, +) -> Tensor: + r"""Fill the input Tensor with values drawn from the normal distribution. + + :math:`\mathcal{N}(\text{mean}, \text{std}^2)`. + + Args: + tensor: an n-dimensional `core.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + + Examples: + >>> w = core.empty(3, 5) + >>> nn.init.normal_(w) + """ + return _no_grad_normal_(tensor, mean, std, generator) + + +def trunc_normal_( + tensor: Tensor, + mean: float = 0.0, + std: float = 1.0, + a: float = -2.0, + b: float = 2.0, + generator: _Optional[core.Generator] = None, +) -> Tensor: + r"""Fill the input Tensor with values drawn from a truncated normal distribution. + + The values are effectively drawn from the + normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + + Args: + tensor: an n-dimensional `core.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + + Examples: + >>> w = core.empty(3, 5) + >>> nn.init.trunc_normal_(w) + """ + return _no_grad_trunc_normal_(tensor, mean, std, a, b, generator=generator) + + +def constant_(tensor: Tensor, val: float) -> Tensor: + r"""Fill the input Tensor with the value :math:`\text{val}`. + + Args: + tensor: an n-dimensional `core.Tensor` + val: the value to fill the tensor with + + Examples: + >>> w = core.empty(3, 5) + >>> nn.init.constant_(w, 0.3) + """ + return _no_grad_fill_(tensor, val) + + + +def ones_(tensor: Tensor) -> Tensor: + r"""Fill the input Tensor with the scalar value `1`. + + Args: + tensor: an n-dimensional `core.Tensor` + + Examples: + >>> w = core.empty(3, 5) + >>> nn.init.ones_(w) + """ + return _no_grad_fill_(tensor, 1.0) + + + +def zeros_(tensor: Tensor) -> Tensor: + r"""Fill the input Tensor with the scalar value `0`. + + Args: + tensor: an n-dimensional `core.Tensor` + + Examples: + >>> w = core.empty(3, 5) + >>> nn.init.zeros_(w) + """ + return _no_grad_zero_(tensor) + + + +def dirac_(tensor, groups=1): + r"""Fill the {3, 4, 5}-dimensional input `Tensor` with the Dirac delta function. + + Preserves the identity of the inputs in `Convolutional` + layers, where as many input channels are preserved as possible. In case + of groups>1, each group of channels preserves identity + + Args: + tensor: a {3, 4, 5}-dimensional `core.Tensor` + groups (int, optional): number of groups in the conv layer (default: 1) + Examples: + >>> w = core.empty(3, 16, 5, 5) + >>> nn.init.dirac_(w) + >>> w = core.empty(3, 24, 5, 5) + >>> nn.init.dirac_(w, 3) + """ + dimensions = tensor.ndimension() + if dimensions not in [3, 4, 5]: + raise ValueError("Only tensors with 3, 4, or 5 dimensions are supported") + + sizes = tensor.size() + + if sizes[0] % groups != 0: + raise ValueError("dim 0 must be divisible by groups") + + out_chans_per_grp = sizes[0] // groups + min_dim = min(out_chans_per_grp, sizes[1]) + + with core.no_grad(): + tensor.zero_() + + for g in range(groups): + for d in range(min_dim): + if dimensions == 3: # Temporal convolution + tensor[g * out_chans_per_grp + d, d, tensor.size(2) // 2] = 1 + elif dimensions == 4: # Spatial convolution + tensor[ + g * out_chans_per_grp + d, + d, + tensor.size(2) // 2, + tensor.size(3) // 2, + ] = 1 + else: # Volumetric convolution + tensor[ + g * out_chans_per_grp + d, + d, + tensor.size(2) // 2, + tensor.size(3) // 2, + tensor.size(4) // 2, + ] = 1 + return tensor + + +def _calculate_fan_in_and_fan_out(tensor): + dimensions = tensor.ndim + if dimensions < 2: + raise ValueError( + "Fan in and fan out can not be computed for tensor with fewer than 2 dimensions" + ) + + num_input_fmaps = tensor.shape[1] + num_output_fmaps = tensor.shape[0] + receptive_field_size = 1 + if tensor.ndim > 2: + # math.prod is not always available, accumulate the product manually + # we could use functools.reduce but that is not supported by TorchScript + for s in tensor.shape[2:]: + receptive_field_size *= s + fan_in = num_input_fmaps * receptive_field_size + fan_out = num_output_fmaps * receptive_field_size + + return fan_in, fan_out + + +def xavier_uniform_( + tensor: Tensor, + gain: float = 1.0, +) -> Tensor: + r"""Fill the input `Tensor` with values using a Xavier uniform distribution. + + The method is described in `Understanding the difficulty of training + deep feedforward neural networks` - Glorot, X. & Bengio, Y. (2010). + The resulting tensor will have values sampled from + :math:`\mathcal{U}(-a, a)` where + + .. math:: + a = \text{gain} \times \sqrt{\frac{6}{\text{fan\_in} + \text{fan\_out}}} + + Also known as Glorot initialization. + + Args: + tensor: an n-dimensional `core.Tensor` + gain: an optional scaling factor + generator: the torch Generator to sample from (default: None) + + Examples: + >>> w = core.empty(3, 5) + >>> nn.init.xavier_uniform_(w, gain=nn.init.calculate_gain('relu')) + + Note: + Be aware that ``fan_in`` and ``fan_out`` are calculated assuming + that the weight matrix is used in a transposed manner, + (i.e., ``x @ w.T`` in ``Linear`` layers, where ``w.shape = [fan_out, fan_in]``). + This is important for correct initialization. + If you plan to use ``x @ w``, where ``w.shape = [fan_in, fan_out]``, + pass in a transposed weight matrix, i.e. ``nn.init.xavier_uniform_(w.T, ...)``. + """ + fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) + std = gain * math.sqrt(2.0 / float(fan_in + fan_out)) + a = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation + + return uniform_(tensor, -a, a) + + +def xavier_normal_( + tensor: Tensor, + gain: float = 1.0, +) -> Tensor: + r"""Fill the input `Tensor` with values using a Xavier normal distribution. + + The method is described in `Understanding the difficulty of training deep feedforward + neural networks` - Glorot, X. & Bengio, Y. (2010). The resulting tensor + will have values sampled from :math:`\mathcal{N}(0, \text{std}^2)` where + + .. math:: + \text{std} = \text{gain} \times \sqrt{\frac{2}{\text{fan\_in} + \text{fan\_out}}} + + Also known as Glorot initialization. + + Args: + tensor: an n-dimensional `core.Tensor` + gain: an optional scaling factor + generator: the torch Generator to sample from (default: None) + + Examples: + >>> w = core.empty(3, 5) + >>> nn.init.xavier_normal_(w) + + Note: + Be aware that ``fan_in`` and ``fan_out`` are calculated assuming + that the weight matrix is used in a transposed manner, + (i.e., ``x @ w.T`` in ``Linear`` layers, where ``w.shape = [fan_out, fan_in]``). + This is important for correct initialization. + If you plan to use ``x @ w``, where ``w.shape = [fan_in, fan_out]``, + pass in a transposed weight matrix, i.e. ``nn.init.xavier_normal_(w.T, ...)``. + """ + fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) + std = gain * math.sqrt(2.0 / float(fan_in + fan_out)) + + return normal_(tensor, 0.0, std) + + +def _calculate_correct_fan(tensor, mode): + mode = mode.lower() + valid_modes = ["fan_in", "fan_out"] + if mode not in valid_modes: + raise ValueError(f"Mode {mode} not supported, please use one of {valid_modes}") + + fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) + return fan_in if mode == "fan_in" else fan_out + + +def kaiming_uniform_( + tensor: Tensor, + a: float = 0, + mode: str = "fan_in", + nonlinearity: str = "leaky_relu", +): + r"""Fill the input `Tensor` with values using a Kaiming uniform distribution. + + The method is described in `Delving deep into rectifiers: Surpassing + human-level performance on ImageNet classification` - He, K. et al. (2015). + The resulting tensor will have values sampled from + :math:`\mathcal{U}(-\text{bound}, \text{bound})` where + + .. math:: + \text{bound} = \text{gain} \times \sqrt{\frac{3}{\text{fan\_mode}}} + + Also known as He initialization. + + Args: + tensor: an n-dimensional `core.Tensor` + a: the negative slope of the rectifier used after this layer (only + used with ``'leaky_relu'``) + mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'`` + preserves the magnitude of the variance of the weights in the + forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the + backwards pass. + nonlinearity: the non-linear function (`nn.functional` name), + recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default). + generator: the torch Generator to sample from (default: None) + + Examples: + >>> w = core.empty(3, 5) + >>> nn.init.kaiming_uniform_(w, mode='fan_in', nonlinearity='relu') + + Note: + Be aware that ``fan_in`` and ``fan_out`` are calculated assuming + that the weight matrix is used in a transposed manner, + (i.e., ``x @ w.T`` in ``Linear`` layers, where ``w.shape = [fan_out, fan_in]``). + This is important for correct initialization. + If you plan to use ``x @ w``, where ``w.shape = [fan_in, fan_out]``, + pass in a transposed weight matrix, i.e. ``nn.init.kaiming_uniform_(w.T, ...)``. + """ + + if 0 in tensor.shape: + warnings.warn("Initializing zero-element tensors is a no-op") + return tensor + fan = _calculate_correct_fan(tensor, mode) + gain = calculate_gain(nonlinearity, a) + std = gain / math.sqrt(fan) + bound = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation + return uniform_(tensor, -bound, bound) + + +def kaiming_normal_( + tensor: Tensor, + a: float = 0, + mode: str = "fan_in", + nonlinearity: str = "leaky_relu", + generator: _Optional[core.Generator] = None, +): + r"""Fill the input `Tensor` with values using a Kaiming normal distribution. + + The method is described in `Delving deep into rectifiers: Surpassing + human-level performance on ImageNet classification` - He, K. et al. (2015). + The resulting tensor will have values sampled from + :math:`\mathcal{N}(0, \text{std}^2)` where + + .. math:: + \text{std} = \frac{\text{gain}}{\sqrt{\text{fan\_mode}}} + + Also known as He initialization. + + Args: + tensor: an n-dimensional `core.Tensor` + a: the negative slope of the rectifier used after this layer (only + used with ``'leaky_relu'``) + mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'`` + preserves the magnitude of the variance of the weights in the + forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the + backwards pass. + nonlinearity: the non-linear function (`nn.functional` name), + recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default). + generator: the torch Generator to sample from (default: None) + + Examples: + >>> w = core.empty(3, 5) + >>> nn.init.kaiming_normal_(w, mode='fan_out', nonlinearity='relu') + + Note: + Be aware that ``fan_in`` and ``fan_out`` are calculated assuming + that the weight matrix is used in a transposed manner, + (i.e., ``x @ w.T`` in ``Linear`` layers, where ``w.shape = [fan_out, fan_in]``). + This is important for correct initialization. + If you plan to use ``x @ w``, where ``w.shape = [fan_in, fan_out]``, + pass in a transposed weight matrix, i.e. ``nn.init.kaiming_normal_(w.T, ...)``. + """ + if 0 in tensor.shape: + warnings.warn("Initializing zero-element tensors is a no-op") + return tensor + fan = _calculate_correct_fan(tensor, mode) + gain = calculate_gain(nonlinearity, a) + std = gain / math.sqrt(fan) + with core.no_grad(): + return tensor.normal_(0, std, generator=generator) + +def orthogonal_( + tensor, + gain=1, + generator: _Optional[core.Generator] = None, +): + r"""Fill the input `Tensor` with a (semi) orthogonal matrix. + + Described in `Exact solutions to the nonlinear dynamics of learning in deep + linear neural networks` - Saxe, A. et al. (2013). The input tensor must have + at least 2 dimensions, and for tensors with more than 2 dimensions the + trailing dimensions are flattened. + + Args: + tensor: an n-dimensional `core.Tensor`, where :math:`n \geq 2` + gain: optional scaling factor + + Examples: + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK) + >>> w = core.empty(3, 5) + >>> nn.init.orthogonal_(w) + """ + if tensor.ndimension() < 2: + raise ValueError("Only tensors with 2 or more dimensions are supported") + + if tensor.numel() == 0: + # no-op + return tensor + rows = tensor.size(0) + cols = tensor.numel() // rows + flattened = tensor.new_empty((rows, cols)).normal_(0, 1, generator=generator) + + if rows < cols: + flattened.t_() + + # Compute the qr factorization + q, r = core.linalg.qr(flattened) + # Make Q uniform according to https://arxiv.org/pdf/math-ph/0609050.pdf + d = core.diag(r, 0) + ph = d.sign() + q *= ph + + if rows < cols: + q.t_() + + with core.no_grad(): + tensor.view_as(q).copy_(q) + tensor.mul_(gain) + return tensor + +def sparse_( + tensor, + sparsity, + std=0.01, + generator: _Optional[core.Generator] = None, +): + r"""Fill the 2D input `Tensor` as a sparse matrix. + + The non-zero elements will be drawn from the normal distribution + :math:`\mathcal{N}(0, 0.01)`, as described in `Deep learning via + Hessian-free optimization` - Martens, J. (2010). + + Args: + tensor: an n-dimensional `core.Tensor` + sparsity: The fraction of elements in each column to be set to zero + std: the standard deviation of the normal distribution used to generate + the non-zero values + generator: the torch Generator to sample from (default: None) + + Examples: + >>> w = core.empty(3, 5) + >>> nn.init.sparse_(w, sparsity=0.1) + """ + if tensor.ndimension() != 2: + raise ValueError("Only tensors with 2 dimensions are supported") + + rows, cols = tensor.shape + num_zeros = int(math.ceil(sparsity * rows)) + + with core.no_grad(): + tensor.normal_(0, std, generator=generator) + for col_idx in range(cols): + row_indices = core.randperm(rows) + zero_indices = row_indices[:num_zeros] + tensor[zero_indices, col_idx] = 0 + return tensor + + + +uniform = uniform_ +normal = normal_ +constant = constant_ +dirac = dirac_ +xavier_uniform = xavier_uniform_ +xavier_normal = xavier_normal_ +kaiming_uniform = kaiming_uniform_ +kaiming_normal = kaiming_normal_ +orthogonal = orthogonal_ +sparse = sparse_ diff --git a/mindnlp/core/nn/modules/__init__.py b/mindnlp/core/nn/modules/__init__.py new file mode 100644 index 000000000..4611cbbcf --- /dev/null +++ b/mindnlp/core/nn/modules/__init__.py @@ -0,0 +1,22 @@ +"""new nn modules""" +from .module import Module +from .container import ModuleList, ParameterList, Sequential, ParameterDict, ModuleDict +from .linear import Linear, Identity +from .sparse import Embedding +from .normalization import LayerNorm, GroupNorm +from .dropout import Dropout, Dropout2d +from .activation import * +from .conv import Conv3d, Conv2d, Conv1d, ConvTranspose2d, ConvTranspose1d +from .padding import ZeroPad2d, ConstantPad2d, ConstantPad1d, ConstantPad3d +from .batchnorm import BatchNorm2d, BatchNorm1d +from .pooling import AdaptiveAvgPool2d, AvgPool1d, MaxPool2d, MaxPool1d, AdaptiveAvgPool1d, AvgPool2d +from .flatten import Unflatten, Flatten +from .rnn_cell import RNNCell, GRUCell, LSTMCell +from .rnn import RNN, LSTM, GRU +from .fold import Unfold, Fold +from .pixelshuffle import PixelUnshuffle, PixelShuffle +from .upsampling import Upsample, UpsamplingBilinear2d, UpsamplingNearest2d +from .loss import * +from .distance import * +from .adaptive import AdaptiveLogSoftmaxWithLoss +from .batchnorm import BatchNorm1d, BatchNorm2d, BatchNorm3d diff --git a/mindnlp/core/nn/modules/_utils.py b/mindnlp/core/nn/modules/_utils.py new file mode 100644 index 000000000..5826b9688 --- /dev/null +++ b/mindnlp/core/nn/modules/_utils.py @@ -0,0 +1,24 @@ +"""utils""" +import collections +from itertools import repeat + + +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable): + return x + return tuple(repeat(x, n)) + return parse + +_single = _ntuple(1) +_pair = _ntuple(2) +_triple = _ntuple(3) +_quadruple = _ntuple(4) + +def _reverse_repeat_tuple(t, n): + r"""Reverse the order of `t` and repeat each element for `n` times. + + This can be used to translate padding arg used by Conv and Pooling modules + to the ones used by `F.pad`. + """ + return tuple(x for x in reversed(t) for _ in range(n)) diff --git a/mindnlp/core/nn/modules/activation.py b/mindnlp/core/nn/modules/activation.py new file mode 100644 index 000000000..df06ff789 --- /dev/null +++ b/mindnlp/core/nn/modules/activation.py @@ -0,0 +1,925 @@ +"""activation""" +from typing import Optional, Tuple +from mindnlp import core +from mindnlp.core import Tensor +from ..parameter import Parameter + +from .module import Module +from .linear import Linear +from .. import functional as F +from .. import init +from ... import ops + +class GELU(Module): + r"""Applies the Gaussian Error Linear Units function: + + .. math:: \text{GELU}(x) = x * \Phi(x) + + where :math:`\Phi(x)` is the Cumulative Distribution Function for Gaussian Distribution. + + When the approximate argument is 'tanh', Gelu is estimated with: + + .. math:: \text{GELU}(x) = 0.5 * x * (1 + \text{Tanh}(\sqrt{2 / \pi} * (x + 0.044715 * x^3))) + + Args: + approximate (str, optional): the gelu approximation algorithm to use: + ``'none'`` | ``'tanh'``. Default: ``'none'`` + + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Output: :math:`(*)`, same shape as the input. + + .. image:: ../scripts/activation_images/GELU.png + + Examples:: + + >>> m = nn.GELU() + >>> input = core.randn(2) + >>> output = m(input) + """ + __constants__ = ['approximate'] + approximate: str + + def __init__(self, approximate: str = 'none') -> None: + super().__init__() + self.approximate = approximate + + def forward(self, input: Tensor) -> Tensor: + return F.gelu(input, approximate=self.approximate) + + def extra_repr(self) -> str: + return f'approximate={repr(self.approximate)}' + +class ReLU(Module): + r"""Applies the rectified linear unit function element-wise: + + :math:`\text{ReLU}(x) = (x)^+ = \max(0, x)` + + Args: + inplace: can optionally do the operation in-place. Default: ``False`` + + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Output: :math:`(*)`, same shape as the input. + + .. image:: ../scripts/activation_images/ReLU.png + + Examples:: + + >>> m = nn.ReLU() + >>> input = core.randn(2) + >>> output = m(input) + + + An implementation of CReLU - https://arxiv.org/abs/1603.05201 + + >>> m = nn.ReLU() + >>> input = core.randn(2).unsqueeze(0) + >>> output = core.cat((m(input), m(-input))) + """ + __constants__ = ["inplace"] + inplace: bool + + def __init__(self, inplace: bool = False): + super().__init__() + self.inplace = inplace + + def forward(self, input: Tensor) -> Tensor: + return F.relu(input, inplace=self.inplace) + + def extra_repr(self) -> str: + inplace_str = "inplace=True" if self.inplace else "" + return inplace_str + + +class LeakyReLU(Module): + r"""Applies the LeakyReLU function element-wise. + + .. math:: + \text{LeakyReLU}(x) = \max(0, x) + \text{negative\_slope} * \min(0, x) + + + or + + .. math:: + \text{LeakyReLU}(x) = + \begin{cases} + x, & \text{ if } x \geq 0 \\ + \text{negative\_slope} \times x, & \text{ otherwise } + \end{cases} + + Args: + negative_slope: Controls the angle of the negative slope (which is used for + negative input values). Default: 1e-2 + inplace: can optionally do the operation in-place. Default: ``False`` + + Shape: + - Input: :math:`(*)` where `*` means, any number of additional + dimensions + - Output: :math:`(*)`, same shape as the input + + .. image:: ../scripts/activation_images/LeakyReLU.png + + Examples:: + + >>> m = nn.LeakyReLU(0.1) + >>> input = core.randn(2) + >>> output = m(input) + """ + + __constants__ = ['inplace', 'negative_slope'] + inplace: bool + negative_slope: float + + def __init__(self, negative_slope: float = 1e-2, inplace: bool = False) -> None: + super().__init__() + self.negative_slope = negative_slope + self.inplace = inplace + + def forward(self, input: Tensor) -> Tensor: + return F.leaky_relu(input, self.negative_slope) + + def extra_repr(self) -> str: + inplace_str = ', inplace=True' if self.inplace else '' + return f'negative_slope={self.negative_slope}{inplace_str}' + + + +class Tanh(Module): + def forward(self, input: Tensor) -> Tensor: + return F.tanh(input) + +class Softmax(Module): + r"""Applies the Softmax function to an n-dimensional input Tensor. + + Rescales them so that the elements of the n-dimensional output Tensor + lie in the range [0,1] and sum to 1. + + Softmax is defined as: + + .. math:: + \text{Softmax}(x_{i}) = \frac{\exp(x_i)}{\sum_j \exp(x_j)} + + When the input Tensor is a sparse tensor then the unspecified + values are treated as ``-inf``. + + Shape: + - Input: :math:`(*)` where `*` means, any number of additional + dimensions + - Output: :math:`(*)`, same shape as the input + + Returns: + a Tensor of the same dimension and shape as the input with + values in the range [0, 1] + + Args: + dim (int): A dimension along which Softmax will be computed (so every slice + along dim will sum to 1). + + .. note:: + This module doesn't work directly with NLLLoss, + which expects the Log to be computed between the Softmax and itself. + Use `LogSoftmax` instead (it's faster and has better numerical properties). + + Examples:: + + >>> m = nn.Softmax(dim=1) + >>> input = core.randn(2, 3) + >>> output = m(input) + + """ + + __constants__ = ['dim'] + dim: Optional[int] + + def __init__(self, dim: Optional[int] = None) -> None: + super().__init__() + self.dim = dim + + def __setstate__(self, state): + super().__setstate__(state) + if not hasattr(self, 'dim'): + self.dim = None + + def forward(self, input: Tensor) -> Tensor: + return F.softmax(input, self.dim) + + def extra_repr(self) -> str: + return f'dim={self.dim}' + + +class LogSoftmax(Module): + r"""Applies the :math:`\log(\text{Softmax}(x))` function to an n-dimensional input Tensor. + + The LogSoftmax formulation can be simplified as: + + .. math:: + \text{LogSoftmax}(x_{i}) = \log\left(\frac{\exp(x_i) }{ \sum_j \exp(x_j)} \right) + + Shape: + - Input: :math:`(*)` where `*` means, any number of additional + dimensions + - Output: :math:`(*)`, same shape as the input + + Args: + dim (int): A dimension along which LogSoftmax will be computed. + + Returns: + a Tensor of the same dimension and shape as the input with + values in the range [-inf, 0) + + Examples:: + + >>> m = nn.LogSoftmax(dim=1) + >>> input = core.randn(2, 3) + >>> output = m(input) + """ + + __constants__ = ['dim'] + dim: Optional[int] + + def __init__(self, dim: Optional[int] = None) -> None: + super().__init__() + self.dim = dim + + def __setstate__(self, state): + super().__setstate__(state) + if not hasattr(self, 'dim'): + self.dim = None + + def forward(self, input: Tensor) -> Tensor: + return F.log_softmax(input, self.dim) + + def extra_repr(self): + return f'dim={self.dim}' + +class Sigmoid(Module): + r"""Applies the Sigmoid function element-wise. + + .. math:: + \text{Sigmoid}(x) = \sigma(x) = \frac{1}{1 + \exp(-x)} + + + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Output: :math:`(*)`, same shape as the input. + + .. image:: ../scripts/activation_images/Sigmoid.png + + Examples:: + + >>> m = nn.Sigmoid() + >>> input = core.randn(2) + >>> output = m(input) + """ + + def forward(self, input: Tensor) -> Tensor: + return F.sigmoid(input) + + +class SiLU(Module): + def forward(self, input): + return F.silu(input) + + +class Mish(Module): + def forward(self, input): + return F.mish(input) + + +class ReLU6(Module): + def forward(self, input): + return F.relu6(input) + +class ELU(Module): + def forward(self, input): + return F.elu(input) + +class GLU(Module): + r"""Applies the gated linear unit function. + + :math:`{GLU}(a, b)= a \otimes \sigma(b)` where :math:`a` is the first half + of the input matrices and :math:`b` is the second half. + + Args: + dim (int): the dimension on which to split the input. Default: -1 + + Shape: + - Input: :math:`(\ast_1, N, \ast_2)` where `*` means, any number of additional + dimensions + - Output: :math:`(\ast_1, M, \ast_2)` where :math:`M=N/2` + + Examples:: + + >>> m = nn.GLU() + >>> input = core.randn(4, 2) + >>> output = m(input) + """ + + __constants__ = ['dim'] + dim: int + + def __init__(self, dim: int = -1) -> None: + super().__init__() + self.dim = dim + + def forward(self, input: Tensor) -> Tensor: + return F.glu(input, self.dim) + + def extra_repr(self) -> str: + return f'dim={self.dim}' + + +class Softplus(Module): + r"""Applies the Softplus function element-wise. + + .. math:: + \text{Softplus}(x) = \frac{1}{\beta} * \log(1 + \exp(\beta * x)) + + SoftPlus is a smooth approximation to the ReLU function and can be used + to constrain the output of a machine to always be positive. + + For numerical stability the implementation reverts to the linear function + when :math:`input \times \beta > threshold`. + + Args: + beta: the :math:`\beta` value for the Softplus formulation. Default: 1 + threshold: values above this revert to a linear function. Default: 20 + + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Output: :math:`(*)`, same shape as the input. + + .. image:: ../scripts/activation_images/Softplus.png + + Examples:: + + >>> m = nn.Softplus() + >>> input = core.randn(2) + >>> output = m(input) + """ + + __constants__ = ['beta', 'threshold'] + beta: float + threshold: float + + def __init__(self, beta: float = 1.0, threshold: float = 20.0) -> None: + super().__init__() + self.beta = beta + self.threshold = threshold + + def forward(self, input: Tensor) -> Tensor: + return F.softplus(input, self.beta, self.threshold) + + def extra_repr(self) -> str: + return f'beta={self.beta}, threshold={self.threshold}' + +class MultiheadAttention(Module): + r"""Allows the model to jointly attend to information + from different representation subspaces as described in the paper: + `Attention Is All You Need `_. + + Multi-Head Attention is defined as: + + .. math:: + \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O + + where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`. + + ``nn.MultiHeadAttention`` will use the optimized implementations of + ``scaled_dot_product_attention()`` when possible. + + In addition to support for the new ``scaled_dot_product_attention()`` + function, for speeding up Inference, MHA will use + fastpath inference with support for Nested Tensors, iff: + + - self attention is being computed (i.e., ``query``, ``key``, and ``value`` are the same tensor). + - inputs are batched (3D) with ``batch_first==True`` + - Either autograd is disabled (using ``core.inference_mode`` or ``core.no_grad``) or no tensor argument ``requires_grad`` + - training is disabled (using ``.eval()``) + - ``add_bias_kv`` is ``False`` + - ``add_zero_attn`` is ``False`` + - ``batch_first`` is ``True`` and the input is batched + - ``kdim`` and ``vdim`` are equal to ``embed_dim`` + - if a `NestedTensor `_ is passed, neither ``key_padding_mask`` + nor ``attn_mask`` is passed + - autocast is disabled + + If the optimized inference fastpath implementation is in use, a + `NestedTensor `_ can be passed for + ``query``/``key``/``value`` to represent padding more efficiently than using a + padding mask. In this case, a `NestedTensor `_ + will be returned, and an additional speedup proportional to the fraction of the input + that is padding can be expected. + + Args: + embed_dim: Total dimension of the model. + num_heads: Number of parallel attention heads. Note that ``embed_dim`` will be split + across ``num_heads`` (i.e. each head will have dimension ``embed_dim // num_heads``). + dropout: Dropout probability on ``attn_output_weights``. Default: ``0.0`` (no dropout). + bias: If specified, adds bias to input / output projection layers. Default: ``True``. + add_bias_kv: If specified, adds bias to the key and value sequences at dim=0. Default: ``False``. + add_zero_attn: If specified, adds a new batch of zeros to the key and value sequences at dim=1. + Default: ``False``. + kdim: Total number of features for keys. Default: ``None`` (uses ``kdim=embed_dim``). + vdim: Total number of features for values. Default: ``None`` (uses ``vdim=embed_dim``). + batch_first: If ``True``, then the input and output tensors are provided + as (batch, seq, feature). Default: ``False`` (seq, batch, feature). + + Examples:: + + >>> # xdoctest: +SKIP + >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) + >>> attn_output, attn_output_weights = multihead_attn(query, key, value) + + .. _`FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness`: + https://arxiv.org/abs/2205.14135 + + """ + + __constants__ = ['batch_first'] + bias_k: Optional[core.Tensor] + bias_v: Optional[core.Tensor] + + def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, + kdim=None, vdim=None, batch_first=False, dtype=None) -> None: + factory_kwargs = {'dtype': dtype} + super().__init__() + self.embed_dim = embed_dim + self.kdim = kdim if kdim is not None else embed_dim + self.vdim = vdim if vdim is not None else embed_dim + self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim + + self.num_heads = num_heads + self.dropout = dropout + self.batch_first = batch_first + self.head_dim = embed_dim // num_heads + assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" + + + if not self._qkv_same_embed_dim: + self.q_proj_weight = Parameter(ops.empty((embed_dim, embed_dim), **factory_kwargs)) + self.k_proj_weight = Parameter(ops.empty((embed_dim, self.kdim), **factory_kwargs)) + self.v_proj_weight = Parameter(ops.empty((embed_dim, self.vdim), **factory_kwargs)) + self.register_parameter('in_proj_weight', None) + else: + self.in_proj_weight = Parameter(ops.empty((3 * embed_dim, embed_dim), **factory_kwargs)) + self.register_parameter('q_proj_weight', None) + self.register_parameter('k_proj_weight', None) + self.register_parameter('v_proj_weight', None) + + if bias: + self.in_proj_bias = Parameter(ops.empty(3 * embed_dim, **factory_kwargs)) + else: + self.register_parameter('in_proj_bias', None) + self.out_proj = Linear(embed_dim, embed_dim, bias=bias) + + if add_bias_kv: + self.bias_k = Parameter(ops.empty((1, 1, embed_dim), **factory_kwargs)) + self.bias_v = Parameter(ops.empty((1, 1, embed_dim), **factory_kwargs)) + else: + self.bias_k = self.bias_v = None + + self.add_zero_attn = add_zero_attn + + self._reset_parameters() + + def _reset_parameters(self): + if self._qkv_same_embed_dim: + init.xavier_uniform_(self.in_proj_weight) + else: + init.xavier_uniform_(self.q_proj_weight) + init.xavier_uniform_(self.k_proj_weight) + init.xavier_uniform_(self.v_proj_weight) + + if self.in_proj_bias is not None: + init.constant_(self.in_proj_bias, 0.) + init.constant_(self.out_proj.bias, 0.) + if self.bias_k is not None: + init.xavier_normal_(self.bias_k) + if self.bias_v is not None: + init.xavier_normal_(self.bias_v) + + def __setstate__(self, state): + # Support loading old MultiheadAttention checkpoints generated by v1.1.0 + if '_qkv_same_embed_dim' not in state: + state['_qkv_same_embed_dim'] = True + + super().__setstate__(state) + + def forward( + self, + query: Tensor, + key: Tensor, + value: Tensor, + key_padding_mask: Optional[Tensor] = None, + need_weights: bool = True, + attn_mask: Optional[Tensor] = None, + average_attn_weights: bool = True, + is_causal : bool = False) -> Tuple[Tensor, Optional[Tensor]]: + r""" + Args: + query: Query embeddings of shape :math:`(L, E_q)` for unbatched input, :math:`(L, N, E_q)` when ``batch_first=False`` + or :math:`(N, L, E_q)` when ``batch_first=True``, where :math:`L` is the target sequence length, + :math:`N` is the batch size, and :math:`E_q` is the query embedding dimension ``embed_dim``. + Queries are compared against key-value pairs to produce the output. + See "Attention Is All You Need" for more details. + key: Key embeddings of shape :math:`(S, E_k)` for unbatched input, :math:`(S, N, E_k)` when ``batch_first=False`` + or :math:`(N, S, E_k)` when ``batch_first=True``, where :math:`S` is the source sequence length, + :math:`N` is the batch size, and :math:`E_k` is the key embedding dimension ``kdim``. + See "Attention Is All You Need" for more details. + value: Value embeddings of shape :math:`(S, E_v)` for unbatched input, :math:`(S, N, E_v)` when + ``batch_first=False`` or :math:`(N, S, E_v)` when ``batch_first=True``, where :math:`S` is the source + sequence length, :math:`N` is the batch size, and :math:`E_v` is the value embedding dimension ``vdim``. + See "Attention Is All You Need" for more details. + key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key`` + to ignore for the purpose of attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`. + Binary and float masks are supported. + For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for + the purpose of attention. For a float mask, it will be directly added to the corresponding ``key`` value. + need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``. + Set ``need_weights=False`` to use the optimized ``scaled_dot_product_attention`` + and achieve the best performance for MHA. + Default: ``True``. + attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape + :math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size, + :math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be + broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch. + Binary and float masks are supported. For a binary mask, a ``True`` value indicates that the + corresponding position is not allowed to attend. For a float mask, the mask values will be added to + the attention weight. + If both attn_mask and key_padding_mask are supplied, their types should match. + average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across + heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an + effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads) + is_causal: If specified, applies a causal mask as attention mask. + Default: ``False``. + Warning: + ``is_causal`` provides a hint that ``attn_mask`` is the + causal mask. Providing incorrect hints can result in + incorrect execution, including forward and backward + compatibility. + + Outputs: + - **attn_output** - Attention outputs of shape :math:`(L, E)` when input is unbatched, + :math:`(L, N, E)` when ``batch_first=False`` or :math:`(N, L, E)` when ``batch_first=True``, + where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E` is the + embedding dimension ``embed_dim``. + - **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``, + returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or + :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and + :math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per + head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched or :math:`(N, \text{num\_heads}, L, S)`. + + .. note:: + `batch_first` argument is ignored for unbatched inputs. + """ + + is_batched = query.ndim == 3 + + key_padding_mask = F._canonical_mask( + mask=key_padding_mask, + mask_name="key_padding_mask", + other_type=F._none_or_dtype(attn_mask), + other_name="attn_mask", + target_type=query.dtype + ) + + attn_mask = F._canonical_mask( + mask=attn_mask, + mask_name="attn_mask", + other_type=None, + other_name="", + target_type=query.dtype, + check_other=False, + ) + + if self.batch_first and is_batched: + # make sure that the transpose op does not affect the "is" property + if key is value: + if query is key: + query = key = value = ops.transpose(query, 1, 0) + else: + query, key = (ops.transpose(x, 1, 0) for x in (query, key)) + value = key + else: + query, key, value = (ops.transpose(x, 1, 0) for x in (query, key, value)) + + if not self._qkv_same_embed_dim: + attn_output, attn_output_weights = F.multi_head_attention_forward( + query, key, value, self.embed_dim, self.num_heads, + self.in_proj_weight, self.in_proj_bias, + self.bias_k, self.bias_v, self.add_zero_attn, + self.dropout, self.out_proj.weight, self.out_proj.bias, + training=self.training, + key_padding_mask=key_padding_mask, need_weights=need_weights, + attn_mask=attn_mask, + use_separate_proj_weight=True, + q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight, + v_proj_weight=self.v_proj_weight, + average_attn_weights=average_attn_weights, + is_causal=is_causal) + else: + attn_output, attn_output_weights = F.multi_head_attention_forward( + query, key, value, self.embed_dim, self.num_heads, + self.in_proj_weight, self.in_proj_bias, + self.bias_k, self.bias_v, self.add_zero_attn, + self.dropout, self.out_proj.weight, self.out_proj.bias, + training=self.training, + key_padding_mask=key_padding_mask, + need_weights=need_weights, + attn_mask=attn_mask, + average_attn_weights=average_attn_weights, + is_causal=is_causal) + if self.batch_first and is_batched: + return ops.transpose(attn_output, 1, 0), attn_output_weights + else: + return attn_output, attn_output_weights + + def merge_masks(self, attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], + query: Tensor) -> Tuple[Optional[Tensor], Optional[int]]: + r""" + Determine mask type and combine masks if necessary. If only one mask is provided, that mask + and the corresponding mask type will be returned. If both masks are provided, they will be both + expanded to shape ``(batch_size, num_heads, seq_len, seq_len)``, combined with logical ``or`` + and mask type 2 will be returned + Args: + attn_mask: attention mask of shape ``(seq_len, seq_len)``, mask type 0 + key_padding_mask: padding mask of shape ``(batch_size, seq_len)``, mask type 1 + query: query embeddings of shape ``(batch_size, seq_len, embed_dim)`` + Returns: + merged_mask: merged mask + mask_type: merged mask type (0, 1, or 2) + """ + mask_type: Optional[int] = None + merged_mask: Optional[Tensor] = None + + if key_padding_mask is not None: + mask_type = 1 + merged_mask = key_padding_mask + + if attn_mask is not None: + # In this branch query can't be a nested tensor, so it has a shape + batch_size, seq_len, _ = query.shape + mask_type = 2 + + # Always expands attn_mask to 4D + if attn_mask.ndim == 3: + attn_mask_expanded = attn_mask.view(batch_size, -1, seq_len, seq_len) + else: # attn_mask.ndim == 2: + attn_mask_expanded = attn_mask.view(1, 1, seq_len, seq_len).expand(batch_size, self.num_heads, -1, -1) + merged_mask = attn_mask_expanded + + if key_padding_mask is not None: + key_padding_mask_expanded = key_padding_mask.view(batch_size, 1, 1, seq_len).expand(-1, self.num_heads, -1, -1) + merged_mask = attn_mask_expanded + key_padding_mask_expanded + + # no attn_mask and no key_padding_mask, returns None, None + return merged_mask, mask_type + +class PReLU(Module): + r"""Applies the element-wise PReLU function. + + .. math:: + \text{PReLU}(x) = \max(0,x) + a * \min(0,x) + + or + + .. math:: + \text{PReLU}(x) = + \begin{cases} + x, & \text{ if } x \ge 0 \\ + ax, & \text{ otherwise } + \end{cases} + + Here :math:`a` is a learnable parameter. When called without arguments, `nn.PReLU()` uses a single + parameter :math:`a` across all input channels. If called with `nn.PReLU(nChannels)`, + a separate :math:`a` is used for each input channel. + + + .. note:: + weight decay should not be used when learning :math:`a` for good performance. + + .. note:: + Channel dim is the 2nd dim of input. When input has dims < 2, then there is + no channel dim and the number of channels = 1. + + Args: + num_parameters (int): number of :math:`a` to learn. + Although it takes an int as input, there is only two values are legitimate: + 1, or the number of channels at input. Default: 1 + init (float): the initial value of :math:`a`. Default: 0.25 + + Shape: + - Input: :math:`( *)` where `*` means, any number of additional + dimensions. + - Output: :math:`(*)`, same shape as the input. + + Attributes: + weight (Tensor): the learnable weights of shape (:attr:`num_parameters`). + + .. image:: ../scripts/activation_images/PReLU.png + + Examples:: + + >>> m = nn.PReLU() + >>> input = core.randn(2) + >>> output = m(input) + """ + + __constants__ = ["num_parameters"] + num_parameters: int + + def __init__( + self, num_parameters: int = 1, init: float = 0.25, dtype=None + ) -> None: + factory_kwargs = {"dtype": dtype} + self.num_parameters = num_parameters + super().__init__() + self.init = init + self.weight = Parameter(ops.empty(num_parameters, **factory_kwargs)) + self.reset_parameters() + + def reset_parameters(self): + init.constant_(self.weight, self.init) + + def forward(self, input: Tensor) -> Tensor: + return F.prelu(input, self.weight) + + def extra_repr(self) -> str: + return f"num_parameters={self.num_parameters}" + +class CELU(Module): + r"""Applies the CELU function element-wise. + + .. math:: + \text{CELU}(x) = \max(0,x) + \min(0, \alpha * (\exp(x/\alpha) - 1)) + + More details can be found in the paper `Continuously Differentiable Exponential Linear Units`_ . + + Args: + alpha: the :math:`\alpha` value for the CELU formulation. Default: 1.0 + inplace: can optionally do the operation in-place. Default: ``False`` + + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Output: :math:`(*)`, same shape as the input. + + .. image:: ../scripts/activation_images/CELU.png + + Examples:: + + >>> m = nn.CELU() + >>> input = core.randn(2) + >>> output = m(input) + + .. _`Continuously Differentiable Exponential Linear Units`: + https://arxiv.org/abs/1704.07483 + """ + + __constants__ = ["alpha", "inplace"] + alpha: float + inplace: bool + + def __init__(self, alpha: float = 1.0, inplace: bool = False) -> None: + super().__init__() + self.alpha = alpha + self.inplace = inplace + + def forward(self, input: Tensor) -> Tensor: + return F.celu(input, self.alpha, self.inplace) + + def extra_repr(self) -> str: + inplace_str = ", inplace=True" if self.inplace else "" + return f"alpha={self.alpha}{inplace_str}" + +class SELU(Module): + r"""Applies the SELU function element-wise. + + .. math:: + \text{SELU}(x) = \text{scale} * (\max(0,x) + \min(0, \alpha * (\exp(x) - 1))) + + with :math:`\alpha = 1.6732632423543772848170429916717` and + :math:`\text{scale} = 1.0507009873554804934193349852946`. + + .. warning:: + When using ``kaiming_normal`` or ``kaiming_normal_`` for initialisation, + ``nonlinearity='linear'`` should be used instead of ``nonlinearity='selu'`` + in order to get `Self-Normalizing Neural Networks`_. + See :func:`core.nn.init.calculate_gain` for more information. + + More details can be found in the paper `Self-Normalizing Neural Networks`_ . + + Args: + inplace (bool, optional): can optionally do the operation in-place. Default: ``False`` + + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Output: :math:`(*)`, same shape as the input. + + .. image:: ../scripts/activation_images/SELU.png + + Examples:: + + >>> m = nn.SELU() + >>> input = core.randn(2) + >>> output = m(input) + + .. _Self-Normalizing Neural Networks: https://arxiv.org/abs/1706.02515 + """ + + __constants__ = ["inplace"] + inplace: bool + + def __init__(self, inplace: bool = False) -> None: + super().__init__() + self.inplace = inplace + + def forward(self, input: Tensor) -> Tensor: + return F.selu(input) + + def extra_repr(self) -> str: + inplace_str = "inplace=True" if self.inplace else "" + return inplace_str + + +class Hardsigmoid(Module): + r"""Applies the Hardsigmoid function element-wise. + + Hardsigmoid is defined as: + + .. math:: + \text{Hardsigmoid}(x) = \begin{cases} + 0 & \text{if~} x \le -3, \\ + 1 & \text{if~} x \ge +3, \\ + x / 6 + 1 / 2 & \text{otherwise} + \end{cases} + + Args: + inplace: can optionally do the operation in-place. Default: ``False`` + + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Output: :math:`(*)`, same shape as the input. + + .. image:: ../scripts/activation_images/Hardsigmoid.png + + Examples:: + + >>> m = nn.Hardsigmoid() + >>> input = core.randn(2) + >>> output = m(input) + """ + + __constants__ = ["inplace"] + + inplace: bool + + def __init__(self, inplace: bool = False) -> None: + super().__init__() + self.inplace = inplace + + def forward(self, input: Tensor) -> Tensor: + return F.hardsigmoid(input, self.inplace) + + +class Hardswish(Module): + r"""Applies the Hardswish function, element-wise. + + Method described in the paper: `Searching for MobileNetV3 `_. + + Hardswish is defined as: + + .. math:: + \text{Hardswish}(x) = \begin{cases} + 0 & \text{if~} x \le -3, \\ + x & \text{if~} x \ge +3, \\ + x \cdot (x + 3) /6 & \text{otherwise} + \end{cases} + + Args: + inplace: can optionally do the operation in-place. Default: ``False`` + + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Output: :math:`(*)`, same shape as the input. + + .. image:: ../scripts/activation_images/Hardswish.png + + Examples:: + + >>> m = nn.Hardswish() + >>> input = core.randn(2) + >>> output = m(input) + """ + + __constants__ = ["inplace"] + + inplace: bool + + def __init__(self, inplace: bool = False) -> None: + super().__init__() + self.inplace = inplace + + def forward(self, input: Tensor) -> Tensor: + return F.hardswish(input, self.inplace) diff --git a/mindnlp/core/nn/modules/adaptive.py b/mindnlp/core/nn/modules/adaptive.py new file mode 100644 index 000000000..d89e4a56d --- /dev/null +++ b/mindnlp/core/nn/modules/adaptive.py @@ -0,0 +1,316 @@ +"""adaptive""" +# mypy: allow-untyped-defs + +from collections import namedtuple +from typing import List, Sequence + +from mindnlp.core import Tensor + +from . import Sequential, ModuleList, Linear +from .module import Module +from ..functional import log_softmax +from ... import ops + +__all__ = ['AdaptiveLogSoftmaxWithLoss'] + +_ASMoutput = namedtuple('_ASMoutput', ['output', 'loss']) + + + +class AdaptiveLogSoftmaxWithLoss(Module): + r"""Efficient softmax approximation. + + As described in + `Efficient softmax approximation for GPUs by Edouard Grave, Armand Joulin, + Moustapha Ciss\u00e9, David Grangier, and Herv\u00e9 J\u00e9gou + `__. + + Adaptive softmax is an approximate strategy for training models with large + output spaces. It is most effective when the label distribution is highly + imbalanced, for example in natural language modelling, where the word + frequency distribution approximately follows the `Zipf's law`_. + + Adaptive softmax partitions the labels into several clusters, according to + their frequency. These clusters may contain different number of targets + each. + Additionally, clusters containing less frequent labels assign lower + dimensional embeddings to those labels, which speeds up the computation. + For each minibatch, only clusters for which at least one target is + present are evaluated. + + The idea is that the clusters which are accessed frequently + (like the first one, containing most frequent labels), should also be cheap + to compute -- that is, contain a small number of assigned labels. + + We highly recommend taking a look at the original paper for more details. + + * :attr:`cutoffs` should be an ordered Sequence of integers sorted + in the increasing order. + It controls number of clusters and the partitioning of targets into + clusters. For example setting ``cutoffs = [10, 100, 1000]`` + means that first `10` targets will be assigned + to the 'head' of the adaptive softmax, targets `11, 12, ..., 100` will be + assigned to the first cluster, and targets `101, 102, ..., 1000` will be + assigned to the second cluster, while targets + `1001, 1002, ..., n_classes - 1` will be assigned + to the last, third cluster. + + * :attr:`div_value` is used to compute the size of each additional cluster, + which is given as + :math:`\left\lfloor\frac{\texttt{in\_features}}{\texttt{div\_value}^{idx}}\right\rfloor`, + where :math:`idx` is the cluster index (with clusters + for less frequent words having larger indices, + and indices starting from :math:`1`). + + * :attr:`head_bias` if set to True, adds a bias term to the 'head' of the + adaptive softmax. See paper for details. Set to False in the official + implementation. + + .. warning:: + Labels passed as inputs to this module should be sorted according to + their frequency. This means that the most frequent label should be + represented by the index `0`, and the least frequent + label should be represented by the index `n_classes - 1`. + + .. note:: + This module returns a ``NamedTuple`` with ``output`` + and ``loss`` fields. See further documentation for details. + + .. note:: + To compute log-probabilities for all classes, the ``log_prob`` + method can be used. + + Args: + in_features (int): Number of features in the input tensor + n_classes (int): Number of classes in the dataset + cutoffs (Sequence): Cutoffs used to assign targets to their buckets + div_value (float, optional): value used as an exponent to compute sizes + of the clusters. Default: 4.0 + head_bias (bool, optional): If ``True``, adds a bias term to the 'head' of the + adaptive softmax. Default: ``False`` + + Returns: + ``NamedTuple`` with ``output`` and ``loss`` fields: + * **output** is a Tensor of size ``N`` containing computed target + log probabilities for each example + * **loss** is a Scalar representing the computed negative + log likelihood loss + + Shape: + - input: :math:`(N, \texttt{in\_features})` or :math:`(\texttt{in\_features})` + - target: :math:`(N)` or :math:`()` where each value satisfies :math:`0 <= \texttt{target[i]} <= \texttt{n\_classes}` + - output1: :math:`(N)` or :math:`()` + - output2: ``Scalar`` + + .. _Zipf's law: https://en.wikipedia.org/wiki/Zipf%27s_law + """ + + in_features: int + n_classes: int + cutoffs: List[int] + div_value: float + head_bias: bool + head: Linear + tail: ModuleList + + def __init__( + self, + in_features: int, + n_classes: int, + cutoffs: Sequence[int], + div_value: float = 4., + head_bias: bool = False, + dtype=None + ) -> None: + factory_kwargs = {'dtype': dtype} + super().__init__() + + cutoffs = list(cutoffs) + + if (len(cutoffs) == 0): + raise ValueError("cutoffs should be a sequence of length larger than 0") + + if (cutoffs != sorted(cutoffs)) \ + or (min(cutoffs) <= 0) \ + or (max(cutoffs) > (n_classes - 1)) \ + or (len(set(cutoffs)) != len(cutoffs)) \ + or any(int(c) != c for c in cutoffs): + + raise ValueError("cutoffs should be a sequence of unique, positive " + "integers sorted in an increasing order, where " + "each value is between 1 and n_classes-1") + + self.in_features = in_features + self.n_classes = n_classes + self.cutoffs = cutoffs + [n_classes] + self.div_value = div_value + self.head_bias = head_bias + + self.shortlist_size = self.cutoffs[0] + self.n_clusters = len(self.cutoffs) - 1 + self.head_size = self.shortlist_size + self.n_clusters + + self.head = Linear(self.in_features, self.head_size, bias=self.head_bias, + **factory_kwargs) + self.tail = ModuleList() + + for i in range(self.n_clusters): + + hsz = int(self.in_features // (self.div_value ** (i + 1))) + osz = self.cutoffs[i + 1] - self.cutoffs[i] + + projection = Sequential( + Linear(self.in_features, hsz, bias=False, **factory_kwargs), + Linear(hsz, osz, bias=False, **factory_kwargs), + ) + + self.tail.append(projection) + + def reset_parameters(self) -> None: + self.head.reset_parameters() + for i2h, h2o in self.tail: + i2h.reset_parameters() + h2o.reset_parameters() + + def forward(self, input_: Tensor, target_: Tensor) -> _ASMoutput: + targ_dim = target_.ndim + + if targ_dim == 1: + if input_.shape[0] != target_.shape[0]: + raise RuntimeError('Input and target should have the same size ' + 'in the batch dimension.') + if input_.ndim != 2: + raise RuntimeError('1D target tensor expects 2D input tensors, ' + 'but found inputs with size', input_.shape) + elif targ_dim == 0: + if input_.ndim != 1: + raise RuntimeError('0D target tensor expects 1D input tensors, ' + 'but found inputs with size', input_.shape) + else: + raise RuntimeError('0D or 1D target tensor expected, ' + 'multi-target not supported') + + is_batched = targ_dim > 0 + input = input_ if is_batched else input_.unsqueeze(0) + target = target_ if is_batched else target_.unsqueeze(0) + + used_rows = 0 + batch_size = target.shape[0] + + output = ops.zeros(batch_size, dtype=input.dtype) + gather_inds = ops.zeros(batch_size, dtype=target.dtype) + + cutoff_values = [0] + self.cutoffs + for i in range(len(cutoff_values) - 1): + + low_idx = cutoff_values[i] + high_idx = cutoff_values[i + 1] + + target_mask = (target >= low_idx) & (target < high_idx) + row_indices = ops.nonzero(target_mask).squeeze() + + if row_indices.numel() == 0: + continue + + if i == 0: + gather_inds = ops.index_add(gather_inds, 0, row_indices, target[target_mask]) + + else: + relative_target = target[target_mask] - low_idx + input_subset = input.index_select(0, row_indices) + + cluster_output = self.tail[i - 1](input_subset) + cluster_index = self.shortlist_size + i - 1 + + gather_inds = ops.index_fill(gather_inds, 0, row_indices, cluster_index) + cluster_logprob = log_softmax(cluster_output, dim=1) + local_logprob = cluster_logprob.gather(1, relative_target.unsqueeze(1)) + output = ops.index_add(output, 0, row_indices, local_logprob.squeeze(1)) + + used_rows += row_indices.numel() + + if used_rows != batch_size: + raise RuntimeError(f"Target values should be in [0, {self.n_classes - 1}], " + f"but values in range [{target.min().item()}, {target.max().item()}] " + "were found. ") + + head_output = self.head(input) + head_logprob = log_softmax(head_output, dim=1) + output += ops.gather(head_logprob, 1, gather_inds.unsqueeze(1)).squeeze() + loss = (-output).mean() + + if not is_batched: + output = output.squeeze(0) + + return _ASMoutput(output, loss) + + def _get_full_log_prob(self, input, head_output): + """Given input tensor, and output of ``self.head``, compute the log of the full distribution.""" + out = ops.zeros((head_output.shape[0], self.n_classes), dtype=input.dtype) + head_logprob = log_softmax(head_output, dim=1) + + out[:, :self.shortlist_size] = head_logprob[:, :self.shortlist_size] + + for i, (start_idx, stop_idx) in enumerate(zip(self.cutoffs, self.cutoffs[1:])): + cluster_output = self.tail[i](input) + cluster_logprob = log_softmax(cluster_output, dim=1) + output_logprob = cluster_logprob + head_logprob[:, self.shortlist_size + i].unsqueeze(1) + + out[:, start_idx:stop_idx] = output_logprob + + return out + + + def log_prob(self, input: Tensor) -> Tensor: + r"""Compute log probabilities for all :math:`\texttt{n\_classes}`. + + Args: + input (Tensor): a minibatch of examples + + Returns: + log-probabilities of for each class :math:`c` + in range :math:`0 <= c <= \texttt{n\_classes}`, where :math:`\texttt{n\_classes}` is a + parameter passed to ``AdaptiveLogSoftmaxWithLoss`` constructor. + + Shape: + - Input: :math:`(N, \texttt{in\_features})` + - Output: :math:`(N, \texttt{n\_classes})` + + """ + head_output = self.head(input) + return self._get_full_log_prob(input, head_output) + + + + def predict(self, input: Tensor) -> Tensor: + r"""Return the class with the highest probability for each example in the input minibatch. + + This is equivalent to ``self.log_prob(input).argmax(dim=1)``, but is more efficient in some cases. + + Args: + input (Tensor): a minibatch of examples + + Returns: + output (Tensor): a class with the highest probability for each example + + Shape: + - Input: :math:`(N, \texttt{in\_features})` + - Output: :math:`(N)` + """ + head_output = self.head(input) + output = ops.argmax(head_output, dim=1) + not_in_shortlist = (output >= self.shortlist_size) + all_in_shortlist = not (not_in_shortlist.any()) + + if all_in_shortlist: + return output + + elif not_in_shortlist.all(): + log_prob = self._get_full_log_prob(input, head_output) + return ops.argmax(log_prob, dim=1) + + else: + log_prob = self._get_full_log_prob(input[not_in_shortlist], + head_output[not_in_shortlist]) + output[not_in_shortlist] = ops.argmax(log_prob, dim=1) + return output diff --git a/mindnlp/core/nn/modules/batchnorm.py b/mindnlp/core/nn/modules/batchnorm.py new file mode 100644 index 000000000..66ea0639b --- /dev/null +++ b/mindnlp/core/nn/modules/batchnorm.py @@ -0,0 +1,372 @@ +"""batch norm""" +from typing import Optional +from mindnlp.core import Tensor +from ..parameter import Parameter + +from .module import Module +from .. import init +from ... import ops +from .. import functional as F + +class _NormBase(Module): + """Common base of _InstanceNorm and _BatchNorm.""" + + _version = 2 + __constants__ = ["track_running_stats", "momentum", "eps", "num_features", "affine"] + num_features: int + eps: float + momentum: float + affine: bool + track_running_stats: bool + # WARNING: weight and bias purposely not defined here. + # See https://github.com/pytorch/pytorch/issues/39670 + + def __init__( + self, + num_features: int, + eps: float = 1e-5, + momentum: float = 0.1, + affine: bool = True, + track_running_stats: bool = True, + dtype=None + ) -> None: + factory_kwargs = {'dtype': dtype} + super().__init__() + self.num_features = num_features + self.eps = eps + self.momentum = momentum + self.affine = affine + self.track_running_stats = track_running_stats + self.weight = Parameter(ops.empty(num_features, **factory_kwargs), affine) + self.bias = Parameter(ops.empty(num_features, **factory_kwargs), affine) + if self.track_running_stats: + self.register_buffer('running_mean', ops.zeros(num_features,)) + self.register_buffer('running_var', ops.ones(num_features,)) + self.running_mean: Optional[Tensor] + self.running_var: Optional[Tensor] + self.register_buffer('num_batches_tracked', + Tensor(0, dtype=core.int64)) + self.num_batches_tracked: Optional[Tensor] + else: + self.register_buffer("running_mean", None) + self.register_buffer("running_var", None) + self.register_buffer("num_batches_tracked", None) + + def reset_running_stats(self) -> None: + if self.track_running_stats: + # running_mean/running_var/num_batches... are registered at runtime depending + # if self.track_running_stats is on + init.zeros_(self.running_mean) # type: ignore[union-attr] + init.ones_(self.running_var) # type: ignore[union-attr] + init.zeros_(self.num_batches_tracked) # type: ignore[union-attr,operator] + + def reset_parameters(self) -> None: + self.reset_running_stats() + if self.affine: + init.ones_(self.weight) + init.zeros_(self.bias) + + def _check_input_dim(self, input): + raise NotImplementedError + + def extra_repr(self): + return ( + "{num_features}, eps={eps}, momentum={momentum}, affine={affine}, " + "track_running_stats={track_running_stats}".format(**self.__dict__) + ) + + +class _BatchNorm(_NormBase): + def __init__( + self, + num_features: int, + eps: float = 1e-5, + momentum: float = 0.1, + affine: bool = True, + track_running_stats: bool = True, + dtype=None + ) -> None: + factory_kwargs = {'dtype': dtype} + super().__init__( + num_features, eps, momentum, affine, track_running_stats, **factory_kwargs + ) + + def forward(self, input: Tensor) -> Tensor: + self._check_input_dim(input) + + # exponential_average_factor is set to self.momentum + # (when it is available) only so that it gets updated + # in ONNX graph when this node is exported to ONNX. + if self.momentum is None: + exponential_average_factor = 0.0 + else: + exponential_average_factor = self.momentum + + if self.training and self.track_running_stats: + # TODO: if statement only here to tell the jit to skip emitting this when it is None + if self.num_batches_tracked is not None: # type: ignore[has-type] + self.num_batches_tracked = self.num_batches_tracked + 1 # type: ignore[has-type] + if self.momentum is None: # use cumulative moving average + exponential_average_factor = 1.0 / float(self.num_batches_tracked) + else: # use exponential moving average + exponential_average_factor = self.momentum + + r""" + Decide whether the mini-batch stats should be used for normalization rather than the buffers. + Mini-batch stats are used in training mode, and in eval mode when buffers are None. + """ + if self.training: + bn_training = True + else: + bn_training = (self.running_mean is None) and (self.running_var is None) + + r""" + Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be + passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are + used for normalization (i.e. in eval mode when buffers are not None). + """ + return F.batch_norm( + input, + # If buffers are not to be tracked, ensure that they won't be updated + self.running_mean + if not self.training or self.track_running_stats + else None, + self.running_var if not self.training or self.track_running_stats else None, + self.weight, + self.bias, + bn_training, + exponential_average_factor, + self.eps, + ) + + +class BatchNorm1d(_BatchNorm): + r"""Applies Batch Normalization over a 2D or 3D input. + + Method described in the paper + `Batch Normalization: Accelerating Deep Network Training by Reducing + Internal Covariate Shift `__ . + + .. math:: + + y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta + + The mean and standard-deviation are calculated per-dimension over + the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors + of size `C` (where `C` is the number of features or channels of the input). By default, the + elements of :math:`\gamma` are set to 1 and the elements of :math:`\beta` are set to 0. + At train time in the forward pass, the standard-deviation is calculated via the biased estimator, + equivalent to ``core.var(input, unbiased=False)``. However, the value stored in the + moving average of the standard-deviation is calculated via the unbiased estimator, equivalent to + ``core.var(input, unbiased=True)``. + + Also by default, during training this layer keeps running estimates of its + computed mean and variance, which are then used for normalization during + evaluation. The running estimates are kept with a default :attr:`momentum` + of 0.1. + + If :attr:`track_running_stats` is set to ``False``, this layer then does not + keep running estimates, and batch statistics are instead used during + evaluation time as well. + + .. note:: + This :attr:`momentum` argument is different from one used in optimizer + classes and the conventional notion of momentum. Mathematically, the + update rule for running statistics here is + :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`, + where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the + new observed value. + + Because the Batch Normalization is done over the `C` dimension, computing statistics + on `(N, L)` slices, it's common terminology to call this Temporal Batch Normalization. + + Args: + num_features: number of features or channels :math:`C` of the input + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Can be set to ``None`` for cumulative moving average + (i.e. simple average). Default: 0.1 + affine: a boolean value that when set to ``True``, this module has + learnable affine parameters. Default: ``True`` + track_running_stats: a boolean value that when set to ``True``, this + module tracks the running mean and variance, and when set to ``False``, + this module does not track such statistics, and initializes statistics + buffers :attr:`running_mean` and :attr:`running_var` as ``None``. + When these buffers are ``None``, this module always uses batch statistics. + in both training and eval modes. Default: ``True`` + + Shape: + - Input: :math:`(N, C)` or :math:`(N, C, L)`, where :math:`N` is the batch size, + :math:`C` is the number of features or channels, and :math:`L` is the sequence length + - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) + + Examples:: + + >>> # With Learnable Parameters + >>> m = nn.BatchNorm1d(100) + >>> # Without Learnable Parameters + >>> m = nn.BatchNorm1d(100, affine=False) + >>> input = core.randn(20, 100) + >>> output = m(input) + """ + + def _check_input_dim(self, input): + if input.dim() != 2 and input.dim() != 3: + raise ValueError( + f"expected 2D or 3D input (got {input.dim()}D input)" + ) + + +class BatchNorm2d(_BatchNorm): + r"""Applies Batch Normalization over a 4D input. + + 4D is a mini-batch of 2D inputs + with additional channel dimension. Method described in the paper + `Batch Normalization: Accelerating Deep Network Training by Reducing + Internal Covariate Shift `__ . + + .. math:: + + y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta + + The mean and standard-deviation are calculated per-dimension over + the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors + of size `C` (where `C` is the input size). By default, the elements of :math:`\gamma` are set + to 1 and the elements of :math:`\beta` are set to 0. At train time in the forward pass, the + standard-deviation is calculated via the biased estimator, equivalent to + ``core.var(input, unbiased=False)``. However, the value stored in the moving average of the + standard-deviation is calculated via the unbiased estimator, equivalent to + ``core.var(input, unbiased=True)``. + + Also by default, during training this layer keeps running estimates of its + computed mean and variance, which are then used for normalization during + evaluation. The running estimates are kept with a default :attr:`momentum` + of 0.1. + + If :attr:`track_running_stats` is set to ``False``, this layer then does not + keep running estimates, and batch statistics are instead used during + evaluation time as well. + + .. note:: + This :attr:`momentum` argument is different from one used in optimizer + classes and the conventional notion of momentum. Mathematically, the + update rule for running statistics here is + :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`, + where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the + new observed value. + + Because the Batch Normalization is done over the `C` dimension, computing statistics + on `(N, H, W)` slices, it's common terminology to call this Spatial Batch Normalization. + + Args: + num_features: :math:`C` from an expected input of size + :math:`(N, C, H, W)` + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Can be set to ``None`` for cumulative moving average + (i.e. simple average). Default: 0.1 + affine: a boolean value that when set to ``True``, this module has + learnable affine parameters. Default: ``True`` + track_running_stats: a boolean value that when set to ``True``, this + module tracks the running mean and variance, and when set to ``False``, + this module does not track such statistics, and initializes statistics + buffers :attr:`running_mean` and :attr:`running_var` as ``None``. + When these buffers are ``None``, this module always uses batch statistics. + in both training and eval modes. Default: ``True`` + + Shape: + - Input: :math:`(N, C, H, W)` + - Output: :math:`(N, C, H, W)` (same shape as input) + + Examples:: + + >>> # With Learnable Parameters + >>> m = nn.BatchNorm2d(100) + >>> # Without Learnable Parameters + >>> m = nn.BatchNorm2d(100, affine=False) + >>> input = core.randn(20, 100, 35, 45) + >>> output = m(input) + """ + + def _check_input_dim(self, input): + if input.dim() != 4: + raise ValueError(f"expected 4D input (got {input.dim()}D input)") + + +class BatchNorm3d(_BatchNorm): + r"""Applies Batch Normalization over a 5D input. + + 5D is a mini-batch of 3D inputs with additional channel dimension as described in the paper + `Batch Normalization: Accelerating Deep Network Training by Reducing + Internal Covariate Shift `__ . + + .. math:: + + y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta + + The mean and standard-deviation are calculated per-dimension over + the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors + of size `C` (where `C` is the input size). By default, the elements of :math:`\gamma` are set + to 1 and the elements of :math:`\beta` are set to 0. At train time in the forward pass, the + standard-deviation is calculated via the biased estimator, equivalent to + ``core.var(input, unbiased=False)``. However, the value stored in the moving average of the + standard-deviation is calculated via the unbiased estimator, equivalent to + ``core.var(input, unbiased=True)``. + + Also by default, during training this layer keeps running estimates of its + computed mean and variance, which are then used for normalization during + evaluation. The running estimates are kept with a default :attr:`momentum` + of 0.1. + + If :attr:`track_running_stats` is set to ``False``, this layer then does not + keep running estimates, and batch statistics are instead used during + evaluation time as well. + + .. note:: + This :attr:`momentum` argument is different from one used in optimizer + classes and the conventional notion of momentum. Mathematically, the + update rule for running statistics here is + :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`, + where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the + new observed value. + + Because the Batch Normalization is done over the `C` dimension, computing statistics + on `(N, D, H, W)` slices, it's common terminology to call this Volumetric Batch Normalization + or Spatio-temporal Batch Normalization. + + Args: + num_features: :math:`C` from an expected input of size + :math:`(N, C, D, H, W)` + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Can be set to ``None`` for cumulative moving average + (i.e. simple average). Default: 0.1 + affine: a boolean value that when set to ``True``, this module has + learnable affine parameters. Default: ``True`` + track_running_stats: a boolean value that when set to ``True``, this + module tracks the running mean and variance, and when set to ``False``, + this module does not track such statistics, and initializes statistics + buffers :attr:`running_mean` and :attr:`running_var` as ``None``. + When these buffers are ``None``, this module always uses batch statistics. + in both training and eval modes. Default: ``True`` + + Shape: + - Input: :math:`(N, C, D, H, W)` + - Output: :math:`(N, C, D, H, W)` (same shape as input) + + Examples:: + + >>> # With Learnable Parameters + >>> m = nn.BatchNorm3d(100) + >>> # Without Learnable Parameters + >>> m = nn.BatchNorm3d(100, affine=False) + >>> input = core.randn(20, 100, 35, 45, 10) + >>> output = m(input) + """ + + def _check_input_dim(self, input): + if input.dim() != 5: + raise ValueError(f"expected 5D input (got {input.dim()}D input)") diff --git a/mindnlp/core/nn/modules/container.py b/mindnlp/core/nn/modules/container.py new file mode 100644 index 000000000..fc99187b8 --- /dev/null +++ b/mindnlp/core/nn/modules/container.py @@ -0,0 +1,826 @@ +"""Container""" +import operator +from itertools import chain, islice +from typing import Any, Dict, Iterable, Iterator, Mapping, Optional, Tuple, Union, TypeVar +from collections import OrderedDict, abc as container_abcs +from typing_extensions import Self + +from mindnlp import core +from ..parameter import Parameter + +from .module import Module + +T = TypeVar('T', bound=Module) + +def _addindent(s_, numSpaces): + s = s_.split('\n') + # don't do anything for single-line stuff + if len(s) == 1: + return s_ + first = s.pop(0) + s = [(numSpaces * ' ') + line for line in s] + s = '\n'.join(s) + s = first + '\n' + s + return s + +class Sequential(Module): + r"""A sequential container. + + Modules will be added to it in the order they are passed in the + constructor. Alternatively, an ``OrderedDict`` of modules can be + passed in. The ``forward()`` method of ``Sequential`` accepts any + input and forwards it to the first module it contains. It then + "chains" outputs to inputs sequentially for each subsequent module, + finally returning the output of the last module. + + The value a ``Sequential`` provides over manually calling a sequence + of modules is that it allows treating the whole container as a + single module, such that performing a transformation on the + ``Sequential`` applies to each of the modules it stores (which are + each a registered submodule of the ``Sequential``). + + What's the difference between a ``Sequential`` and a + :class:`core.nn.ModuleList`? A ``ModuleList`` is exactly what it + sounds like--a list for storing ``Module`` s! On the other hand, + the layers in a ``Sequential`` are connected in a cascading way. + + Example:: + + # Using Sequential to create a small model. When `model` is run, + # input will first be passed to `Conv2d(1,20,5)`. The output of + # `Conv2d(1,20,5)` will be used as the input to the first + # `ReLU`; the output of the first `ReLU` will become the input + # for `Conv2d(20,64,5)`. Finally, the output of + # `Conv2d(20,64,5)` will be used as input to the second `ReLU` + model = nn.Sequential( + nn.Conv2d(1,20,5), + nn.ReLU(), + nn.Conv2d(20,64,5), + nn.ReLU() + ) + + # Using Sequential with OrderedDict. This is functionally the + # same as the above code + model = nn.Sequential(OrderedDict([ + ('conv1', nn.Conv2d(1,20,5)), + ('relu1', nn.ReLU()), + ('conv2', nn.Conv2d(20,64,5)), + ('relu2', nn.ReLU()) + ])) + """ + + _modules: Dict[str, Module] # type: ignore[assignment] + + def __init__(self, *args): + super().__init__() + if len(args) == 1 and isinstance(args[0], OrderedDict): + for key, module in args[0].items(): + self.add_module(key, module) + else: + for idx, module in enumerate(args): + self.add_module(str(idx), module) + + def _get_item_by_idx(self, iterator, idx) -> T: # type: ignore[misc, type-var] + """Get the idx-th item of the iterator.""" + size = len(self) + idx = operator.index(idx) + if not -size <= idx < size: + raise IndexError(f'index {idx} is out of range') + idx %= size + return next(islice(iterator, idx, None)) + + def __getitem__(self, idx: Union[slice, int]) -> Union['Sequential', T]: + if isinstance(idx, slice): + return self.__class__(OrderedDict(list(self._modules.items())[idx])) + else: + return self._get_item_by_idx(self._modules.values(), idx) + + def __setitem__(self, idx: int, module: Module) -> None: + key: str = self._get_item_by_idx(self._modules.keys(), idx) + return setattr(self, key, module) + + def __delitem__(self, idx: Union[slice, int]) -> None: + if isinstance(idx, slice): + for key in list(self._modules.keys())[idx]: + delattr(self, key) + else: + key = self._get_item_by_idx(self._modules.keys(), idx) + delattr(self, key) + # To preserve numbering + str_indices = [str(i) for i in range(len(self._modules))] + self._modules = OrderedDict(list(zip(str_indices, self._modules.values()))) + + def __len__(self) -> int: + return len(self._modules) + + def __add__(self, other) -> 'Sequential': + if isinstance(other, Sequential): + ret = Sequential() + for layer in self: + ret.append(layer) + for layer in other: + ret.append(layer) + return ret + else: + raise ValueError('add operator supports only objects ' + f'of Sequential class, but {str(type(other))} is given.') + + def pop(self, key: Union[int, slice]) -> Module: + v = self[key] + del self[key] + return v + + def __iadd__(self, other) -> Self: + if isinstance(other, Sequential): + offset = len(self) + for i, module in enumerate(other): + self.add_module(str(i + offset), module) + return self + else: + raise ValueError('add operator supports only objects ' + f'of Sequential class, but {str(type(other))} is given.') + + def __mul__(self, other: int) -> 'Sequential': + if not isinstance(other, int): + raise TypeError(f"unsupported operand type(s) for *: {type(self)} and {type(other)}") + elif (other <= 0): + raise ValueError(f"Non-positive multiplication factor {other} for {type(self)}") + else: + combined = Sequential() + offset = 0 + for _ in range(other): + for module in self: + combined.add_module(str(offset), module) + offset += 1 + return combined + + def __rmul__(self, other: int) -> 'Sequential': + return self.__mul__(other) + + def __imul__(self, other: int) -> Self: + if not isinstance(other, int): + raise TypeError(f"unsupported operand type(s) for *: {type(self)} and {type(other)}") + elif (other <= 0): + raise ValueError(f"Non-positive multiplication factor {other} for {type(self)}") + else: + len_original = len(self) + offset = len(self) + for _ in range(other - 1): + for i in range(len_original): + self.add_module(str(i + offset), self._modules[str(i)]) + offset += len_original + return self + + def __dir__(self): + keys = super().__dir__() + keys = [key for key in keys if not key.isdigit()] + return keys + + def __iter__(self) -> Iterator[Module]: + return iter(self._modules.values()) + + # NB: We can't really type check this function as the type of input + # may change dynamically (as is tested in + # TestScript.test_sequential_intermediary_types). Cannot annotate + # with Any as TorchScript expects a more precise type + def forward(self, input): + if self.__ms_class__: + return self.jit_forward(input) + return self.slow_forward(input) + + def slow_forward(self, input): + for module in self: + input = module(input) + return input + + def jit_forward(self, input): + for module in self._modules.values(): + input = module(input) + return input + + def append(self, module: Module) -> 'Sequential': + r"""Append a given module to the end. + + Args: + module (nn.Module): module to append + """ + self.add_module(str(len(self)), module) + return self + + + def insert(self, index: int, module: Module) -> 'Sequential': + if not isinstance(module, Module): + raise AssertionError( + f'module should be of type: {Module}') + n = len(self._modules) + if not (-n <= index <= n): + raise IndexError( + f'Index out of range: {index}') + if index < 0: + index += n + for i in range(n, index, -1): + self._modules[str(i)] = self._modules[str(i - 1)] + self._modules[str(index)] = module + return self + + def extend(self, sequential) -> 'Sequential': + for layer in sequential: + self.append(layer) + return self + + +class ModuleList(Module): + r"""Holds submodules in a list. + + :class:`~core.nn.ModuleList` can be indexed like a regular Python list, but + modules it contains are properly registered, and will be visible by all + :class:`~core.nn.Module` methods. + + Args: + modules (iterable, optional): an iterable of modules to add + + Example:: + + class MyModule(nn.Module): + def __init__(self): + super().__init__() + self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)]) + + def forward(self, x): + # ModuleList can act as an iterable, or be indexed using ints + for i, l in enumerate(self.linears): + x = self.linears[i // 2](x) + l(x) + return x + """ + + _modules: Dict[str, Module] # type: ignore[assignment] + + def __init__(self, modules: Optional[Iterable[Module]] = None) -> None: + super().__init__() + if modules is not None: + self += modules + + def _get_abs_string_index(self, idx): + """Get the absolute index for the list of modules.""" + idx = operator.index(idx) + if not (-len(self) <= idx < len(self)): + raise IndexError(f'index {idx} is out of range') + if idx < 0: + idx += len(self) + return str(idx) + + def __getitem__(self, idx: Union[int, slice]) -> Union[Module, 'ModuleList']: + if isinstance(idx, slice): + return self.__class__(list(self._modules.values())[idx]) + else: + return self._modules[self._get_abs_string_index(idx)] + + def __setitem__(self, idx: int, module: Module) -> None: + idx = self._get_abs_string_index(idx) + return setattr(self, str(idx), module) + + def __delitem__(self, idx: Union[int, slice]) -> None: + if isinstance(idx, slice): + for k in range(len(self._modules))[idx]: + delattr(self, str(k)) + else: + delattr(self, self._get_abs_string_index(idx)) + # To preserve numbering, self._modules is being reconstructed with modules after deletion + str_indices = [str(i) for i in range(len(self._modules))] + self._modules = OrderedDict(list(zip(str_indices, self._modules.values()))) + + def __len__(self) -> int: + return len(self._modules) + + def __iter__(self) -> Iterator[Module]: + return iter(self._modules.values()) + + def __iadd__(self, modules: Iterable[Module]) -> Self: + return self.extend(modules) + + def __add__(self, other: Iterable[Module]) -> 'ModuleList': + combined = ModuleList() + for i, module in enumerate(chain(self, other)): + combined.add_module(str(i), module) + return combined + + def __repr__(self): + """Return a custom repr for ModuleList that compresses repeated module representations.""" + list_of_reprs = [repr(item) for item in self] + if len(list_of_reprs) == 0: + return self._get_name() + '()' + + start_end_indices = [[0, 0]] + repeated_blocks = [list_of_reprs[0]] + for i, r in enumerate(list_of_reprs[1:], 1): + if r == repeated_blocks[-1]: + start_end_indices[-1][1] += 1 + continue + + start_end_indices.append([i, i]) + repeated_blocks.append(r) + + lines = [] + main_str = self._get_name() + '(' + for (start_id, end_id), b in zip(start_end_indices, repeated_blocks): + local_repr = f"({start_id}): {b}" # default repr + + if start_id != end_id: + n = end_id - start_id + 1 + local_repr = f"({start_id}-{end_id}): {n} x {b}" + + local_repr = _addindent(local_repr, 2) + lines.append(local_repr) + + main_str += '\n ' + '\n '.join(lines) + '\n' + main_str += ')' + return main_str + + def __dir__(self): + keys = super().__dir__() + keys = [key for key in keys if not key.isdigit()] + return keys + + def insert(self, index: int, module: Module) -> None: + r"""Insert a given module before a given index in the list. + + Args: + index (int): index to insert. + module (nn.Module): module to insert + """ + for i in range(len(self._modules), index, -1): + self._modules[str(i)] = self._modules[str(i - 1)] + self._modules[str(index)] = module + + + def append(self, module: Module) -> 'ModuleList': + r"""Append a given module to the end of the list. + + Args: + module (nn.Module): module to append + """ + self.add_module(str(len(self)), module) + return self + + + def pop(self, key: Union[int, slice]) -> Module: + v = self[key] + del self[key] + return v + + def extend(self, modules: Iterable[Module]) -> Self: + r"""Append modules from a Python iterable to the end of the list. + + Args: + modules (iterable): iterable of modules to append + """ + if not isinstance(modules, container_abcs.Iterable): + raise TypeError("ModuleList.extend should be called with an " + "iterable, but got " + type(modules).__name__) + offset = len(self) + for i, module in enumerate(modules): + self.add_module(str(offset + i), module) + return self + + +class ModuleDict(Module): + r"""Holds submodules in a dictionary. + + :class:`~core.nn.ModuleDict` can be indexed like a regular Python dictionary, + but modules it contains are properly registered, and will be visible by all + :class:`~core.nn.Module` methods. + + :class:`~core.nn.ModuleDict` is an **ordered** dictionary that respects + + * the order of insertion, and + + * in :meth:`~core.nn.ModuleDict.update`, the order of the merged + ``OrderedDict``, ``dict`` (started from Python 3.6) or another + :class:`~core.nn.ModuleDict` (the argument to + :meth:`~core.nn.ModuleDict.update`). + + Note that :meth:`~core.nn.ModuleDict.update` with other unordered mapping + types (e.g., Python's plain ``dict`` before Python version 3.6) does not + preserve the order of the merged mapping. + + Args: + modules (iterable, optional): a mapping (dictionary) of (string: module) + or an iterable of key-value pairs of type (string, module) + + Example:: + + class MyModule(nn.Module): + def __init__(self): + super().__init__() + self.choices = nn.ModuleDict({ + 'conv': nn.Conv2d(10, 10, 3), + 'pool': nn.MaxPool2d(3) + }) + self.activations = nn.ModuleDict([ + ['lrelu', nn.LeakyReLU()], + ['prelu', nn.PReLU()] + ]) + + def forward(self, x, choice, act): + x = self.choices[choice](x) + x = self.activations[act](x) + return x + """ + + _modules: Dict[str, Module] # type: ignore[assignment] + + def __init__(self, modules: Optional[Mapping[str, Module]] = None) -> None: + super().__init__() + if modules is not None: + self.update(modules) + + def __getitem__(self, key: str) -> Module: + return self._modules[key] + + def __setitem__(self, key: str, module: Module) -> None: + self.add_module(key, module) + + def __delitem__(self, key: str) -> None: + del self._modules[key] + + def __len__(self) -> int: + return len(self._modules) + + def __iter__(self) -> Iterator[str]: + return iter(self._modules) + + def __contains__(self, key: str) -> bool: + return key in self._modules + + def clear(self) -> None: + """Remove all items from the ModuleDict.""" + self._modules.clear() + + + def pop(self, key: str) -> Module: + r"""Remove key from the ModuleDict and return its module. + + Args: + key (str): key to pop from the ModuleDict + """ + v = self[key] + del self[key] + return v + + + def keys(self) -> Iterable[str]: + r"""Return an iterable of the ModuleDict keys.""" + return self._modules.keys() + + def items(self) -> Iterable[Tuple[str, Module]]: + r"""Return an iterable of the ModuleDict key/value pairs.""" + return self._modules.items() + + def values(self) -> Iterable[Module]: + r"""Return an iterable of the ModuleDict values.""" + return self._modules.values() + + def update(self, modules: Mapping[str, Module]) -> None: + r"""Update the :class:`~core.nn.ModuleDict` with key-value pairs from a mapping, overwriting existing keys. + + .. note:: + If :attr:`modules` is an ``OrderedDict``, a :class:`~core.nn.ModuleDict`, or + an iterable of key-value pairs, the order of new elements in it is preserved. + + Args: + modules (iterable): a mapping (dictionary) from string to :class:`~core.nn.Module`, + or an iterable of key-value pairs of type (string, :class:`~core.nn.Module`) + """ + if not isinstance(modules, container_abcs.Iterable): + raise TypeError("ModuleDict.update should be called with an " + "iterable of key/value pairs, but got " + + type(modules).__name__) + + if isinstance(modules, (OrderedDict, ModuleDict, container_abcs.Mapping)): + for key, module in modules.items(): + self[key] = module + else: + # modules here can be a list with two items + for j, m in enumerate(modules): + if not isinstance(m, container_abcs.Iterable): + raise TypeError("ModuleDict update sequence element " + "#" + str(j) + " should be Iterable; is" + + type(m).__name__) + if not len(m) == 2: + raise ValueError("ModuleDict update sequence element " + "#" + str(j) + " has length " + str(len(m)) + + "; 2 is required") + # modules can be Mapping (what it's typed at), or a list: [(name1, module1), (name2, module2)] + # that's too cumbersome to type correctly with overloads, so we add an ignore here + self[m[0]] = m[1] # type: ignore[assignment] + + + # remove forward alltogether to fallback on Module's _forward_unimplemented + +class ParameterList(Module): + r"""Holds parameters in a list. + + ParameterList can be indexed like a regular Python list, but parameters it + contains are properly registered, and will be visible by all Module methods. + + Arguments: + modules (list, optional): a list of :class:`~core.nn.Parameter`` to add + + Example:: + + class MyModule(nn.Module): + def __init__(self): + super(MyModule, self).__init__() + self.params = nn.ParameterList([nn.Parameter(core.randn(10, 10)) for i in range(10)]) + + def forward(self, x): + # ModuleList can act as an iterable, or be indexed using ints + for i, p in enumerate(self.params): + x = self.params[i // 2].mm(x) + p.mm(x) + return x + """ + + def __init__(self, parameters=None): + super(ParameterList, self).__init__() + if parameters is not None: + self += parameters + + def __getitem__(self, idx): + if not (-len(self) <= idx < len(self)): + raise IndexError('index {} is out of range'.format(idx)) + if idx < 0: + idx += len(self) + return self._parameters[str(idx)] + + def __setitem__(self, idx, param): + return self.register_parameter(str(idx), param) + + def __len__(self): + return len(self._parameters) + + def __iter__(self): + return iter(self._parameters.values()) + + def __iadd__(self, parameters): + return self.extend(parameters) + + def append(self, parameter): + """Appends a given parameter at the end of the list. + + Arguments: + parameter (nn.Parameter): parameter to append + """ + self.register_parameter(str(len(self)), parameter) + return self + + def extend(self, parameters): + """Appends parameters from a Python list at the end. + + Arguments: + parameters (list): list of parameters to append + """ + if not isinstance(parameters, list): + raise TypeError("ParameterList.extend should be called with a " + "list, but got " + type(parameters).__name__) + offset = len(self) + for i, param in enumerate(parameters): + self.register_parameter(str(offset + i), param) + return self + +class ParameterDict(Module): + r"""Holds parameters in a dictionary. + + ParameterDict can be indexed like a regular Python dictionary, but Parameters it + contains are properly registered, and will be visible by all Module methods. + Other objects are treated as would be done by a regular Python dictionary + + :class:`~core.nn.ParameterDict` is an **ordered** dictionary. + :meth:`~core.nn.ParameterDict.update` with other unordered mapping + types (e.g., Python's plain ``dict``) does not preserve the order of the + merged mapping. On the other hand, ``OrderedDict`` or another :class:`~core.nn.ParameterDict` + will preserve their ordering. + + Note that the constructor, assigning an element of the dictionary and the + :meth:`~core.nn.ParameterDict.update` method will convert any :class:`~core.Tensor` into + :class:`~core.nn.Parameter`. + + Args: + values (iterable, optional): a mapping (dictionary) of + (string : Any) or an iterable of key-value pairs + of type (string, Any) + + Example:: + + class MyModule(nn.Module): + def __init__(self): + super().__init__() + self.params = nn.ParameterDict({ + 'left': nn.Parameter(core.randn(5, 10)), + 'right': nn.Parameter(core.randn(5, 10)) + }) + + def forward(self, x, choice): + x = self.params[choice].mm(x) + return x + """ + + def __init__(self, parameters: Any = None) -> None: + super().__init__() + self._keys: Dict[str, None] = {} + if parameters is not None: + self.update(parameters) + + def _key_to_attr(self, key: str) -> str: + if not isinstance(key, str): + raise TypeError("Index given to ParameterDict cannot be used as a key as it is " + f"not a string (type is '{type(key).__name__}'). Open an issue on " + "github if you need non-string keys.") + else: + # Use the key as-is so that `.named_parameters()` returns the right thing + return key + + def __getitem__(self, key: str) -> Any: + attr = self._key_to_attr(key) + return getattr(self, attr) + + def __setitem__(self, key: str, value: Any) -> None: + # Note that all other function that add an entry to the dictionary part of + # the ParameterDict end up here. So this is the only place where we need + # to wrap things into Parameter if needed. + # Objects added via setattr() are not in the dictionary part and thus won't + # call into this function. + self._keys[key] = None + attr = self._key_to_attr(key) + if isinstance(value, core.Tensor) and not isinstance(value, Parameter): + value = Parameter(value) + setattr(self, attr, value) + + def __delitem__(self, key: str) -> None: + del self._keys[key] + attr = self._key_to_attr(key) + delattr(self, attr) + + def __len__(self) -> int: + return len(self._keys) + + def __iter__(self) -> Iterator[str]: + return iter(self._keys) + + def __reversed__(self) -> Iterator[str]: + return reversed(list(self._keys)) + + def copy(self) -> 'ParameterDict': + """Return a copy of this :class:`~core.nn.ParameterDict` instance.""" + # We have to use an OrderedDict because the ParameterDict constructor + # behaves differently on plain dict vs OrderedDict + return ParameterDict(OrderedDict((k, self[k]) for k in self._keys)) + + + def __contains__(self, key: str) -> bool: + return key in self._keys + + def setdefault(self, key: str, default: Optional[Any] = None) -> Any: + """Set the default for a key in the Parameterdict. + + If key is in the ParameterDict, return its value. + If not, insert `key` with a parameter `default` and return `default`. + `default` defaults to `None`. + + Args: + key (str): key to set default for + default (Any): the parameter set to the key + """ + if key not in self: + self[key] = default + return self[key] + + + def clear(self) -> None: + """Remove all items from the ParameterDict.""" + for k in self._keys.copy(): + del self[k] + + + def pop(self, key: str) -> Any: + r"""Remove key from the ParameterDict and return its parameter. + + Args: + key (str): key to pop from the ParameterDict + """ + v = self[key] + del self[key] + return v + + + def popitem(self) -> Tuple[str, Any]: + """Remove and return the last inserted `(key, parameter)` pair from the ParameterDict.""" + k, _ = self._keys.popitem() + # We need the key in the _keys to be able to access/del + self._keys[k] = None + val = self[k] + del self[k] + return k, val + + + def get(self, key: str, default: Optional[Any] = None) -> Any: + r"""Return the parameter associated with key if present. Otherwise return default if provided, None if not. + + Args: + key (str): key to get from the ParameterDict + default (Parameter, optional): value to return if key not present + """ + return self[key] if key in self else default + + + def fromkeys(self, keys: Iterable[str], default: Optional[Any] = None) -> 'ParameterDict': + r"""Return a new ParameterDict with the keys provided. + + Args: + keys (iterable, string): keys to make the new ParameterDict from + default (Parameter, optional): value to set for all keys + """ + return ParameterDict((k, default) for k in keys) + + + def keys(self) -> Iterable[str]: + r"""Return an iterable of the ParameterDict keys.""" + return self._keys.keys() + + + def items(self) -> Iterable[Tuple[str, Any]]: + r"""Return an iterable of the ParameterDict key/value pairs.""" + return ((k, self[k]) for k in self._keys) + + + def values(self) -> Iterable[Any]: + r"""Return an iterable of the ParameterDict values.""" + return (self[k] for k in self._keys) + + + def update(self, parameters: Union[Mapping[str, Any], 'ParameterDict']) -> None: + r"""Update the :class:`~core.nn.ParameterDict` with key-value pairs from ``parameters``, overwriting existing keys. + + .. note:: + If :attr:`parameters` is an ``OrderedDict``, a :class:`~core.nn.ParameterDict`, or + an iterable of key-value pairs, the order of new elements in it is preserved. + + Args: + parameters (iterable): a mapping (dictionary) from string to + :class:`~core.nn.Parameter`, or an iterable of + key-value pairs of type (string, :class:`~core.nn.Parameter`) + """ + if not isinstance(parameters, container_abcs.Iterable): + raise TypeError("ParametersDict.update should be called with an " + "iterable of key/value pairs, but got " + + type(parameters).__name__) + + if isinstance(parameters, (OrderedDict, ParameterDict)): + for key, parameter in parameters.items(): + self[key] = parameter + elif isinstance(parameters, container_abcs.Mapping): + for key, parameter in sorted(parameters.items()): + self[key] = parameter + else: + for j, p in enumerate(parameters): + if not isinstance(p, container_abcs.Iterable): + raise TypeError("ParameterDict update sequence element " + "#" + str(j) + " should be Iterable; is" + + type(p).__name__) + if not len(p) == 2: + raise ValueError("ParameterDict update sequence element " + "#" + str(j) + " has length " + str(len(p)) + + "; 2 is required") + # parameters as length-2 list too cumbersome to type, see ModuleDict.update comment + self[p[0]] = p[1] # type: ignore[assignment] + + + def extra_repr(self) -> str: + child_lines = [] + for k, p in self.items(): + if isinstance(p, core.Tensor): + size_str = 'x'.join(str(size) for size in p.size()) + parastr = '{} containing: [{} of size {}]'.format( + "Parameter" if isinstance(p, Parameter) else "Tensor", + type(p), size_str) + child_lines.append(' (' + str(k) + '): ' + parastr) + else: + child_lines.append(' (' + str(k) + '): Object of type: ' + type(p).__name__) + tmpstr = '\n'.join(child_lines) + return tmpstr + + def __call__(self, input): + raise RuntimeError('ParameterDict should not be called.') + + def __or__(self, other: 'ParameterDict') -> 'ParameterDict': + copy = self.copy() + copy.update(other) + return copy + + def __ror__(self, other: 'ParameterDict') -> 'ParameterDict': + copy = other.copy() + copy.update(self) + return copy + + def __ior__(self, other : 'ParameterDict') -> Self: + self.update(other) + return self diff --git a/mindnlp/core/nn/modules/conv.py b/mindnlp/core/nn/modules/conv.py new file mode 100644 index 000000000..4a26cb5ad --- /dev/null +++ b/mindnlp/core/nn/modules/conv.py @@ -0,0 +1,851 @@ +# coding=utf-8 +"""conv""" +import math +from typing import Optional, Tuple, Union, List + +from mindnlp.core import Tensor +from ..parameter import Parameter +from .module import Module +from ..common_types import _size_2_t, _size_1_t +from ._utils import _single, _pair, _reverse_repeat_tuple +from .. import init +from .. import functional as F +from ... import ops + + +class _ConvNd(Module): + + __constants__ = ['stride', 'padding', 'dilation', 'groups', + 'padding_mode', 'output_padding', 'in_channels', + 'out_channels', 'kernel_size'] + __annotations__ = {'bias': Optional[Tensor]} + + def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor: # type: ignore[empty-body] + ... + + in_channels: int + _reversed_padding_repeated_twice: List[int] + out_channels: int + kernel_size: Tuple[int, ...] + stride: Tuple[int, ...] + padding: Union[str, Tuple[int, ...]] + dilation: Tuple[int, ...] + transposed: bool + output_padding: Tuple[int, ...] + groups: int + padding_mode: str + weight: Tensor + bias: Optional[Tensor] + + def __init__(self, + in_channels: int, + out_channels: int, + kernel_size: Tuple[int, ...], + stride: Tuple[int, ...], + padding: Tuple[int, ...], + dilation: Tuple[int, ...], + transposed: bool, + output_padding: Tuple[int, ...], + groups: int, + bias: bool, + padding_mode: str, + dtype=None) -> None: + factory_kwargs = {'dtype': dtype} + super().__init__() + if groups <= 0: + raise ValueError('groups must be a positive integer') + if in_channels % groups != 0: + raise ValueError('in_channels must be divisible by groups') + if out_channels % groups != 0: + raise ValueError('out_channels must be divisible by groups') + valid_padding_strings = {'same', 'valid'} + if isinstance(padding, str): + if padding not in valid_padding_strings: + raise ValueError( + f"Invalid padding string {padding!r}, should be one of {valid_padding_strings}") + if padding == 'same' and any(s != 1 for s in stride): + raise ValueError("padding='same' is not supported for strided convolutions") + + valid_padding_modes = {'zeros', 'reflect', 'replicate', 'circular'} + if padding_mode not in valid_padding_modes: + raise ValueError(f"padding_mode must be one of {valid_padding_modes}, but got padding_mode='{padding_mode}'") + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.dilation = dilation + self.transposed = transposed + self.output_padding = output_padding + self.groups = groups + self.padding_mode = padding_mode + # `_reversed_padding_repeated_twice` is the padding to be passed to + # `F.pad` if needed (e.g., for non-zero padding types that are + # implemented as two ops: padding + conv). `F.pad` accepts paddings in + # reverse order than the dimension. + if isinstance(self.padding, str): + self._reversed_padding_repeated_twice = [0, 0] * len(kernel_size) + if padding == 'same': + for d, k, i in zip(dilation, kernel_size, + range(len(kernel_size) - 1, -1, -1)): + total_padding = d * (k - 1) + left_pad = total_padding // 2 + self._reversed_padding_repeated_twice[2 * i] = left_pad + self._reversed_padding_repeated_twice[2 * i + 1] = ( + total_padding - left_pad) + else: + self._reversed_padding_repeated_twice = _reverse_repeat_tuple(self.padding, 2) + + if transposed: + self.weight = Parameter(ops.empty( + (in_channels, out_channels // groups, *kernel_size), **factory_kwargs)) + else: + self.weight = Parameter(ops.empty( + (out_channels, in_channels // groups, *kernel_size), **factory_kwargs)) + if bias: + self.bias = Parameter(ops.empty(out_channels, **factory_kwargs)) + else: + self.register_parameter('bias', None) + + self.reset_parameters() + + def reset_parameters(self) -> None: + # Setting a=sqrt(5) in kaiming_uniform is the same as initializing with + # uniform(-1/sqrt(k), 1/sqrt(k)), where k = weight.size(1) * prod(*kernel_size) + # For more details see: https://github.com/pytorch/pytorch/issues/15314#issuecomment-477448573 + init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + if self.bias is not None: + fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) + if fan_in != 0: + bound = 1 / math.sqrt(fan_in) + init.uniform_(self.bias, -bound, bound) + + def extra_repr(self): + s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}' + ', stride={stride}') + if self.padding != (0,) * len(self.padding): + s += ', padding={padding}' + if self.dilation != (1,) * len(self.dilation): + s += ', dilation={dilation}' + if self.output_padding != (0,) * len(self.output_padding): + s += ', output_padding={output_padding}' + if self.groups != 1: + s += ', groups={groups}' + if self.bias is None: + s += ', bias=False' + if self.padding_mode != 'zeros': + s += ', padding_mode={padding_mode}' + return s.format(**self.__dict__) + + def __setstate__(self, state): + super().__setstate__(state) + if not hasattr(self, 'padding_mode'): + self.padding_mode = 'zeros' + + +class Conv1d(_ConvNd): + r"""Applies a 1D convolution over an input signal composed of several input + planes. + + In the simplest case, the output value of the layer with input size + :math:`(N, C_{\text{in}}, L)` and output :math:`(N, C_{\text{out}}, L_{\text{out}})` can be + precisely described as: + + .. math:: + \text{out}(N_i, C_{\text{out}_j}) = \text{bias}(C_{\text{out}_j}) + + \sum_{k = 0}^{C_{in} - 1} \text{weight}(C_{\text{out}_j}, k) + \star \text{input}(N_i, k) + + where :math:`\star` is the valid `cross-correlation`_ operator, + :math:`N` is a batch size, :math:`C` denotes a number of channels, + :math:`L` is a length of signal sequence. + """ + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: _size_1_t, + stride: _size_1_t = 1, + padding: Union[str, _size_1_t] = 0, + dilation: _size_1_t = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = 'zeros', # TODO: refine this type + dtype=None + ) -> None: + factory_kwargs = {'dtype': dtype} + # we create new variables below to make mypy happy since kernel_size has + # type Union[int, Tuple[int]] and kernel_size_ has type Tuple[int] + kernel_size_ = _single(kernel_size) + stride_ = _single(stride) + padding_ = padding if isinstance(padding, str) else _single(padding) + dilation_ = _single(dilation) + super().__init__( + in_channels, out_channels, kernel_size_, stride_, padding_, dilation_, + False, _single(0), groups, bias, padding_mode, **factory_kwargs) + + pad_mode = 'valid' + pad = padding + if isinstance(padding, tuple): + if padding[0] != 0: + pad_mode = 'pad' + pad = (0, 0, padding[0], padding[0]) + elif isinstance(padding, int): + if padding != 0: + pad_mode = 'pad' + pad = (0, 0) + (padding,) * 2 + if not isinstance(padding, (int, tuple)): + pad_mode = padding + pad = (0,) * 4 + + if self.padding_mode != 'zeros': + pad_mode = 'valid' + pad = (0,) * 4 + self.conv2d = mops.Conv2D(out_channel=self.out_channels, + kernel_size=(1,) + self.kernel_size, + mode=1, + pad_mode=pad_mode, + pad=pad, + stride=(1,) + self.stride, + dilation=(1,) + self.dilation, + group=self.groups) + + def forward(self, input): + if self.padding_mode != 'zeros': + input = F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode) + input = input.expand_dims(2) + output = self.conv2d(input, self.weight.expand_dims(2)) + + if self.bias is not None: + output = mops.bias_add(output, self.bias) + + output = output.squeeze(2) + return output + + +class Conv2d(_ConvNd): + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: _size_2_t, + stride: _size_2_t = 1, + padding: Union[str, _size_2_t] = 0, + dilation: _size_2_t = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = 'zeros', + dtype=None + ) -> None: + factory_kwargs = {'dtype': dtype} + kernel_size_ = _pair(kernel_size) + stride_ = _pair(stride) + padding_ = padding if isinstance(padding, str) else _pair(padding) + dilation_ = _pair(dilation) + super().__init__( + in_channels, out_channels, kernel_size_, stride_, padding_, dilation_, + False, _pair(0), groups, bias, padding_mode, **factory_kwargs) + + pad_mode = 'pad' + pad = padding + if isinstance(padding, tuple): + pad = (padding[0], padding[0], padding[1], padding[1]) + elif isinstance(padding, int): + pad = (padding,) * 4 + if not isinstance(padding, (int, tuple)): + pad_mode = padding + pad = (0,) * 4 + + self.conv2d = mops.Conv2D(out_channel=self.out_channels, + kernel_size=self.kernel_size, + mode=1, + pad_mode=pad_mode, + pad=pad, + stride=self.stride, + dilation=self.dilation, + group=self.groups) + def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]): + if self.padding_mode != 'zeros': + input = ops.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode) + output = self.conv2d(input, weight) + if bias is not None: + output = mops.bias_add(output, bias) + return output + + def forward(self, input): + return self._conv_forward(input, self.weight, self.bias) + + + +class Conv3d(_ConvNd): + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: _size_2_t, + stride: _size_2_t = 1, + padding: Union[str, _size_2_t] = 0, + dilation: _size_2_t = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = 'zeros', + dtype=None + ) -> None: + factory_kwargs = {'dtype': dtype} + kernel_size_ = _pair(kernel_size) + stride_ = _pair(stride) + padding_ = padding if isinstance(padding, str) else _pair(padding) + dilation_ = dilation + super().__init__( + in_channels, out_channels, kernel_size_, stride_, padding_, dilation_, + False, _pair(0), groups, bias, padding_mode, **factory_kwargs) + + pad_mode = 'pad' + pad = padding + if isinstance(padding, tuple): + pad = (padding[0], padding[0], padding[1], padding[1]) + elif isinstance(padding, int): + pad = (padding,) * 6 + if not isinstance(padding, (int, tuple)): + pad_mode = padding + pad = (0,) * 6 + + self.conv3d = mops.Conv3D(out_channel=self.out_channels, + kernel_size=self.kernel_size, + mode=1, + pad_mode=pad_mode, + pad=pad, + stride=self.stride, + dilation=self.dilation, + group=self.groups) + + def forward(self, input): + if self.padding_mode != 'zeros': + input = ops.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode) + output = self.conv3d(input, self.weight) + if self.bias is not None: + output = mops.bias_add(output, self.bias) + return output + +# class Conv3d(_ConvNd): +# r"""Applies a 3D convolution over an input signal composed of several input +# planes. + +# In the simplest case, the output value of the layer with input size :math:`(N, C_{in}, D, H, W)` +# and output :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})` can be precisely described as: + +# .. math:: + +# \begin{array}{ll} +# out(N_i, C_{out_j}) = bias(C_{out_j}) +# + \sum_{{k}=0}^{C_{in}-1} weight(C_{out_j}, k) \star input(N_i, k) +# \end{array} + +# where :math:`\star` is the valid 3D `cross-correlation`_ operator + +# | :attr:`stride` controls the stride for the cross-correlation. +# | If :attr:`padding` is non-zero, then the input is implicitly zero-padded on both sides +# for :attr:`padding` number of points. +# | :attr:`dilation` controls the spacing between the kernel points; also known as the à trous algorithm. +# It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does. +# | :attr:`groups` controls the connections between inputs and outputs. `in_channels` and `out_channels` +# must both be divisible by `groups`. +# | At groups=1, all inputs are convolved to all outputs. +# | At groups=2, the operation becomes equivalent to having two conv layers +# side by side, each seeing half the input channels, +# and producing half the output channels, and both subsequently concatenated. +# At groups=`in_channels`, each input channel is convolved with its own set of filters +# (of size `out_channels // in_channels`). + +# The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`dilation` can either be: + +# - a single ``int`` -- in which case the same value is used for the depth, height and width dimension +# - a ``tuple`` of three ints -- in which case, the first `int` is used for the depth dimension, +# the second `int` for the height dimension and the third `int` for the width dimension + +# .. note:: + +# Depending of the size of your kernel, several (of the last) +# columns of the input might be lost, because it is a valid `cross-correlation`_, +# and not a full `cross-correlation`_. +# It is up to the user to add proper padding. + +# Args: +# in_channels (int): Number of channels in the input image +# out_channels (int): Number of channels produced by the convolution +# kernel_size (int or tuple): Size of the convolving kernel +# stride (int or tuple, optional): Stride of the convolution +# padding (int or tuple, optional): Zero-padding added to both sides of the input +# dilation (int or tuple, optional): Spacing between kernel elements +# groups (int, optional): Number of blocked connections from input channels to output channels +# bias (bool, optional): If True, adds a learnable bias to the output + +# Shape: +# - Input: :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})` +# - Output: :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})` where +# :math:`D_{out} = floor((D_{in} + 2 * padding[0] - dilation[0] * (kernel\_size[0] - 1) - 1) / stride[0] + 1)` +# :math:`H_{out} = floor((H_{in} + 2 * padding[1] - dilation[1] * (kernel\_size[1] - 1) - 1) / stride[1] + 1)` +# :math:`W_{out} = floor((W_{in} + 2 * padding[2] - dilation[2] * (kernel\_size[2] - 1) - 1) / stride[2] + 1)` + +# Attributes: +# weight (Tensor): the learnable weights of the module of shape +# (out_channels, in_channels, kernel_size[0], kernel_size[1], kernel_size[2]) +# bias (Tensor): the learnable bias of the module of shape (out_channels) + +# Examples:: + +# >>> # With square kernels and equal stride +# >>> m = nn.Conv3d(16, 33, 3, stride=2) +# >>> # non-square kernels and unequal stride and with padding +# >>> m = nn.Conv3d(16, 33, (3, 5, 2), stride=(2, 1, 1), padding=(4, 2, 0)) +# >>> input = autograd.Variable(core.randn(20, 16, 10, 50, 100)) +# >>> output = m(input) + +# .. _cross-correlation: +# https://en.wikipedia.org/wiki/Cross-correlation + +# .. _link: +# https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md +# """ + +# def __init__(self, in_channels, out_channels, kernel_size, stride=1, +# padding=0, dilation=1, groups=1, bias=True): +# kernel_size = _triple(kernel_size) +# stride = _triple(stride) +# padding = _triple(padding) +# dilation = _triple(dilation) +# super(Conv3d, self).__init__( +# in_channels, out_channels, kernel_size, stride, padding, dilation, +# False, _triple(0), groups, bias) + +# def forward(self, input): +# return ops.conv3d(input, self.weight, self.bias, self.stride, +# self.padding, self.dilation, self.groups) + + +class _ConvTransposeNd(_ConvNd): + def __init__(self, in_channels, out_channels, kernel_size, stride, + padding, dilation, transposed, output_padding, + groups, bias, padding_mode, dtype=None) -> None: + if padding_mode != 'zeros': + raise ValueError(f'Only "zeros" padding mode is supported for {self.__class__.__name__}') + + factory_kwargs = {'dtype': dtype} + super().__init__( + in_channels, out_channels, kernel_size, stride, + padding, dilation, transposed, output_padding, + groups, bias, padding_mode, **factory_kwargs) + + # dilation being an optional parameter is for backwards + # compatibility + def _output_padding(self, input: Tensor, output_size: Optional[List[int]], + stride: List[int], padding: List[int], kernel_size: List[int], + num_spatial_dims: int, dilation: Optional[List[int]] = None) -> List[int]: + if output_size is None: + ret = _single(self.output_padding) # converting to list if was not already + else: + has_batch_dim = input.dim() == num_spatial_dims + 2 + num_non_spatial_dims = 2 if has_batch_dim else 1 + if len(output_size) == num_non_spatial_dims + num_spatial_dims: + output_size = output_size[num_non_spatial_dims:] + if len(output_size) != num_spatial_dims: + raise ValueError( + f"ConvTranspose{num_spatial_dims}D: for {input.dim()}D input, output_size must have {num_spatial_dims} or {num_non_spatial_dims + num_spatial_dims} elements (got {len(output_size)})") + + min_sizes = [] + max_sizes = [] + for d in range(num_spatial_dims): + dim_size = ((input.size(d + num_non_spatial_dims) - 1) * stride[d] - + 2 * padding[d] + + (dilation[d] if dilation is not None else 1) * (kernel_size[d] - 1) + 1) + min_sizes.append(dim_size) + max_sizes.append(min_sizes[d] + stride[d] - 1) + + for i in range(len(output_size)): + size = output_size[i] + min_size = min_sizes[i] + max_size = max_sizes[i] + if size < min_size or size > max_size: + raise ValueError( + f"requested an output size of {output_size}, but valid sizes range " + f"from {min_sizes} to {max_sizes} (for an input of {input.size()[2:]})") + + res = [] + for d in range(num_spatial_dims): + res.append(output_size[d] - min_sizes[d]) + + ret = res + return ret + +class ConvTranspose1d(_ConvTransposeNd): + """Applies a 1D transposed convolution operator over an input image + composed of several input planes. + + This module can be seen as the gradient of Conv1d with respect to its input. + It is also known as a fractionally-strided convolution or + a deconvolution (although it is not an actual deconvolution operation). + + | :attr:`stride` controls the stride for the cross-correlation. + | If :attr:`padding` is non-zero, then the input is implicitly zero-padded on both sides + for :attr:`padding` number of points. + | If :attr:`output_padding` is non-zero, then the output is implicitly zero-padded on one side + for :attr:`output_padding` number of points. + | :attr:`dilation` controls the spacing between the kernel points; also known as the à trous algorithm. + It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does. + | :attr:`groups` controls the connections between inputs and outputs. `in_channels` and `out_channels` + must both be divisible by `groups`. + | At groups=1, all inputs are convolved to all outputs. + | At groups=2, the operation becomes equivalent to having two conv layers + side by side, each seeing half the input channels, + and producing half the output channels, and both subsequently concatenated. + At groups=`in_channels`, each input channel is convolved with its own set of filters + (of size `out_channels // in_channels`). + + .. note:: + + Depending of the size of your kernel, several (of the last) + columns of the input might be lost, because it is a valid `cross-correlation`_, + and not a full `cross-correlation`_. + It is up to the user to add proper padding. + + Args: + in_channels (int): Number of channels in the input image + out_channels (int): Number of channels produced by the convolution + kernel_size (int or tuple): Size of the convolving kernel + stride (int or tuple, optional): Stride of the convolution + padding (int or tuple, optional): Zero-padding added to both sides of the input + output_padding (int or tuple, optional): Zero-padding added to one side of the output + groups (int, optional): Number of blocked connections from input channels to output channels + bias (bool, optional): If True, adds a learnable bias to the output + dilation (int or tuple, optional): Spacing between kernel elements + + Shape: + - Input: :math:`(N, C_{in}, L_{in})` + - Output: :math:`(N, C_{out}, L_{out})` where + :math:`L_{out} = (L_{in} - 1) * stride - 2 * padding + kernel\_size + output\_padding` + + Attributes: + weight (Tensor): the learnable weights of the module of shape + (in_channels, out_channels, kernel_size[0], kernel_size[1]) + bias (Tensor): the learnable bias of the module of shape (out_channels) + """ + + def __init__(self, in_channels, out_channels, kernel_size, stride=1, + padding=0, output_padding=0, groups=1, bias=True, dilation=1, padding_mode: str = 'zeros'): + kernel_size = _single(kernel_size) + stride = _single(stride) + padding = _single(padding) + dilation = _single(dilation) + output_padding = _single(output_padding) + super(ConvTranspose1d, self).__init__( + in_channels, out_channels, kernel_size, stride, padding, dilation, + True, output_padding, groups, bias, padding_mode) + + pad_mode = 'pad' + pad = padding + if isinstance(padding, tuple): + pad = (0, 0, padding[0], padding[0]) + elif isinstance(padding, int): + pad = (0, 0) + (padding,) * 2 + if not isinstance(padding, (int, tuple)): + pad_mode = padding + pad = (0,) * 4 + + # cause Conv2DTranspose's out_channel refers to Conv2D's out_channel. + self.conv2d_transpose = mops.Conv2DTranspose(out_channel=self.out_channels, + kernel_size=(1,) + self.kernel_size, + mode=1, + pad_mode=pad_mode, + pad=pad, + stride=(1,) + self.stride, + dilation=(1,) + self.dilation, + group=self.groups) + self.h_add = _deconv_output_length(pad_mode, 1, 1, 1, pad[0] + pad[1]) + self.w_add = _deconv_output_length(pad_mode, kernel_size[0], stride[0], dilation[0], pad[2] + pad[3]) + + def forward(self, input, output_size=None): + if self.padding_mode != 'zeros': + raise ValueError('Only `zeros` padding mode is supported for ConvTranspose2d') + + assert isinstance(self.padding, tuple) + # One cannot replace List by Tuple or Sequence in "_output_padding" because + # TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`. + num_spatial_dims = 1 + output_padding = self._output_padding( + input, output_size, self.stride, self.padding, self.kernel_size, # type: ignore[arg-type] + num_spatial_dims, self.dilation) # type: ignore[arg-type] + input = mops.expand_dims(input, 2) + n, _, h, w = input.shape + conv2d_trans_ret = self.conv2d_transpose(input, self.weight.expand_dims(2), + (n, self.out_channels, + h + self.h_add, + w * self.stride[0] + self.w_add)) + if self.bias is not None: + conv2d_trans_ret = mops.bias_add(conv2d_trans_ret, self.bias) + + conv2d_trans_ret = conv2d_trans_ret.squeeze(2) + conv2d_trans_ret = ops.pad(conv2d_trans_ret, (0,) + output_padding, value=0.) + return conv2d_trans_ret + + +def _deconv_output_length(pad_mode, filter_size, stride_size, dilation_size, padding): + """Calculate the width and height of output.""" + length = 0 + filter_size = filter_size + (filter_size - 1) * (dilation_size - 1) + if pad_mode == 'valid': + if filter_size - stride_size > 0: + length = filter_size - stride_size + elif pad_mode == 'pad': + length = - padding + filter_size - stride_size + + return length + +class ConvTranspose2d(_ConvTransposeNd): + r"""Applies a 2D transposed convolution operator over an input image + composed of several input planes. + + This module can be seen as the gradient of Conv2d with respect to its input. + It is also known as a fractionally-strided convolution or + a deconvolution (although it is not an actual deconvolution operation). + + | :attr:`stride` controls the stride for the cross-correlation. + | If :attr:`padding` is non-zero, then the input is implicitly zero-padded on both sides + for :attr:`padding` number of points. + | If :attr:`output_padding` is non-zero, then the output is implicitly zero-padded on one side + for :attr:`output_padding` number of points. + | :attr:`dilation` controls the spacing between the kernel points; also known as the à trous algorithm. + It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does. + | :attr:`groups` controls the connections between inputs and outputs. `in_channels` and `out_channels` + must both be divisible by `groups`. + | At groups=1, all inputs are convolved to all outputs. + | At groups=2, the operation becomes equivalent to having two conv layers + side by side, each seeing half the input channels, + and producing half the output channels, and both subsequently concatenated. + At groups=`in_channels`, each input channel is convolved with its own set of filters + (of size `out_channels // in_channels`). + + The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`output_padding` + can either be: + + - a single ``int`` -- in which case the same value is used for the height and width dimensions + - a ``tuple`` of two ints -- in which case, the first `int` is used for the height dimension, + and the second `int` for the width dimension + + .. note:: + + Depending of the size of your kernel, several (of the last) + columns of the input might be lost, because it is a valid `cross-correlation`_, + and not a full `cross-correlation`_. + It is up to the user to add proper padding. + + Args: + in_channels (int): Number of channels in the input image + out_channels (int): Number of channels produced by the convolution + kernel_size (int or tuple): Size of the convolving kernel + stride (int or tuple, optional): Stride of the convolution + padding (int or tuple, optional): Zero-padding added to both sides of the input + output_padding (int or tuple, optional): Zero-padding added to one side of the output + groups (int, optional): Number of blocked connections from input channels to output channels + bias (bool, optional): If True, adds a learnable bias to the output + dilation (int or tuple, optional): Spacing between kernel elements + + Shape: + - Input: :math:`(N, C_{in}, H_{in}, W_{in})` + - Output: :math:`(N, C_{out}, H_{out}, W_{out})` where + :math:`H_{out} = (H_{in} - 1) * stride[0] - 2 * padding[0] + kernel\_size[0] + output\_padding[0]` + :math:`W_{out} = (W_{in} - 1) * stride[1] - 2 * padding[1] + kernel\_size[1] + output\_padding[1]` + + Attributes: + weight (Tensor): the learnable weights of the module of shape + (in_channels, out_channels, kernel_size[0], kernel_size[1]) + bias (Tensor): the learnable bias of the module of shape (out_channels) + + Examples:: + + >>> # With square kernels and equal stride + >>> m = nn.ConvTranspose2d(16, 33, 3, stride=2) + >>> # non-square kernels and unequal stride and with padding + >>> m = nn.ConvTranspose2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2)) + >>> input = autograd.Variable(core.randn(20, 16, 50, 100)) + >>> output = m(input) + >>> # exact output size can be also specified as an argument + >>> input = autograd.Variable(core.randn(1, 16, 12, 12)) + >>> downsample = nn.Conv2d(16, 16, 3, stride=2, padding=1) + >>> upsample = nn.ConvTranspose2d(16, 16, 3, stride=2, padding=1) + >>> h = downsample(input) + >>> h.size() + core.Size([1, 16, 6, 6]) + >>> output = upsample(h, output_size=input.size()) + >>> output.size() + core.Size([1, 16, 12, 12]) + + .. _cross-correlation: + https://en.wikipedia.org/wiki/Cross-correlation + + .. _link: + https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md + """ + + def __init__(self, in_channels, out_channels, kernel_size, stride=1, + padding=0, output_padding=0, groups=1, bias=True, dilation=1, + padding_mode='zeros', dtype=None): + factory_kwargs = {'dtype': dtype} + kernel_size = _pair(kernel_size) + stride = _pair(stride) + padding = _pair(padding) + dilation = _pair(dilation) + output_padding = _pair(output_padding) + super().__init__( + in_channels, out_channels, kernel_size, stride, padding, dilation, + True, output_padding, groups, bias, padding_mode, **factory_kwargs) + + pad_mode = 'pad' + pad = padding + if isinstance(padding, tuple): + pad = (padding[0], padding[0], padding[1], padding[1]) + elif isinstance(padding, int): + pad = (padding,) * 4 + if not isinstance(padding, (int, tuple)): + pad_mode = padding + pad = (0,) * 4 + + # cause Conv2DTranspose's out_channel refers to Conv2D's out_channel. + self.conv2d_transpose = mops.Conv2DTranspose(out_channel=in_channels, + kernel_size=kernel_size, + mode=1, + pad_mode=pad_mode, + pad=pad, + stride=stride, + dilation=dilation, + group=groups) + + self.h_add = _deconv_output_length(pad_mode, kernel_size[0], stride[0], dilation[0], pad[0] + pad[1]) + self.w_add = _deconv_output_length(pad_mode, kernel_size[1], stride[1], dilation[1], pad[2] + pad[3]) + + def forward(self, input, output_size=None): + if self.padding_mode != 'zeros': + raise ValueError('Only `zeros` padding mode is supported for ConvTranspose2d') + + assert isinstance(self.padding, tuple) + # One cannot replace List by Tuple or Sequence in "_output_padding" because + # TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`. + num_spatial_dims = 2 + output_padding = self._output_padding( + input, output_size, self.stride, self.padding, self.kernel_size, # type: ignore[arg-type] + num_spatial_dims, self.dilation) # type: ignore[arg-type] + + n, _, h, w = input.shape + conv2d_trans_ret = self.conv2d_transpose(input, self.weight, + (n, self.out_channels, + h * self.stride[0] + self.h_add, + w * self.stride[1] + self.w_add)) + if self.bias is not None: + conv2d_trans_ret = mops.bias_add(conv2d_trans_ret, self.bias) + + conv2d_trans_ret = ops.pad(conv2d_trans_ret, output_padding, value=0.) + + return conv2d_trans_ret + + +# class ConvTranspose3d(_ConvTransposeNd): +# r"""Applies a 3D transposed convolution operator over an input image composed of several input +# planes. +# The transposed convolution operator multiplies each input value element-wise by a learnable kernel, +# and sums over the outputs from all input feature planes. + +# This module can be seen as the gradient of Conv3d with respect to its input. +# It is also known as a fractionally-strided convolution or +# a deconvolution (although it is not an actual deconvolution operation). + +# | :attr:`stride` controls the stride for the cross-correlation. +# | If :attr:`padding` is non-zero, then the input is implicitly zero-padded on both sides +# for :attr:`padding` number of points. +# | If :attr:`output_padding` is non-zero, then the output is implicitly zero-padded on one side +# for :attr:`output_padding` number of points. +# | :attr:`dilation` controls the spacing between the kernel points; also known as the à trous algorithm. +# It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does. +# | :attr:`groups` controls the connections between inputs and outputs. `in_channels` and `out_channels` +# must both be divisible by `groups`. +# | At groups=1, all inputs are convolved to all outputs. +# | At groups=2, the operation becomes equivalent to having two conv layers +# side by side, each seeing half the input channels, +# and producing half the output channels, and both subsequently concatenated. +# At groups=`in_channels`, each input channel is convolved with its own set of filters +# (of size `out_channels // in_channels`). + +# The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`output_padding` +# can either be: + +# - a single ``int`` -- in which case the same value is used for the depth, height and width dimensions +# - a ``tuple`` of three ints -- in which case, the first `int` is used for the depth dimension, +# the second `int` for the height dimension and the third `int` for the width dimension + +# .. note:: + +# Depending of the size of your kernel, several (of the last) +# columns of the input might be lost, because it is a valid `cross-correlation`_, +# and not a full `cross-correlation`_. +# It is up to the user to add proper padding. + +# Args: +# in_channels (int): Number of channels in the input image +# out_channels (int): Number of channels produced by the convolution +# kernel_size (int or tuple): Size of the convolving kernel +# stride (int or tuple, optional): Stride of the convolution +# padding (int or tuple, optional): Zero-padding added to both sides of the input +# output_padding (int or tuple, optional): Zero-padding added to one side of the output +# groups (int, optional): Number of blocked connections from input channels to output channels +# bias (bool, optional): If True, adds a learnable bias to the output +# dilation (int or tuple, optional): Spacing between kernel elements + +# Shape: +# - Input: :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})` +# - Output: :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})` where +# :math:`D_{out} = (D_{in} - 1) * stride[0] - 2 * padding[0] + kernel\_size[0] + output\_padding[0]` +# :math:`H_{out} = (H_{in} - 1) * stride[1] - 2 * padding[1] + kernel\_size[1] + output\_padding[1]` +# :math:`W_{out} = (W_{in} - 1) * stride[2] - 2 * padding[2] + kernel\_size[2] + output\_padding[2]` + +# Attributes: +# weight (Tensor): the learnable weights of the module of shape +# (in_channels, out_channels, kernel_size[0], kernel_size[1], kernel_size[2]) +# bias (Tensor): the learnable bias of the module of shape (out_channels) + +# Examples:: + +# >>> # With square kernels and equal stride +# >>> m = nn.ConvTranspose3d(16, 33, 3, stride=2) +# >>> # non-square kernels and unequal stride and with padding +# >>> m = nn.Conv3d(16, 33, (3, 5, 2), stride=(2, 1, 1), padding=(0, 4, 2)) +# >>> input = autograd.Variable(core.randn(20, 16, 10, 50, 100)) +# >>> output = m(input) + +# .. _cross-correlation: +# https://en.wikipedia.org/wiki/Cross-correlation + +# .. _link: +# https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md +# """ + +# def __init__(self, in_channels, out_channels, kernel_size, stride=1, +# padding=0, output_padding=0, groups=1, bias=True, dilation=1): +# kernel_size = _triple(kernel_size) +# stride = _triple(stride) +# padding = _triple(padding) +# dilation = _triple(dilation) +# output_padding = _triple(output_padding) +# super(ConvTranspose3d, self).__init__( +# in_channels, out_channels, kernel_size, stride, padding, dilation, +# True, output_padding, groups, bias) + +# def forward(self, input, output_size=None): +# output_padding = self._output_padding(input, output_size) +# return F.conv_transpose3d( +# input, self.weight, self.bias, self.stride, self.padding, +# output_padding, self.groups, self.dilation) + + +# TODO: Conv2dLocal +# TODO: Conv2dMap +# TODO: ConvTranspose2dMap diff --git a/mindnlp/core/nn/modules/distance.py b/mindnlp/core/nn/modules/distance.py new file mode 100644 index 000000000..41518d66d --- /dev/null +++ b/mindnlp/core/nn/modules/distance.py @@ -0,0 +1,94 @@ +"""distance""" +from mindnlp.core import Tensor + +from .module import Module +from .. import functional as F + + +__all__ = ['CosineSimilarity'] + + +class PairwiseDistance(Module): + r""" + Computes the pairwise distance between input vectors, or between columns of input matrices. + + Distances are computed using ``p``-norm, with constant ``eps`` added to avoid division by zero + if ``p`` is negative, i.e.: + + .. math :: + \mathrm{dist}\left(x, y\right) = \left\Vert x-y + \epsilon e \right\Vert_p, + + where :math:`e` is the vector of ones and the ``p``-norm is given by. + + .. math :: + \Vert x \Vert _p = \left( \sum_{i=1}^n \vert x_i \vert ^ p \right) ^ {1/p}. + + Args: + p (real, optional): the norm degree. Can be negative. Default: 2 + eps (float, optional): Small value to avoid division by zero. + Default: 1e-6 + keepdim (bool, optional): Determines whether or not to keep the vector dimension. + Default: False + Shape: + - Input1: :math:`(N, D)` or :math:`(D)` where `N = batch dimension` and `D = vector dimension` + - Input2: :math:`(N, D)` or :math:`(D)`, same shape as the Input1 + - Output: :math:`(N)` or :math:`()` based on input dimension. + If :attr:`keepdim` is ``True``, then :math:`(N, 1)` or :math:`(1)` based on input dimension. + + Examples:: + >>> pdist = nn.PairwiseDistance(p=2) + >>> input1 = core.randn(100, 128) + >>> input2 = core.randn(100, 128) + >>> output = pdist(input1, input2) + """ + + __constants__ = ['norm', 'eps', 'keepdim'] + norm: float + eps: float + keepdim: bool + + def __init__(self, p: float = 2., eps: float = 1e-6, keepdim: bool = False) -> None: + super().__init__() + self.norm = p + self.eps = eps + self.keepdim = keepdim + + def forward(self, x1: Tensor, x2: Tensor) -> Tensor: + return F.pairwise_distance(x1, x2, self.norm, self.eps, self.keepdim) + + + + +class CosineSimilarity(Module): + r"""Returns cosine similarity between :math:`x_1` and :math:`x_2`, computed along `dim`. + + .. math :: + \text{similarity} = \dfrac{x_1 \cdot x_2}{\max(\Vert x_1 \Vert _2 \cdot \Vert x_2 \Vert _2, \epsilon)}. + + Args: + dim (int, optional): Dimension where cosine similarity is computed. Default: 1 + eps (float, optional): Small value to avoid division by zero. + Default: 1e-8 + Shape: + - Input1: :math:`(\ast_1, D, \ast_2)` where D is at position `dim` + - Input2: :math:`(\ast_1, D, \ast_2)`, same number of dimensions as x1, matching x1 size at dimension `dim`, + and broadcastable with x1 at other dimensions. + - Output: :math:`(\ast_1, \ast_2)` + Examples:: + >>> input1 = core.randn(100, 128) + >>> input2 = core.randn(100, 128) + >>> cos = nn.CosineSimilarity(dim=1, eps=1e-6) + >>> output = cos(input1, input2) + """ + + __constants__ = ['dim', 'eps'] + dim: int + eps: float + + def __init__(self, dim: int = 1, eps: float = 1e-8) -> None: + super().__init__() + self.dim = dim + self.eps = eps + + def forward(self, x1: Tensor, x2: Tensor) -> Tensor: + return F.cosine_similarity(x1, x2, self.dim, self.eps) diff --git a/mindnlp/core/nn/modules/dropout.py b/mindnlp/core/nn/modules/dropout.py new file mode 100644 index 000000000..550b12763 --- /dev/null +++ b/mindnlp/core/nn/modules/dropout.py @@ -0,0 +1,299 @@ +"""dropout""" +from mindnlp.core import Tensor + +from .module import Module +from .. import functional as F + + +__all__ = ['Dropout', 'Dropout1d', 'Dropout2d', 'Dropout3d', 'AlphaDropout', 'FeatureAlphaDropout'] + +class _DropoutNd(Module): + __constants__ = ['p', 'inplace'] + p: float + inplace: bool + + def __init__(self, p: float = 0.5, inplace: bool = False) -> None: + super().__init__() + if p < 0 or p > 1: + raise ValueError(f"dropout probability has to be between 0 and 1, but got {p}") + self.p = float(p) + self.inplace = inplace + + def extra_repr(self) -> str: + return f'p={self.p}, inplace={self.inplace}' + + + +class Dropout(_DropoutNd): + r"""During training, randomly zeroes some of the elements of the input tensor with probability :attr:`p`. + + The zeroed elements are chosen independently for each forward call and are sampled from a Bernoulli distribution. + + Each channel will be zeroed out independently on every forward call. + + This has proven to be an effective technique for regularization and + preventing the co-adaptation of neurons as described in the paper + `Improving neural networks by preventing co-adaptation of feature + detectors`_ . + + Furthermore, the outputs are scaled by a factor of :math:`\frac{1}{1-p}` during + training. This means that during evaluation the module simply computes an + identity function. + + Args: + p: probability of an element to be zeroed. Default: 0.5 + inplace: If set to ``True``, will do this operation in-place. Default: ``False`` + + Shape: + - Input: :math:`(*)`. Input can be of any shape + - Output: :math:`(*)`. Output is of the same shape as input + + Examples:: + + >>> m = nn.Dropout(p=0.2) + >>> input = core.randn(20, 16) + >>> output = m(input) + + .. _Improving neural networks by preventing co-adaptation of feature + detectors: https://arxiv.org/abs/1207.0580 + """ + + def forward(self, input: Tensor) -> Tensor: + return F.dropout(input, self.p, self.training) + + +class Dropout1d(_DropoutNd): + r"""Randomly zero out entire channels. + + A channel is a 1D feature map, + e.g., the :math:`j`-th channel of the :math:`i`-th sample in the + batched input is a 1D tensor :math:`\text{input}[i, j]`. + + Each channel will be zeroed out independently on every forward call with + probability :attr:`p` using samples from a Bernoulli distribution. + + Usually the input comes from :class:`nn.Conv1d` modules. + + As described in the paper + `Efficient Object Localization Using Convolutional Networks`_ , + if adjacent pixels within feature maps are strongly correlated + (as is normally the case in early convolution layers) then i.i.d. dropout + will not regularize the activations and will otherwise just result + in an effective learning rate decrease. + + In this case, :func:`nn.Dropout1d` will help promote independence between + feature maps and should be used instead. + + Args: + p (float, optional): probability of an element to be zero-ed. + inplace (bool, optional): If set to ``True``, will do this operation + in-place + + Shape: + - Input: :math:`(N, C, L)` or :math:`(C, L)`. + - Output: :math:`(N, C, L)` or :math:`(C, L)` (same shape as input). + + Examples:: + + >>> m = nn.Dropout1d(p=0.2) + >>> input = core.randn(20, 16, 32) + >>> output = m(input) + + .. _Efficient Object Localization Using Convolutional Networks: + https://arxiv.org/abs/1411.4280 + """ + + def forward(self, input: Tensor) -> Tensor: + return F.dropout1d(input, self.p, self.training) + + +class Dropout2d(_DropoutNd): + r"""Randomly zero out entire channels. + + A channel is a 2D feature map, + e.g., the :math:`j`-th channel of the :math:`i`-th sample in the + batched input is a 2D tensor :math:`\text{input}[i, j]`. + + Each channel will be zeroed out independently on every forward call with + probability :attr:`p` using samples from a Bernoulli distribution. + + Usually the input comes from :class:`nn.Conv2d` modules. + + As described in the paper + `Efficient Object Localization Using Convolutional Networks`_ , + if adjacent pixels within feature maps are strongly correlated + (as is normally the case in early convolution layers) then i.i.d. dropout + will not regularize the activations and will otherwise just result + in an effective learning rate decrease. + + In this case, :func:`nn.Dropout2d` will help promote independence between + feature maps and should be used instead. + + Args: + p (float, optional): probability of an element to be zero-ed. + inplace (bool, optional): If set to ``True``, will do this operation + in-place + + .. warning :: + Due to historical reasons, this class will perform 1D channel-wise dropout + for 3D inputs (as done by :class:`nn.Dropout1d`). Thus, it currently does NOT + support inputs without a batch dimension of shape :math:`(C, H, W)`. This + behavior will change in a future release to interpret 3D inputs as no-batch-dim + inputs. To maintain the old behavior, switch to :class:`nn.Dropout1d`. + + Shape: + - Input: :math:`(N, C, H, W)` or :math:`(N, C, L)`. + - Output: :math:`(N, C, H, W)` or :math:`(N, C, L)` (same shape as input). + + Examples:: + + >>> m = nn.Dropout2d(p=0.2) + >>> input = core.randn(20, 16, 32, 32) + >>> output = m(input) + + .. _Efficient Object Localization Using Convolutional Networks: + https://arxiv.org/abs/1411.4280 + """ + + def forward(self, input: Tensor) -> Tensor: + return F.dropout2d(input, self.p, self.training) + + + + +class Dropout3d(_DropoutNd): + r"""Randomly zero out entire channels. + + A channel is a 3D feature map, + e.g., the :math:`j`-th channel of the :math:`i`-th sample in the + batched input is a 3D tensor :math:`\text{input}[i, j]`. + + Each channel will be zeroed out independently on every forward call with + probability :attr:`p` using samples from a Bernoulli distribution. + + Usually the input comes from :class:`nn.Conv3d` modules. + + As described in the paper + `Efficient Object Localization Using Convolutional Networks`_ , + if adjacent pixels within feature maps are strongly correlated + (as is normally the case in early convolution layers) then i.i.d. dropout + will not regularize the activations and will otherwise just result + in an effective learning rate decrease. + + In this case, :func:`nn.Dropout3d` will help promote independence between + feature maps and should be used instead. + + Args: + p (float, optional): probability of an element to be zeroed. + inplace (bool, optional): If set to ``True``, will do this operation + in-place + + Shape: + - Input: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)`. + - Output: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)` (same shape as input). + + Examples:: + + >>> m = nn.Dropout3d(p=0.2) + >>> input = core.randn(20, 16, 4, 32, 32) + >>> output = m(input) + + .. _Efficient Object Localization Using Convolutional Networks: + https://arxiv.org/abs/1411.4280 + """ + + def forward(self, input: Tensor) -> Tensor: + return F.dropout3d(input, self.p, self.training, self.inplace) + + +class AlphaDropout(_DropoutNd): + r"""Applies Alpha Dropout over the input. + + Alpha Dropout is a type of Dropout that maintains the self-normalizing + property. + For an input with zero mean and unit standard deviation, the output of + Alpha Dropout maintains the original mean and standard deviation of the + input. + Alpha Dropout goes hand-in-hand with SELU activation function, which ensures + that the outputs have zero mean and unit standard deviation. + + During training, it randomly masks some of the elements of the input + tensor with probability *p* using samples from a bernoulli distribution. + The elements to masked are randomized on every forward call, and scaled + and shifted to maintain zero mean and unit standard deviation. + + During evaluation the module simply computes an identity function. + + More details can be found in the paper `Self-Normalizing Neural Networks`_ . + + Args: + p (float): probability of an element to be dropped. Default: 0.5 + inplace (bool, optional): If set to ``True``, will do this operation + in-place + + Shape: + - Input: :math:`(*)`. Input can be of any shape + - Output: :math:`(*)`. Output is of the same shape as input + + Examples:: + + >>> m = nn.AlphaDropout(p=0.2) + >>> input = core.randn(20, 16) + >>> output = m(input) + + .. _Self-Normalizing Neural Networks: https://arxiv.org/abs/1706.02515 + """ + + def forward(self, input: Tensor) -> Tensor: + return F.alpha_dropout(input, self.p, self.training) + + +class FeatureAlphaDropout(_DropoutNd): + r"""Randomly masks out entire channels. + + A channel is a feature map, + e.g. the :math:`j`-th channel of the :math:`i`-th sample in the batch input + is a tensor :math:`\text{input}[i, j]` of the input tensor). Instead of + setting activations to zero, as in regular Dropout, the activations are set + to the negative saturation value of the SELU activation function. More details + can be found in the paper `Self-Normalizing Neural Networks`_ . + + Each element will be masked independently for each sample on every forward + call with probability :attr:`p` using samples from a Bernoulli distribution. + The elements to be masked are randomized on every forward call, and scaled + and shifted to maintain zero mean and unit variance. + + Usually the input comes from :class:`nn.AlphaDropout` modules. + + As described in the paper + `Efficient Object Localization Using Convolutional Networks`_ , + if adjacent pixels within feature maps are strongly correlated + (as is normally the case in early convolution layers) then i.i.d. dropout + will not regularize the activations and will otherwise just result + in an effective learning rate decrease. + + In this case, :func:`nn.AlphaDropout` will help promote independence between + feature maps and should be used instead. + + Args: + p (float, optional): probability of an element to be zeroed. Default: 0.5 + inplace (bool, optional): If set to ``True``, will do this operation + in-place + + Shape: + - Input: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)`. + - Output: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)` (same shape as input). + + Examples:: + + >>> m = nn.FeatureAlphaDropout(p=0.2) + >>> input = core.randn(20, 16, 4, 32, 32) + >>> output = m(input) + + .. _Self-Normalizing Neural Networks: https://arxiv.org/abs/1706.02515 + .. _Efficient Object Localization Using Convolutional Networks: + https://arxiv.org/abs/1411.4280 + """ + + def forward(self, input: Tensor) -> Tensor: + return F.feature_alpha_dropout(input, self.p, self.training) diff --git a/mindnlp/core/nn/modules/flatten.py b/mindnlp/core/nn/modules/flatten.py new file mode 100644 index 000000000..8a2eed8ba --- /dev/null +++ b/mindnlp/core/nn/modules/flatten.py @@ -0,0 +1,150 @@ +"""flatten""" +from typing import Tuple, Union, List +from mindnlp.core import Tensor + +from .module import Module +from ...ops import flatten, unflatten + +__all__ = ['Flatten', 'Unflatten'] + + +_size = Union[List[int], Tuple[int, ...]] + +class Flatten(Module): + r""" + Flattens a contiguous range of dims into a tensor. + + For use with :class:`~nn.Sequential`, see :meth:`core.flatten` for details. + + Shape: + - Input: :math:`(*, S_{\text{start}},..., S_{i}, ..., S_{\text{end}}, *)`,' + where :math:`S_{i}` is the size at dimension :math:`i` and :math:`*` means any + number of dimensions including none. + - Output: :math:`(*, \prod_{i=\text{start}}^{\text{end}} S_{i}, *)`. + + Args: + start_dim: first dim to flatten (default = 1). + end_dim: last dim to flatten (default = -1). + + Examples:: + >>> input = core.randn(32, 1, 5, 5) + >>> # With default parameters + >>> m = nn.Flatten() + >>> output = m(input) + >>> output.size() + core.Size([32, 25]) + >>> # With non-default parameters + >>> m = nn.Flatten(0, 2) + >>> output = m(input) + >>> output.size() + core.Size([160, 5]) + """ + + __constants__ = ['start_dim', 'end_dim'] + start_dim: int + end_dim: int + + def __init__(self, start_dim: int = 1, end_dim: int = -1) -> None: + super().__init__() + self.start_dim = start_dim + self.end_dim = end_dim + + def forward(self, input: Tensor) -> Tensor: + return flatten(input, self.start_dim, self.end_dim) + + def extra_repr(self) -> str: + return f'start_dim={self.start_dim}, end_dim={self.end_dim}' + + + + +class Unflatten(Module): + r""" + Unflattens a tensor dim expanding it to a desired shape. For use with :class:`~nn.Sequential`. + + * :attr:`dim` specifies the dimension of the input tensor to be unflattened, and it can + be either `int` or `str` when `Tensor` or `NamedTensor` is used, respectively. + + * :attr:`unflattened_size` is the new shape of the unflattened dimension of the tensor and it can be + a `tuple` of ints or a `list` of ints or `core.Size` for `Tensor` input; a `NamedShape` + (tuple of `(name, size)` tuples) for `NamedTensor` input. + + Shape: + - Input: :math:`(*, S_{\text{dim}}, *)`, where :math:`S_{\text{dim}}` is the size at + dimension :attr:`dim` and :math:`*` means any number of dimensions including none. + - Output: :math:`(*, U_1, ..., U_n, *)`, where :math:`U` = :attr:`unflattened_size` and + :math:`\prod_{i=1}^n U_i = S_{\text{dim}}`. + + Args: + dim (Union[int, str]): Dimension to be unflattened + unflattened_size (Union[core.Size, Tuple, List, NamedShape]): New shape of the unflattened dimension + + Examples: + >>> input = core.randn(2, 50) + >>> # With tuple of ints + >>> m = nn.Sequential( + >>> nn.Linear(50, 50), + >>> nn.Unflatten(1, (2, 5, 5)) + >>> ) + >>> output = m(input) + >>> output.size() + core.Size([2, 2, 5, 5]) + >>> # With core.Size + >>> m = nn.Sequential( + >>> nn.Linear(50, 50), + >>> nn.Unflatten(1, core.Size([2, 5, 5])) + >>> ) + >>> output = m(input) + >>> output.size() + core.Size([2, 2, 5, 5]) + >>> # With namedshape (tuple of tuples) + >>> input = core.randn(2, 50, names=('N', 'features')) + >>> unflatten = nn.Unflatten('features', (('C', 2), ('H', 5), ('W', 5))) + >>> output = unflatten(input) + >>> output.size() + core.Size([2, 2, 5, 5]) + """ + + NamedShape = Tuple[Tuple[str, int]] + + __constants__ = ['dim', 'unflattened_size'] + dim: Union[int, str] + unflattened_size: Union[_size, NamedShape] + + def __init__(self, dim: Union[int, str], unflattened_size: Union[_size, NamedShape]) -> None: + super().__init__() + + if isinstance(dim, int): + self._require_tuple_int(unflattened_size) + elif isinstance(dim, str): + self._require_tuple_tuple(unflattened_size) + else: + raise TypeError("invalid argument type for dim parameter") + + self.dim = dim + self.unflattened_size = unflattened_size + + def _require_tuple_tuple(self, input): + if (isinstance(input, tuple)): + for idx, elem in enumerate(input): + if not isinstance(elem, tuple): + raise TypeError("unflattened_size must be tuple of tuples, " + + f"but found element of type {type(elem).__name__} at pos {idx}") + return + raise TypeError("unflattened_size must be a tuple of tuples, " + + f"but found type {type(input).__name__}") + + def _require_tuple_int(self, input): + if (isinstance(input, (tuple, list))): + for idx, elem in enumerate(input): + if not isinstance(elem, int): + raise TypeError("unflattened_size must be tuple of ints, " + + f"but found element of type {type(elem).__name__} at pos {idx}") + return + raise TypeError(f"unflattened_size must be a tuple of ints, but found type {type(input).__name__}") + + def forward(self, input: Tensor) -> Tensor: + return unflatten(input, self.dim, self.unflattened_size) + + def extra_repr(self) -> str: + return f'dim={self.dim}, unflattened_size={self.unflattened_size}' diff --git a/mindnlp/core/nn/modules/fold.py b/mindnlp/core/nn/modules/fold.py new file mode 100644 index 000000000..3949a0887 --- /dev/null +++ b/mindnlp/core/nn/modules/fold.py @@ -0,0 +1,305 @@ +"""fold module""" +from mindnlp.core import Tensor +from .module import Module +from .. import functional as F + +from ..common_types import _size_any_t + +__all__ = ['Fold', 'Unfold'] + +class Fold(Module): + r"""Combines an array of sliding local blocks into a large containing tensor. + + Consider a batched :attr:`input` tensor containing sliding local blocks, + e.g., patches of images, of shape :math:`(N, C \times \prod(\text{kernel\_size}), L)`, + where :math:`N` is batch dimension, :math:`C \times \prod(\text{kernel\_size})` + is the number of values within a block (a block has :math:`\prod(\text{kernel\_size})` + spatial locations each containing a :math:`C`-channeled vector), and + :math:`L` is the total number of blocks. (This is exactly the + same specification as the output shape of :class:`~core.nn.Unfold`.) This + operation combines these local blocks into the large :attr:`output` tensor + of shape :math:`(N, C, \text{output\_size}[0], \text{output\_size}[1], \dots)` + by summing the overlapping values. Similar to :class:`~core.nn.Unfold`, the + arguments must satisfy + + .. math:: + L = \prod_d \left\lfloor\frac{\text{output\_size}[d] + 2 \times \text{padding}[d] % + - \text{dilation}[d] \times (\text{kernel\_size}[d] - 1) - 1}{\text{stride}[d]} + 1\right\rfloor, + + where :math:`d` is over all spatial dimensions. + + * :attr:`output_size` describes the spatial shape of the large containing + tensor of the sliding local blocks. It is useful to resolve the ambiguity + when multiple input shapes map to same number of sliding blocks, e.g., + with ``stride > 0``. + + The :attr:`padding`, :attr:`stride` and :attr:`dilation` arguments specify + how the sliding blocks are retrieved. + + * :attr:`stride` controls the stride for the sliding blocks. + + * :attr:`padding` controls the amount of implicit zero-paddings on both + sides for :attr:`padding` number of points for each dimension before + reshaping. + + * :attr:`dilation` controls the spacing between the kernel points; also known as the \u00e0 trous algorithm. + It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does. + + Args: + output_size (int or tuple): the shape of the spatial dimensions of the + output (i.e., ``output.sizes()[2:]``) + kernel_size (int or tuple): the size of the sliding blocks + dilation (int or tuple, optional): a parameter that controls the + stride of elements within the + neighborhood. Default: 1 + padding (int or tuple, optional): implicit zero padding to be added on + both sides of input. Default: 0 + stride (int or tuple): the stride of the sliding blocks in the input + spatial dimensions. Default: 1 + + * If :attr:`output_size`, :attr:`kernel_size`, :attr:`dilation`, + :attr:`padding` or :attr:`stride` is an int or a tuple of length 1 then + their values will be replicated across all spatial dimensions. + + * For the case of two output spatial dimensions this operation is sometimes + called ``col2im``. + + .. note:: + :class:`~core.nn.Fold` calculates each combined value in the resulting + large tensor by summing all values from all containing blocks. + :class:`~core.nn.Unfold` extracts the values in the local blocks by + copying from the large tensor. So, if the blocks overlap, they are not + inverses of each other. + + In general, folding and unfolding operations are related as + follows. Consider :class:`~core.nn.Fold` and + :class:`~core.nn.Unfold` instances created with the same + parameters: + + >>> fold_params = dict(kernel_size=..., dilation=..., padding=..., stride=...) + >>> fold = nn.Fold(output_size=..., **fold_params) + >>> unfold = nn.Unfold(**fold_params) + + Then for any (supported) ``input`` tensor the following + equality holds: + + :: + + fold(unfold(input)) == divisor * input + + where ``divisor`` is a tensor that depends only on the shape + and dtype of the ``input``: + + >>> # xdoctest: +SKIP + >>> input_ones = core.ones(input.shape, dtype=input.dtype) + >>> divisor = fold(unfold(input_ones)) + + When the ``divisor`` tensor contains no zero elements, then + ``fold`` and ``unfold`` operations are inverses of each + other (up to constant divisor). + + .. warning:: + Currently, only unbatched (3D) or batched (4D) image-like output tensors are supported. + + Shape: + - Input: :math:`(N, C \times \prod(\text{kernel\_size}), L)` or :math:`(C \times \prod(\text{kernel\_size}), L)` + - Output: :math:`(N, C, \text{output\_size}[0], \text{output\_size}[1], \dots)` + or :math:`(C, \text{output\_size}[0], \text{output\_size}[1], \dots)` as described above + + Examples:: + + >>> fold = nn.Fold(output_size=(4, 5), kernel_size=(2, 2)) + >>> input = core.randn(1, 3 * 2 * 2, 12) + >>> output = fold(input) + >>> output.size() + core.Size([1, 3, 4, 5]) + + .. _link: + https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md + + """ + + __constants__ = ['output_size', 'kernel_size', 'dilation', 'padding', + 'stride'] + output_size: _size_any_t + kernel_size: _size_any_t + dilation: _size_any_t + padding: _size_any_t + stride: _size_any_t + + def __init__( + self, + output_size: _size_any_t, + kernel_size: _size_any_t, + dilation: _size_any_t = 1, + padding: _size_any_t = 0, + stride: _size_any_t = 1 + ) -> None: + super().__init__() + self.output_size = output_size + self.kernel_size = kernel_size + self.dilation = dilation + self.padding = padding + self.stride = stride + + def forward(self, input: Tensor) -> Tensor: + return F.fold(input, self.output_size, self.kernel_size, self.dilation, + self.padding, self.stride) + + def extra_repr(self) -> str: + return 'output_size={output_size}, kernel_size={kernel_size}, ' \ + 'dilation={dilation}, padding={padding}, stride={stride}'.format( + **self.__dict__ + ) + + + +class Unfold(Module): + r"""Extracts sliding local blocks from a batched input tensor. + + Consider a batched :attr:`input` tensor of shape :math:`(N, C, *)`, + where :math:`N` is the batch dimension, :math:`C` is the channel dimension, + and :math:`*` represent arbitrary spatial dimensions. This operation flattens + each sliding :attr:`kernel_size`-sized block within the spatial dimensions + of :attr:`input` into a column (i.e., last dimension) of a 3-D :attr:`output` + tensor of shape :math:`(N, C \times \prod(\text{kernel\_size}), L)`, where + :math:`C \times \prod(\text{kernel\_size})` is the total number of values + within each block (a block has :math:`\prod(\text{kernel\_size})` spatial + locations each containing a :math:`C`-channeled vector), and :math:`L` is + the total number of such blocks: + + .. math:: + L = \prod_d \left\lfloor\frac{\text{spatial\_size}[d] + 2 \times \text{padding}[d] % + - \text{dilation}[d] \times (\text{kernel\_size}[d] - 1) - 1}{\text{stride}[d]} + 1\right\rfloor, + + where :math:`\text{spatial\_size}` is formed by the spatial dimensions + of :attr:`input` (:math:`*` above), and :math:`d` is over all spatial + dimensions. + + Therefore, indexing :attr:`output` at the last dimension (column dimension) + gives all values within a certain block. + + The :attr:`padding`, :attr:`stride` and :attr:`dilation` arguments specify + how the sliding blocks are retrieved. + + * :attr:`stride` controls the stride for the sliding blocks. + + * :attr:`padding` controls the amount of implicit zero-paddings on both + sides for :attr:`padding` number of points for each dimension before + reshaping. + + * :attr:`dilation` controls the spacing between the kernel points; also known as the \u00e0 trous algorithm. + It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does. + + Args: + kernel_size (int or tuple): the size of the sliding blocks + dilation (int or tuple, optional): a parameter that controls the + stride of elements within the + neighborhood. Default: 1 + padding (int or tuple, optional): implicit zero padding to be added on + both sides of input. Default: 0 + stride (int or tuple, optional): the stride of the sliding blocks in the input + spatial dimensions. Default: 1 + + * If :attr:`kernel_size`, :attr:`dilation`, :attr:`padding` or + :attr:`stride` is an int or a tuple of length 1, their values will be + replicated across all spatial dimensions. + + * For the case of two input spatial dimensions this operation is sometimes + called ``im2col``. + + .. note:: + :class:`~core.nn.Fold` calculates each combined value in the resulting + large tensor by summing all values from all containing blocks. + :class:`~core.nn.Unfold` extracts the values in the local blocks by + copying from the large tensor. So, if the blocks overlap, they are not + inverses of each other. + + In general, folding and unfolding operations are related as + follows. Consider :class:`~core.nn.Fold` and + :class:`~core.nn.Unfold` instances created with the same + parameters: + + >>> fold_params = dict(kernel_size=..., dilation=..., padding=..., stride=...) + >>> fold = nn.Fold(output_size=..., **fold_params) + >>> unfold = nn.Unfold(**fold_params) + + Then for any (supported) ``input`` tensor the following + equality holds: + + :: + + fold(unfold(input)) == divisor * input + + where ``divisor`` is a tensor that depends only on the shape + and dtype of the ``input``: + + >>> # xdoctest: +SKIP + >>> input_ones = core.ones(input.shape, dtype=input.dtype) + >>> divisor = fold(unfold(input_ones)) + + When the ``divisor`` tensor contains no zero elements, then + ``fold`` and ``unfold`` operations are inverses of each + other (up to constant divisor). + + .. warning:: + Currently, only 4-D input tensors (batched image-like tensors) are + supported. + + Shape: + - Input: :math:`(N, C, *)` + - Output: :math:`(N, C \times \prod(\text{kernel\_size}), L)` as described above + + Examples:: + + >>> unfold = nn.Unfold(kernel_size=(2, 3)) + >>> input = core.randn(2, 5, 3, 4) + >>> output = unfold(input) + >>> # each patch contains 30 values (2x3=6 vectors, each of 5 channels) + >>> # 4 blocks (2x3 kernels) in total in the 3x4 input + >>> output.size() + core.Size([2, 30, 4]) + + >>> # xdoctest: +IGNORE_WANT + >>> # Convolution is equivalent with Unfold + Matrix Multiplication + Fold (or view to output shape) + >>> inp = core.randn(1, 3, 10, 12) + >>> w = core.randn(2, 3, 4, 5) + >>> inp_unf = core.nn.functional.unfold(inp, (4, 5)) + >>> out_unf = inp_unf.transpose(1, 2).matmul(w.view(w.size(0), -1).t()).transpose(1, 2) + >>> out = core.nn.functional.fold(out_unf, (7, 8), (1, 1)) + >>> # or equivalently (and avoiding a copy), + >>> # out = out_unf.view(1, 2, 7, 8) + >>> (core.nn.functional.conv2d(inp, w) - out).abs().max() + tensor(1.9073e-06) + + .. _link: + https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md + + """ + + __constants__ = ['kernel_size', 'dilation', 'padding', 'stride'] + kernel_size: _size_any_t + dilation: _size_any_t + padding: _size_any_t + stride: _size_any_t + + def __init__( + self, + kernel_size: _size_any_t, + dilation: _size_any_t = 1, + padding: _size_any_t = 0, + stride: _size_any_t = 1 + ) -> None: + super().__init__() + self.kernel_size = kernel_size + self.dilation = dilation + self.padding = padding + self.stride = stride + + def forward(self, input: Tensor) -> Tensor: + return F.unfold(input, self.kernel_size, self.dilation, + self.padding, self.stride) + + def extra_repr(self) -> str: + return 'kernel_size={kernel_size}, dilation={dilation}, padding={padding},' \ + ' stride={stride}'.format(**self.__dict__) diff --git a/mindnlp/core/nn/modules/instancenorm.py b/mindnlp/core/nn/modules/instancenorm.py new file mode 100644 index 000000000..390d3c33a --- /dev/null +++ b/mindnlp/core/nn/modules/instancenorm.py @@ -0,0 +1,361 @@ +# mypy: allow-untyped-defs + +import warnings + +from mindnlp import core.nn.functional as F +from mindnlp.core import Tensor + +from .batchnorm import _NormBase + + +__all__ = [ + "InstanceNorm1d", + "InstanceNorm2d", + "InstanceNorm3d", + "LazyInstanceNorm1d", + "LazyInstanceNorm2d", + "LazyInstanceNorm3d", +] + + +class _InstanceNorm(_NormBase): + def __init__( + self, + num_features: int, + eps: float = 1e-5, + momentum: float = 0.1, + affine: bool = False, + track_running_stats: bool = False, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__( + num_features, eps, momentum, affine, track_running_stats, **factory_kwargs + ) + + def _check_input_dim(self, input): + raise NotImplementedError + + def _get_no_batch_dim(self): + raise NotImplementedError + + def _handle_no_batch_input(self, input): + return self._apply_instance_norm(input.unsqueeze(0)).squeeze(0) + + def _apply_instance_norm(self, input): + return F.instance_norm( + input, + self.running_mean, + self.running_var, + self.weight, + self.bias, + self.training or not self.track_running_stats, + self.momentum if self.momentum is not None else 0.0, + self.eps, + ) + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + version = local_metadata.get("version", None) + # at version 1: removed running_mean and running_var when + # track_running_stats=False (default) + if version is None and not self.track_running_stats: + running_stats_keys = [] + for name in ("running_mean", "running_var"): + key = prefix + name + if key in state_dict: + running_stats_keys.append(key) + if len(running_stats_keys) > 0: + error_msgs.append( + "Unexpected running stats buffer(s) {names} for {klass} " + "with track_running_stats=False. If state_dict is a " + "checkpoint saved before 0.4.0, this may be expected " + "because {klass} does not track running stats by default " + "since 0.4.0. Please remove these keys from state_dict. If " + "the running stats are actually needed, instead set " + "track_running_stats=True in {klass} to enable them. See " + "the documentation of {klass} for details.".format( + names=" and ".join(f'"{k}"' for k in running_stats_keys), + klass=self.__class__.__name__, + ) + ) + for key in running_stats_keys: + state_dict.pop(key) + + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) + + def forward(self, input: Tensor) -> Tensor: + self._check_input_dim(input) + + feature_dim = input.dim() - self._get_no_batch_dim() + if input.size(feature_dim) != self.num_features: + if self.affine: + raise ValueError( + f"expected input's size at dim={feature_dim} to match num_features" + f" ({self.num_features}), but got: {input.size(feature_dim)}." + ) + else: + warnings.warn( + f"input's size at dim={feature_dim} does not match num_features. " + "You can silence this warning by not passing in num_features, " + "which is not used because affine=False" + ) + + if input.dim() == self._get_no_batch_dim(): + return self._handle_no_batch_input(input) + + return self._apply_instance_norm(input) + + +class InstanceNorm1d(_InstanceNorm): + r"""Applies Instance Normalization. + + This operation applies Instance Normalization + over a 2D (unbatched) or 3D (batched) input as described in the paper + `Instance Normalization: The Missing Ingredient for Fast Stylization + `__. + + .. math:: + + y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta + + The mean and standard-deviation are calculated per-dimension separately + for each object in a mini-batch. :math:`\gamma` and :math:`\beta` are learnable parameter vectors + of size `C` (where `C` is the number of features or channels of the input) if :attr:`affine` is ``True``. + The variance is calculated via the biased estimator, equivalent to + `torch.var(input, unbiased=False)`. + + By default, this layer uses instance statistics computed from input data in + both training and evaluation modes. + + If :attr:`track_running_stats` is set to ``True``, during training this + layer keeps running estimates of its computed mean and variance, which are + then used for normalization during evaluation. The running estimates are + kept with a default :attr:`momentum` of 0.1. + + .. note:: + This :attr:`momentum` argument is different from one used in optimizer + classes and the conventional notion of momentum. Mathematically, the + update rule for running statistics here is + :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`, + where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the + new observed value. + + .. note:: + :class:`InstanceNorm1d` and :class:`LayerNorm` are very similar, but + have some subtle differences. :class:`InstanceNorm1d` is applied + on each channel of channeled data like multidimensional time series, but + :class:`LayerNorm` is usually applied on entire sample and often in NLP + tasks. Additionally, :class:`LayerNorm` applies elementwise affine + transform, while :class:`InstanceNorm1d` usually don't apply affine + transform. + + Args: + num_features: number of features or channels :math:`C` of the input + eps: a value added to the denominator for numerical stability. Default: 1e-5 + momentum: the value used for the running_mean and running_var computation. Default: 0.1 + affine: a boolean value that when set to ``True``, this module has + learnable affine parameters, initialized the same way as done for batch normalization. + Default: ``False``. + track_running_stats: a boolean value that when set to ``True``, this + module tracks the running mean and variance, and when set to ``False``, + this module does not track such statistics and always uses batch + statistics in both training and eval modes. Default: ``False`` + + Shape: + - Input: :math:`(N, C, L)` or :math:`(C, L)` + - Output: :math:`(N, C, L)` or :math:`(C, L)` (same shape as input) + + Examples:: + + >>> # Without Learnable Parameters + >>> m = nn.InstanceNorm1d(100) + >>> # With Learnable Parameters + >>> m = nn.InstanceNorm1d(100, affine=True) + >>> input = torch.randn(20, 100, 40) + >>> output = m(input) + """ + + def _get_no_batch_dim(self): + return 2 + + def _check_input_dim(self, input): + if input.dim() not in (2, 3): + raise ValueError(f"expected 2D or 3D input (got {input.dim()}D input)") + + + +class InstanceNorm2d(_InstanceNorm): + r"""Applies Instance Normalization. + + This operation applies Instance Normalization + over a 4D input (a mini-batch of 2D inputs + with additional channel dimension) as described in the paper + `Instance Normalization: The Missing Ingredient for Fast Stylization + `__. + + .. math:: + + y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta + + The mean and standard-deviation are calculated per-dimension separately + for each object in a mini-batch. :math:`\gamma` and :math:`\beta` are learnable parameter vectors + of size `C` (where `C` is the input size) if :attr:`affine` is ``True``. + The standard-deviation is calculated via the biased estimator, equivalent to + `torch.var(input, unbiased=False)`. + + By default, this layer uses instance statistics computed from input data in + both training and evaluation modes. + + If :attr:`track_running_stats` is set to ``True``, during training this + layer keeps running estimates of its computed mean and variance, which are + then used for normalization during evaluation. The running estimates are + kept with a default :attr:`momentum` of 0.1. + + .. note:: + This :attr:`momentum` argument is different from one used in optimizer + classes and the conventional notion of momentum. Mathematically, the + update rule for running statistics here is + :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`, + where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the + new observed value. + + .. note:: + :class:`InstanceNorm2d` and :class:`LayerNorm` are very similar, but + have some subtle differences. :class:`InstanceNorm2d` is applied + on each channel of channeled data like RGB images, but + :class:`LayerNorm` is usually applied on entire sample and often in NLP + tasks. Additionally, :class:`LayerNorm` applies elementwise affine + transform, while :class:`InstanceNorm2d` usually don't apply affine + transform. + + Args: + num_features: :math:`C` from an expected input of size + :math:`(N, C, H, W)` or :math:`(C, H, W)` + eps: a value added to the denominator for numerical stability. Default: 1e-5 + momentum: the value used for the running_mean and running_var computation. Default: 0.1 + affine: a boolean value that when set to ``True``, this module has + learnable affine parameters, initialized the same way as done for batch normalization. + Default: ``False``. + track_running_stats: a boolean value that when set to ``True``, this + module tracks the running mean and variance, and when set to ``False``, + this module does not track such statistics and always uses batch + statistics in both training and eval modes. Default: ``False`` + + Shape: + - Input: :math:`(N, C, H, W)` or :math:`(C, H, W)` + - Output: :math:`(N, C, H, W)` or :math:`(C, H, W)` (same shape as input) + + Examples:: + + >>> # Without Learnable Parameters + >>> m = nn.InstanceNorm2d(100) + >>> # With Learnable Parameters + >>> m = nn.InstanceNorm2d(100, affine=True) + >>> input = torch.randn(20, 100, 35, 45) + >>> output = m(input) + """ + + def _get_no_batch_dim(self): + return 3 + + def _check_input_dim(self, input): + if input.dim() not in (3, 4): + raise ValueError(f"expected 3D or 4D input (got {input.dim()}D input)") + + + +class InstanceNorm3d(_InstanceNorm): + r"""Applies Instance Normalization. + + This operation applies Instance Normalization + over a 5D input (a mini-batch of 3D inputs with additional channel dimension) as described in the paper + `Instance Normalization: The Missing Ingredient for Fast Stylization + `__. + + .. math:: + + y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta + + The mean and standard-deviation are calculated per-dimension separately + for each object in a mini-batch. :math:`\gamma` and :math:`\beta` are learnable parameter vectors + of size C (where C is the input size) if :attr:`affine` is ``True``. + The standard-deviation is calculated via the biased estimator, equivalent to + `torch.var(input, unbiased=False)`. + + By default, this layer uses instance statistics computed from input data in + both training and evaluation modes. + + If :attr:`track_running_stats` is set to ``True``, during training this + layer keeps running estimates of its computed mean and variance, which are + then used for normalization during evaluation. The running estimates are + kept with a default :attr:`momentum` of 0.1. + + .. note:: + This :attr:`momentum` argument is different from one used in optimizer + classes and the conventional notion of momentum. Mathematically, the + update rule for running statistics here is + :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`, + where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the + new observed value. + + .. note:: + :class:`InstanceNorm3d` and :class:`LayerNorm` are very similar, but + have some subtle differences. :class:`InstanceNorm3d` is applied + on each channel of channeled data like 3D models with RGB color, but + :class:`LayerNorm` is usually applied on entire sample and often in NLP + tasks. Additionally, :class:`LayerNorm` applies elementwise affine + transform, while :class:`InstanceNorm3d` usually don't apply affine + transform. + + Args: + num_features: :math:`C` from an expected input of size + :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)` + eps: a value added to the denominator for numerical stability. Default: 1e-5 + momentum: the value used for the running_mean and running_var computation. Default: 0.1 + affine: a boolean value that when set to ``True``, this module has + learnable affine parameters, initialized the same way as done for batch normalization. + Default: ``False``. + track_running_stats: a boolean value that when set to ``True``, this + module tracks the running mean and variance, and when set to ``False``, + this module does not track such statistics and always uses batch + statistics in both training and eval modes. Default: ``False`` + + Shape: + - Input: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)` + - Output: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)` (same shape as input) + + Examples:: + + >>> # Without Learnable Parameters + >>> m = nn.InstanceNorm3d(100) + >>> # With Learnable Parameters + >>> m = nn.InstanceNorm3d(100, affine=True) + >>> input = torch.randn(20, 100, 35, 45, 10) + >>> output = m(input) + """ + + def _get_no_batch_dim(self): + return 4 + + def _check_input_dim(self, input): + if input.dim() not in (4, 5): + raise ValueError(f"expected 4D or 5D input (got {input.dim()}D input)") + diff --git a/mindnlp/core/nn/modules/linear.py b/mindnlp/core/nn/modules/linear.py new file mode 100644 index 000000000..43ee37f9b --- /dev/null +++ b/mindnlp/core/nn/modules/linear.py @@ -0,0 +1,94 @@ +"""linear""" +from typing import Any +import math +from mindnlp.core import Tensor +from ..parameter import Parameter +from .module import Module +from .. import init +from .. import functional as F +from ... import ops + +class Linear(Module): + r"""Applies a linear transformation to the incoming data: :math:`y = Ax + b` + + Args: + in_features: size of each input sample + out_features: size of each output sample + bias: If set to False, the layer will not learn an additive bias. + Default: True + + Shape: + - Input: :math:`(N, in\_features)` + - Output: :math:`(N, out\_features)` + + Attributes: + weight: the learnable weights of the module of shape + (out_features x in_features) + bias: the learnable bias of the module of shape (out_features) + + Examples:: + + >>> m = nn.Linear(20, 30) + >>> input = autograd.Variable(core.randn(128, 20)) + >>> output = m(input) + >>> print(output.size()) + """ + + def __init__(self, in_features, out_features, bias=True, dtype=None) -> None: + factory_kwargs = {'dtype': dtype} + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.weight = Parameter(ops.empty((out_features, in_features), **factory_kwargs)) + if bias: + self.bias = Parameter(ops.empty(out_features, **factory_kwargs)) + else: + self.register_parameter('bias', None) + + self.reset_parameters() + + def reset_parameters(self) -> None: + # Setting a=sqrt(5) in kaiming_uniform is the same as initializing with + # uniform(-1/sqrt(in_features), 1/sqrt(in_features)). For details, see + # https://github.com/pytorch/pytorch/issues/57109 + init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + if self.bias is not None: + fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + init.uniform_(self.bias, -bound, bound) + + def forward(self, input): + return F.linear(input, self.weight, self.bias) + + def __repr__(self): + return self.__class__.__name__ + ' (' \ + + str(self.in_features) + ' -> ' \ + + str(self.out_features) + ')' + + +class Identity(Module): + r"""A placeholder identity operator that is argument-insensitive. + + Args: + args: any argument (unused) + kwargs: any keyword argument (unused) + + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Output: :math:`(*)`, same shape as the input. + + Examples:: + + >>> m = nn.Identity(54, unused_argument1=0.1, unused_argument2=False) + >>> input = core.randn(128, 20) + >>> output = m(input) + >>> print(output.size()) + core.Size([128, 20]) + + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__() + + def forward(self, input: Tensor) -> Tensor: + return input diff --git a/mindnlp/core/nn/modules/loss.py b/mindnlp/core/nn/modules/loss.py new file mode 100644 index 000000000..0ce846fbd --- /dev/null +++ b/mindnlp/core/nn/modules/loss.py @@ -0,0 +1,1816 @@ +"""loss""" +from typing import Callable, Optional +from typing_extensions import deprecated +from mindnlp.core import Tensor + +from .distance import PairwiseDistance +from .module import Module +from .. import functional as F +from .. import _reduction as _Reduction + +__all__ = ['L1Loss', 'NLLLoss', 'NLLLoss2d', 'PoissonNLLLoss', 'GaussianNLLLoss', 'KLDivLoss', + 'MSELoss', 'BCELoss', 'BCEWithLogitsLoss', 'HingeEmbeddingLoss', 'MultiLabelMarginLoss', + 'SmoothL1Loss', 'HuberLoss', 'SoftMarginLoss', 'CrossEntropyLoss', 'MultiLabelSoftMarginLoss', + 'CosineEmbeddingLoss', 'MarginRankingLoss', 'MultiMarginLoss', 'TripletMarginLoss', + 'TripletMarginWithDistanceLoss', 'CTCLoss'] + +class _Loss(Module): + reduction: str + + def __init__(self, size_average=None, reduce=None, reduction: str = 'mean') -> None: + super().__init__() + if size_average is not None or reduce is not None: + self.reduction: str = _Reduction.legacy_get_string(size_average, reduce) + else: + self.reduction = reduction + + +class _WeightedLoss(_Loss): + def __init__(self, weight: Optional[Tensor] = None, size_average=None, reduce=None, reduction: str = 'mean') -> None: + super().__init__(size_average, reduce, reduction) + self.register_buffer('weight', weight) + self.weight: Optional[Tensor] + + + +class L1Loss(_Loss): + r"""Creates a criterion that measures the mean absolute error (MAE) between each element in + the input :math:`x` and target :math:`y`. + + The unreduced (i.e. with :attr:`reduction` set to ``'none'``) loss can be described as: + + .. math:: + \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad + l_n = \left| x_n - y_n \right|, + + where :math:`N` is the batch size. If :attr:`reduction` is not ``'none'`` + (default ``'mean'``), then: + + .. math:: + \ell(x, y) = + \begin{cases} + \operatorname{mean}(L), & \text{if reduction} = \text{`mean';}\\ + \operatorname{sum}(L), & \text{if reduction} = \text{`sum'.} + \end{cases} + + :math:`x` and :math:`y` are tensors of arbitrary shapes with a total + of :math:`n` elements each. + + The sum operation still operates over all the elements, and divides by :math:`n`. + + The division by :math:`n` can be avoided if one sets ``reduction = 'sum'``. + + Supports real-valued and complex-valued inputs. + + Args: + size_average (bool, optional): Deprecated (see :attr:`reduction`). By default, + the losses are averaged over each loss element in the batch. Note that for + some losses, there are multiple elements per sample. If the field :attr:`size_average` + is set to ``False``, the losses are instead summed for each minibatch. Ignored + when :attr:`reduce` is ``False``. Default: ``True`` + reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the + losses are averaged or summed over observations for each minibatch depending + on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per + batch element instead and ignores :attr:`size_average`. Default: ``True`` + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the sum of the output will be divided by the number of + elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` + and :attr:`reduce` are in the process of being deprecated, and in the meantime, + specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` + + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Target: :math:`(*)`, same shape as the input. + - Output: scalar. If :attr:`reduction` is ``'none'``, then + :math:`(*)`, same shape as the input. + + Examples:: + + >>> loss = nn.L1Loss() + >>> input = core.randn(3, 5, requires_grad=True) + >>> target = core.randn(3, 5) + >>> output = loss(input, target) + >>> output.backward() + """ + __constants__ = ['reduction'] + + def forward(self, input: Tensor, target: Tensor) -> Tensor: + return F.l1_loss(input, target, reduction=self.reduction) + + + + +class NLLLoss(_WeightedLoss): + r"""The negative log likelihood loss. It is useful to train a classification + problem with `C` classes. + + If provided, the optional argument :attr:`weight` should be a 1D Tensor assigning + weight to each of the classes. This is particularly useful when you have an + unbalanced training set. + + The `input` given through a forward call is expected to contain + log-probabilities of each class. `input` has to be a Tensor of size either + :math:`(minibatch, C)` or :math:`(minibatch, C, d_1, d_2, ..., d_K)` + with :math:`K \geq 1` for the `K`-dimensional case. The latter is useful for + higher dimension inputs, such as computing NLL loss per-pixel for 2D images. + + Obtaining log-probabilities in a neural network is easily achieved by + adding a `LogSoftmax` layer in the last layer of your network. + You may use `CrossEntropyLoss` instead, if you prefer not to add an extra + layer. + + The `target` that this loss expects should be a class index in the range :math:`[0, C-1]` + where `C = number of classes`; if `ignore_index` is specified, this loss also accepts + this class index (this index may not necessarily be in the class range). + + The unreduced (i.e. with :attr:`reduction` set to ``'none'``) loss can be described as: + + .. math:: + \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad + l_n = - w_{y_n} x_{n,y_n}, \quad + w_{c} = \text{weight}[c] \cdot \mathbb{1}\{c \not= \text{ignore\_index}\}, + + where :math:`x` is the input, :math:`y` is the target, :math:`w` is the weight, and + :math:`N` is the batch size. If :attr:`reduction` is not ``'none'`` + (default ``'mean'``), then + + .. math:: + \ell(x, y) = \begin{cases} + \sum_{n=1}^N \frac{1}{\sum_{n=1}^N w_{y_n}} l_n, & + \text{if reduction} = \text{`mean';}\\ + \sum_{n=1}^N l_n, & + \text{if reduction} = \text{`sum'.} + \end{cases} + + Args: + weight (Tensor, optional): a manual rescaling weight given to each + class. If given, it has to be a Tensor of size `C`. Otherwise, it is + treated as if having all ones. + size_average (bool, optional): Deprecated (see :attr:`reduction`). By default, + the losses are averaged over each loss element in the batch. Note that for + some losses, there are multiple elements per sample. If the field :attr:`size_average` + is set to ``False``, the losses are instead summed for each minibatch. Ignored + when :attr:`reduce` is ``False``. Default: ``None`` + ignore_index (int, optional): Specifies a target value that is ignored + and does not contribute to the input gradient. When + :attr:`size_average` is ``True``, the loss is averaged over + non-ignored targets. + reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the + losses are averaged or summed over observations for each minibatch depending + on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per + batch element instead and ignores :attr:`size_average`. Default: ``None`` + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will + be applied, ``'mean'``: the weighted mean of the output is taken, + ``'sum'``: the output will be summed. Note: :attr:`size_average` + and :attr:`reduce` are in the process of being deprecated, and in + the meantime, specifying either of those two args will override + :attr:`reduction`. Default: ``'mean'`` + + Shape:: + - Input: :math:`(N, C)` or :math:`(C)`, where `C = number of classes`, `N = batch size`, or + :math:`(N, C, d_1, d_2, ..., d_K)` with :math:`K \geq 1` + in the case of `K`-dimensional loss. + - Target: :math:`(N)` or :math:`()`, where each value is + :math:`0 \leq \text{targets}[i] \leq C-1`, or + :math:`(N, d_1, d_2, ..., d_K)` with :math:`K \geq 1` in the case of + K-dimensional loss. + - Output: If :attr:`reduction` is ``'none'``, shape :math:`(N)` or + :math:`(N, d_1, d_2, ..., d_K)` with :math:`K \geq 1` in the case of K-dimensional loss. + Otherwise, scalar. + + Examples:: + + >>> log_softmax = nn.LogSoftmax(dim=1) + >>> loss_fn = nn.NLLLoss() + >>> # input to NLLLoss is of size N x C = 3 x 5 + >>> input = core.randn(3, 5, requires_grad=True) + >>> # each element in target must have 0 <= value < C + >>> target = core.tensor([1, 0, 4]) + >>> loss = loss_fn(log_softmax(input), target) + >>> loss.backward() + >>> + >>> + >>> # 2D loss example (used, for example, with image inputs) + >>> N, C = 5, 4 + >>> loss_fn = nn.NLLLoss() + >>> data = core.randn(N, 16, 10, 10) + >>> conv = nn.Conv2d(16, C, (3, 3)) + >>> log_softmax = nn.LogSoftmax(dim=1) + >>> # output of conv forward is of shape [N, C, 8, 8] + >>> output = log_softmax(conv(data)) + >>> # each element in target must have 0 <= value < C + >>> target = core.empty(N, 8, 8, dtype=core.long).random_(0, C) + >>> # input to NLLLoss is of size N x C x height (8) x width (8) + >>> loss = loss_fn(output, target) + >>> loss.backward() + """ + __constants__ = ['ignore_index', 'reduction'] + ignore_index: int + + def __init__(self, weight: Optional[Tensor] = None, size_average=None, ignore_index: int = -100, + reduce=None, reduction: str = 'mean') -> None: + super().__init__(weight, size_average, reduce, reduction) + self.ignore_index = ignore_index + + def forward(self, input: Tensor, target: Tensor) -> Tensor: + return F.nll_loss(input, target, weight=self.weight, ignore_index=self.ignore_index, reduction=self.reduction) + + + +@deprecated( + "`NLLLoss2d` has been deprecated. " + "Please use `NLLLoss` instead as a drop-in replacement and see " + "https://pycore.org/docs/main/nn.html#core.nn.NLLLoss for more details.", + category=FutureWarning, +) +class NLLLoss2d(NLLLoss): + def __init__(self, weight: Optional[Tensor] = None, size_average=None, ignore_index: int = -100, + reduce=None, reduction: str = 'mean') -> None: + super().__init__(weight, size_average, ignore_index, reduce, reduction) + + + +class PoissonNLLLoss(_Loss): + r"""Negative log likelihood loss with Poisson distribution of target. + + The loss can be described as: + + .. math:: + \text{target} \sim \mathrm{Poisson}(\text{input}) + + \text{loss}(\text{input}, \text{target}) = \text{input} - \text{target} * \log(\text{input}) + + \log(\text{target!}) + + The last term can be omitted or approximated with Stirling formula. The + approximation is used for target values more than 1. For targets less or + equal to 1 zeros are added to the loss. + + Args: + log_input (bool, optional): if ``True`` the loss is computed as + :math:`\exp(\text{input}) - \text{target}*\text{input}`, if ``False`` the loss is + :math:`\text{input} - \text{target}*\log(\text{input}+\text{eps})`. + full (bool, optional): whether to compute full loss, i. e. to add the + Stirling approximation term + + .. math:: + \text{target}*\log(\text{target}) - \text{target} + 0.5 * \log(2\pi\text{target}). + size_average (bool, optional): Deprecated (see :attr:`reduction`). By default, + the losses are averaged over each loss element in the batch. Note that for + some losses, there are multiple elements per sample. If the field :attr:`size_average` + is set to ``False``, the losses are instead summed for each minibatch. Ignored + when :attr:`reduce` is ``False``. Default: ``True`` + eps (float, optional): Small value to avoid evaluation of :math:`\log(0)` when + :attr:`log_input = False`. Default: 1e-8 + reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the + losses are averaged or summed over observations for each minibatch depending + on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per + batch element instead and ignores :attr:`size_average`. Default: ``True`` + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the sum of the output will be divided by the number of + elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` + and :attr:`reduce` are in the process of being deprecated, and in the meantime, + specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` + + Examples:: + + >>> loss = nn.PoissonNLLLoss() + >>> log_input = core.randn(5, 2, requires_grad=True) + >>> target = core.randn(5, 2) + >>> output = loss(log_input, target) + >>> output.backward() + + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Target: :math:`(*)`, same shape as the input. + - Output: scalar by default. If :attr:`reduction` is ``'none'``, then :math:`(*)`, + the same shape as the input. + """ + __constants__ = ['log_input', 'full', 'eps', 'reduction'] + log_input: bool + full: bool + eps: float + + def __init__(self, log_input: bool = True, full: bool = False, size_average=None, + eps: float = 1e-8, reduce=None, reduction: str = 'mean') -> None: + super().__init__(size_average, reduce, reduction) + self.log_input = log_input + self.full = full + self.eps = eps + + def forward(self, log_input: Tensor, target: Tensor) -> Tensor: + return F.poisson_nll_loss(log_input, target, log_input=self.log_input, full=self.full, + eps=self.eps, reduction=self.reduction) + + + + +class GaussianNLLLoss(_Loss): + r"""Gaussian negative log likelihood loss. + + The targets are treated as samples from Gaussian distributions with + expectations and variances predicted by the neural network. For a + ``target`` tensor modelled as having Gaussian distribution with a tensor + of expectations ``input`` and a tensor of positive variances ``var`` the loss is: + + .. math:: + \text{loss} = \frac{1}{2}\left(\log\left(\text{max}\left(\text{var}, + \ \text{eps}\right)\right) + \frac{\left(\text{input} - \text{target}\right)^2} + {\text{max}\left(\text{var}, \ \text{eps}\right)}\right) + \text{const.} + + where :attr:`eps` is used for stability. By default, the constant term of + the loss function is omitted unless :attr:`full` is ``True``. If ``var`` is not the same + size as ``input`` (due to a homoscedastic assumption), it must either have a final dimension + of 1 or have one fewer dimension (with all other sizes being the same) for correct broadcasting. + + Args: + full (bool, optional): include the constant term in the loss + calculation. Default: ``False``. + eps (float, optional): value used to clamp ``var`` (see note below), for + stability. Default: 1e-6. + reduction (str, optional): specifies the reduction to apply to the + output:``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction + will be applied, ``'mean'``: the output is the average of all batch + member losses, ``'sum'``: the output is the sum of all batch member + losses. Default: ``'mean'``. + + Shape: + - Input: :math:`(N, *)` or :math:`(*)` where :math:`*` means any number of additional + dimensions + - Target: :math:`(N, *)` or :math:`(*)`, same shape as the input, or same shape as the input + but with one dimension equal to 1 (to allow for broadcasting) + - Var: :math:`(N, *)` or :math:`(*)`, same shape as the input, or same shape as the input but + with one dimension equal to 1, or same shape as the input but with one fewer + dimension (to allow for broadcasting) + - Output: scalar if :attr:`reduction` is ``'mean'`` (default) or + ``'sum'``. If :attr:`reduction` is ``'none'``, then :math:`(N, *)`, same + shape as the input + + Examples:: + >>> loss = nn.GaussianNLLLoss() + >>> input = core.randn(5, 2, requires_grad=True) + >>> target = core.randn(5, 2) + >>> var = core.ones(5, 2, requires_grad=True) # heteroscedastic + >>> output = loss(input, target, var) + >>> output.backward() + + >>> loss = nn.GaussianNLLLoss() + >>> input = core.randn(5, 2, requires_grad=True) + >>> target = core.randn(5, 2) + >>> var = core.ones(5, 1, requires_grad=True) # homoscedastic + >>> output = loss(input, target, var) + >>> output.backward() + + Note: + The clamping of ``var`` is ignored with respect to autograd, and so the + gradients are unaffected by it. + + Reference: + Nix, D. A. and Weigend, A. S., "Estimating the mean and variance of the + target probability distribution", Proceedings of 1994 IEEE International + Conference on Neural Networks (ICNN'94), Orlando, FL, USA, 1994, pp. 55-60 + vol.1, doi: 10.1109/ICNN.1994.374138. + """ + __constants__ = ['full', 'eps', 'reduction'] + full: bool + eps: float + + def __init__(self, *, full: bool = False, eps: float = 1e-6, reduction: str = 'mean') -> None: + super().__init__(None, None, reduction) + self.full = full + self.eps = eps + + def forward(self, input: Tensor, target: Tensor, var: Tensor) -> Tensor: + return F.gaussian_nll_loss(input, target, var, full=self.full, eps=self.eps, reduction=self.reduction) + + + + +class KLDivLoss(_Loss): + r"""The Kullback-Leibler divergence loss. + + For tensors of the same shape :math:`y_{\text{pred}},\ y_{\text{true}}`, + where :math:`y_{\text{pred}}` is the :attr:`input` and :math:`y_{\text{true}}` is the + :attr:`target`, we define the **pointwise KL-divergence** as + + .. math:: + + L(y_{\text{pred}},\ y_{\text{true}}) + = y_{\text{true}} \cdot \log \frac{y_{\text{true}}}{y_{\text{pred}}} + = y_{\text{true}} \cdot (\log y_{\text{true}} - \log y_{\text{pred}}) + + To avoid underflow issues when computing this quantity, this loss expects the argument + :attr:`input` in the log-space. The argument :attr:`target` may also be provided in the + log-space if :attr:`log_target`\ `= True`. + + To summarise, this function is roughly equivalent to computing + + .. code-block:: python + + if not log_target: # default + loss_pointwise = target * (target.log() - input) + else: + loss_pointwise = target.exp() * (target - input) + + and then reducing this result depending on the argument :attr:`reduction` as + + .. code-block:: python + + if reduction == "mean": # default + loss = loss_pointwise.mean() + elif reduction == "batchmean": # mathematically correct + loss = loss_pointwise.sum() / input.size(0) + elif reduction == "sum": + loss = loss_pointwise.sum() + else: # reduction == "none" + loss = loss_pointwise + + .. note:: + As all the other losses in PyTorch, this function expects the first argument, + :attr:`input`, to be the output of the model (e.g. the neural network) + and the second, :attr:`target`, to be the observations in the dataset. + This differs from the standard mathematical notation :math:`KL(P\ ||\ Q)` where + :math:`P` denotes the distribution of the observations and :math:`Q` denotes the model. + + .. warning:: + :attr:`reduction`\ `= "mean"` doesn't return the true KL divergence value, please use + :attr:`reduction`\ `= "batchmean"` which aligns with the mathematical definition. + + Args: + size_average (bool, optional): Deprecated (see :attr:`reduction`). By default, + the losses are averaged over each loss element in the batch. Note that for + some losses, there are multiple elements per sample. If the field :attr:`size_average` + is set to `False`, the losses are instead summed for each minibatch. Ignored + when :attr:`reduce` is `False`. Default: `True` + reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the + losses are averaged or summed over observations for each minibatch depending + on :attr:`size_average`. When :attr:`reduce` is `False`, returns a loss per + batch element instead and ignores :attr:`size_average`. Default: `True` + reduction (str, optional): Specifies the reduction to apply to the output. Default: `"mean"` + log_target (bool, optional): Specifies whether `target` is the log space. Default: `False` + + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Target: :math:`(*)`, same shape as the input. + - Output: scalar by default. If :attr:`reduction` is `'none'`, then :math:`(*)`, + same shape as the input. + + Examples:: + >>> kl_loss = nn.KLDivLoss(reduction="batchmean") + >>> # input should be a distribution in the log space + >>> input = F.log_softmax(core.randn(3, 5, requires_grad=True), dim=1) + >>> # Sample a batch of distributions. Usually this would come from the dataset + >>> target = F.softmax(core.rand(3, 5), dim=1) + >>> output = kl_loss(input, target) + + >>> kl_loss = nn.KLDivLoss(reduction="batchmean", log_target=True) + >>> log_target = F.log_softmax(core.rand(3, 5), dim=1) + >>> output = kl_loss(input, log_target) + """ + __constants__ = ['reduction'] + + def __init__(self, size_average=None, reduce=None, reduction: str = 'mean', log_target: bool = False) -> None: + super().__init__(size_average, reduce, reduction) + self.log_target = log_target + + def forward(self, input: Tensor, target: Tensor) -> Tensor: + return F.kl_div(input, target, reduction=self.reduction, log_target=self.log_target) + + + + +class MSELoss(_Loss): + r"""Creates a criterion that measures the mean squared error (squared L2 norm) between + each element in the input :math:`x` and target :math:`y`. + + The unreduced (i.e. with :attr:`reduction` set to ``'none'``) loss can be described as: + + .. math:: + \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad + l_n = \left( x_n - y_n \right)^2, + + where :math:`N` is the batch size. If :attr:`reduction` is not ``'none'`` + (default ``'mean'``), then: + + .. math:: + \ell(x, y) = + \begin{cases} + \operatorname{mean}(L), & \text{if reduction} = \text{`mean';}\\ + \operatorname{sum}(L), & \text{if reduction} = \text{`sum'.} + \end{cases} + + :math:`x` and :math:`y` are tensors of arbitrary shapes with a total + of :math:`n` elements each. + + The mean operation still operates over all the elements, and divides by :math:`n`. + + The division by :math:`n` can be avoided if one sets ``reduction = 'sum'``. + + Args: + size_average (bool, optional): Deprecated (see :attr:`reduction`). By default, + the losses are averaged over each loss element in the batch. Note that for + some losses, there are multiple elements per sample. If the field :attr:`size_average` + is set to ``False``, the losses are instead summed for each minibatch. Ignored + when :attr:`reduce` is ``False``. Default: ``True`` + reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the + losses are averaged or summed over observations for each minibatch depending + on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per + batch element instead and ignores :attr:`size_average`. Default: ``True`` + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the sum of the output will be divided by the number of + elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` + and :attr:`reduce` are in the process of being deprecated, and in the meantime, + specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` + + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Target: :math:`(*)`, same shape as the input. + + Examples:: + + >>> loss = nn.MSELoss() + >>> input = core.randn(3, 5, requires_grad=True) + >>> target = core.randn(3, 5) + >>> output = loss(input, target) + >>> output.backward() + """ + __constants__ = ['reduction'] + + def forward(self, input: Tensor, target: Tensor) -> Tensor: + return F.mse_loss(input, target, reduction=self.reduction) + + +class BCELoss(_WeightedLoss): + r"""Creates a criterion that measures the Binary Cross Entropy between the target and + the input probabilities: + + The unreduced (i.e. with :attr:`reduction` set to ``'none'``) loss can be described as: + + .. math:: + \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad + l_n = - w_n \left[ y_n \cdot \log x_n + (1 - y_n) \cdot \log (1 - x_n) \right], + + where :math:`N` is the batch size. If :attr:`reduction` is not ``'none'`` + (default ``'mean'``), then + + .. math:: + \ell(x, y) = \begin{cases} + \operatorname{mean}(L), & \text{if reduction} = \text{`mean';}\\ + \operatorname{sum}(L), & \text{if reduction} = \text{`sum'.} + \end{cases} + + This is used for measuring the error of a reconstruction in for example + an auto-encoder. Note that the targets :math:`y` should be numbers + between 0 and 1. + + Notice that if :math:`x_n` is either 0 or 1, one of the log terms would be + mathematically undefined in the above loss equation. PyTorch chooses to set + :math:`\log (0) = -\infty`, since :math:`\lim_{x\to 0} \log (x) = -\infty`. + However, an infinite term in the loss equation is not desirable for several reasons. + + For one, if either :math:`y_n = 0` or :math:`(1 - y_n) = 0`, then we would be + multiplying 0 with infinity. Secondly, if we have an infinite loss value, then + we would also have an infinite term in our gradient, since + :math:`\lim_{x\to 0} \frac{d}{dx} \log (x) = \infty`. + This would make BCELoss's backward method nonlinear with respect to :math:`x_n`, + and using it for things like linear regression would not be straight-forward. + + Our solution is that BCELoss clamps its log function outputs to be greater than + or equal to -100. This way, we can always have a finite loss value and a linear + backward method. + + + Args: + weight (Tensor, optional): a manual rescaling weight given to the loss + of each batch element. If given, has to be a Tensor of size `nbatch`. + size_average (bool, optional): Deprecated (see :attr:`reduction`). By default, + the losses are averaged over each loss element in the batch. Note that for + some losses, there are multiple elements per sample. If the field :attr:`size_average` + is set to ``False``, the losses are instead summed for each minibatch. Ignored + when :attr:`reduce` is ``False``. Default: ``True`` + reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the + losses are averaged or summed over observations for each minibatch depending + on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per + batch element instead and ignores :attr:`size_average`. Default: ``True`` + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the sum of the output will be divided by the number of + elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` + and :attr:`reduce` are in the process of being deprecated, and in the meantime, + specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` + + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Target: :math:`(*)`, same shape as the input. + - Output: scalar. If :attr:`reduction` is ``'none'``, then :math:`(*)`, same + shape as input. + + Examples:: + + >>> m = nn.Sigmoid() + >>> loss = nn.BCELoss() + >>> input = core.randn(3, 2, requires_grad=True) + >>> target = core.rand(3, 2, requires_grad=False) + >>> output = loss(m(input), target) + >>> output.backward() + """ + __constants__ = ['reduction'] + + def forward(self, input: Tensor, target: Tensor) -> Tensor: + return F.binary_cross_entropy(input, target, weight=self.weight, reduction=self.reduction) + + +class BCEWithLogitsLoss(_Loss): + r"""This loss combines a `Sigmoid` layer and the `BCELoss` in one single + class. This version is more numerically stable than using a plain `Sigmoid` + followed by a `BCELoss` as, by combining the operations into one layer, + we take advantage of the log-sum-exp trick for numerical stability. + + The unreduced (i.e. with :attr:`reduction` set to ``'none'``) loss can be described as: + + .. math:: + \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad + l_n = - w_n \left[ y_n \cdot \log \sigma(x_n) + + (1 - y_n) \cdot \log (1 - \sigma(x_n)) \right], + + where :math:`N` is the batch size. If :attr:`reduction` is not ``'none'`` + (default ``'mean'``), then + + .. math:: + \ell(x, y) = \begin{cases} + \operatorname{mean}(L), & \text{if reduction} = \text{`mean';}\\ + \operatorname{sum}(L), & \text{if reduction} = \text{`sum'.} + \end{cases} + + This is used for measuring the error of a reconstruction in for example + an auto-encoder. Note that the targets `t[i]` should be numbers + between 0 and 1. + + It's possible to trade off recall and precision by adding weights to positive examples. + In the case of multi-label classification the loss can be described as: + + .. math:: + \ell_c(x, y) = L_c = \{l_{1,c},\dots,l_{N,c}\}^\top, \quad + l_{n,c} = - w_{n,c} \left[ p_c y_{n,c} \cdot \log \sigma(x_{n,c}) + + (1 - y_{n,c}) \cdot \log (1 - \sigma(x_{n,c})) \right], + + where :math:`c` is the class number (:math:`c > 1` for multi-label binary classification, + :math:`c = 1` for single-label binary classification), + :math:`n` is the number of the sample in the batch and + :math:`p_c` is the weight of the positive answer for the class :math:`c`. + + :math:`p_c > 1` increases the recall, :math:`p_c < 1` increases the precision. + + For example, if a dataset contains 100 positive and 300 negative examples of a single class, + then ``pos_weight`` for the class should be equal to :math:`\frac{300}{100}=3`. + The loss would act as if the dataset contains :math:`3\times 100=300` positive examples. + + Examples:: + + >>> target = core.ones([10, 64], dtype=core.float32) # 64 classes, batch size = 10 + >>> output = core.full([10, 64], 1.5) # A prediction (logit) + >>> pos_weight = core.ones([64]) # All weights are equal to 1 + >>> criterion = core.nn.BCEWithLogitsLoss(pos_weight=pos_weight) + >>> criterion(output, target) # -log(sigmoid(1.5)) + tensor(0.20...) + + In the above example, the ``pos_weight`` tensor's elements correspond to the 64 distinct classes + in a multi-label binary classification scenario. Each element in ``pos_weight`` is designed to adjust the + loss function based on the imbalance between negative and positive samples for the respective class. + This approach is useful in datasets with varying levels of class imbalance, ensuring that the loss + calculation accurately accounts for the distribution in each class. + + Args: + weight (Tensor, optional): a manual rescaling weight given to the loss + of each batch element. If given, has to be a Tensor of size `nbatch`. + size_average (bool, optional): Deprecated (see :attr:`reduction`). By default, + the losses are averaged over each loss element in the batch. Note that for + some losses, there are multiple elements per sample. If the field :attr:`size_average` + is set to ``False``, the losses are instead summed for each minibatch. Ignored + when :attr:`reduce` is ``False``. Default: ``True`` + reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the + losses are averaged or summed over observations for each minibatch depending + on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per + batch element instead and ignores :attr:`size_average`. Default: ``True`` + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the sum of the output will be divided by the number of + elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` + and :attr:`reduce` are in the process of being deprecated, and in the meantime, + specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` + pos_weight (Tensor, optional): a weight of positive examples to be broadcasted with target. + Must be a tensor with equal size along the class dimension to the number of classes. + Pay close attention to PyTorch's broadcasting semantics in order to achieve the desired + operations. For a target of size [B, C, H, W] (where B is batch size) pos_weight of + size [B, C, H, W] will apply different pos_weights to each element of the batch or + [C, H, W] the same pos_weights across the batch. To apply the same positive weight + along all spacial dimensions for a 2D multi-class target [C, H, W] use: [C, 1, 1]. + Default: ``None`` + + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Target: :math:`(*)`, same shape as the input. + - Output: scalar. If :attr:`reduction` is ``'none'``, then :math:`(*)`, same + shape as input. + + Examples:: + + >>> loss = nn.BCEWithLogitsLoss() + >>> input = core.randn(3, requires_grad=True) + >>> target = core.empty(3).random_(2) + >>> output = loss(input, target) + >>> output.backward() + """ + def __init__(self, weight: Optional[Tensor] = None, size_average=None, reduce=None, reduction: str = 'mean', + pos_weight: Optional[Tensor] = None) -> None: + super().__init__(size_average, reduce, reduction) + self.register_buffer('weight', weight) + self.register_buffer('pos_weight', pos_weight) + self.weight: Optional[Tensor] + self.pos_weight: Optional[Tensor] + + def forward(self, input: Tensor, target: Tensor) -> Tensor: + return F.binary_cross_entropy_with_logits(input, target, + self.weight, + pos_weight=self.pos_weight, + reduction=self.reduction) + + + + +class HingeEmbeddingLoss(_Loss): + r"""Measures the loss given an input tensor :math:`x` and a labels tensor :math:`y` + (containing 1 or -1). + This is usually used for measuring whether two inputs are similar or + dissimilar, e.g. using the L1 pairwise distance as :math:`x`, and is typically + used for learning nonlinear embeddings or semi-supervised learning. + + The loss function for :math:`n`-th sample in the mini-batch is + + .. math:: + l_n = \begin{cases} + x_n, & \text{if}\; y_n = 1,\\ + \max \{0, margin - x_n\}, & \text{if}\; y_n = -1, + \end{cases} + + and the total loss functions is + + .. math:: + \ell(x, y) = \begin{cases} + \operatorname{mean}(L), & \text{if reduction} = \text{`mean';}\\ + \operatorname{sum}(L), & \text{if reduction} = \text{`sum'.} + \end{cases} + + where :math:`L = \{l_1,\dots,l_N\}^\top`. + + Args: + margin (float, optional): Has a default value of `1`. + size_average (bool, optional): Deprecated (see :attr:`reduction`). By default, + the losses are averaged over each loss element in the batch. Note that for + some losses, there are multiple elements per sample. If the field :attr:`size_average` + is set to ``False``, the losses are instead summed for each minibatch. Ignored + when :attr:`reduce` is ``False``. Default: ``True`` + reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the + losses are averaged or summed over observations for each minibatch depending + on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per + batch element instead and ignores :attr:`size_average`. Default: ``True`` + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the sum of the output will be divided by the number of + elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` + and :attr:`reduce` are in the process of being deprecated, and in the meantime, + specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` + + Shape: + - Input: :math:`(*)` where :math:`*` means, any number of dimensions. The sum operation + operates over all the elements. + - Target: :math:`(*)`, same shape as the input + - Output: scalar. If :attr:`reduction` is ``'none'``, then same shape as the input + """ + __constants__ = ['margin', 'reduction'] + margin: float + + def __init__(self, margin: float = 1.0, size_average=None, reduce=None, reduction: str = 'mean') -> None: + super().__init__(size_average, reduce, reduction) + self.margin = margin + + def forward(self, input: Tensor, target: Tensor) -> Tensor: + return F.hinge_embedding_loss(input, target, margin=self.margin, reduction=self.reduction) + + + + +class MultiLabelMarginLoss(_Loss): + r"""Creates a criterion that optimizes a multi-class multi-classification + hinge loss (margin-based loss) between input :math:`x` (a 2D mini-batch `Tensor`) + and output :math:`y` (which is a 2D `Tensor` of target class indices). + For each sample in the mini-batch: + + .. math:: + \text{loss}(x, y) = \sum_{ij}\frac{\max(0, 1 - (x[y[j]] - x[i]))}{\text{x.size}(0)} + + where :math:`x \in \left\{0, \; \cdots , \; \text{x.size}(0) - 1\right\}`, \ + :math:`y \in \left\{0, \; \cdots , \; \text{y.size}(0) - 1\right\}`, \ + :math:`0 \leq y[j] \leq \text{x.size}(0)-1`, \ + and :math:`i \neq y[j]` for all :math:`i` and :math:`j`. + + :math:`y` and :math:`x` must have the same size. + + The criterion only considers a contiguous block of non-negative targets that + starts at the front. + + This allows for different samples to have variable amounts of target classes. + + Args: + size_average (bool, optional): Deprecated (see :attr:`reduction`). By default, + the losses are averaged over each loss element in the batch. Note that for + some losses, there are multiple elements per sample. If the field :attr:`size_average` + is set to ``False``, the losses are instead summed for each minibatch. Ignored + when :attr:`reduce` is ``False``. Default: ``True`` + reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the + losses are averaged or summed over observations for each minibatch depending + on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per + batch element instead and ignores :attr:`size_average`. Default: ``True`` + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the sum of the output will be divided by the number of + elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` + and :attr:`reduce` are in the process of being deprecated, and in the meantime, + specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` + + Shape: + - Input: :math:`(C)` or :math:`(N, C)` where `N` is the batch size and `C` + is the number of classes. + - Target: :math:`(C)` or :math:`(N, C)`, label targets padded by -1 ensuring same shape as the input. + - Output: scalar. If :attr:`reduction` is ``'none'``, then :math:`(N)`. + + Examples:: + + >>> loss = nn.MultiLabelMarginLoss() + >>> x = core.FloatTensor([[0.1, 0.2, 0.4, 0.8]]) + >>> # for target y, only consider labels 3 and 0, not after label -1 + >>> y = core.LongTensor([[3, 0, -1, 1]]) + >>> # 0.25 * ((1-(0.1-0.2)) + (1-(0.1-0.4)) + (1-(0.8-0.2)) + (1-(0.8-0.4))) + >>> loss(x, y) + tensor(0.85...) + + """ + __constants__ = ['reduction'] + + def forward(self, input: Tensor, target: Tensor) -> Tensor: + return F.multilabel_margin_loss(input, target, reduction=self.reduction) + + + + +class SmoothL1Loss(_Loss): + r"""Creates a criterion that uses a squared term if the absolute + element-wise error falls below beta and an L1 term otherwise. + It is less sensitive to outliers than :class:`core.nn.MSELoss` and in some cases + prevents exploding gradients (e.g. see the paper `Fast R-CNN`_ by Ross Girshick). + + For a batch of size :math:`N`, the unreduced loss can be described as: + + .. math:: + \ell(x, y) = L = \{l_1, ..., l_N\}^T + + with + + .. math:: + l_n = \begin{cases} + 0.5 (x_n - y_n)^2 / beta, & \text{if } |x_n - y_n| < beta \\ + |x_n - y_n| - 0.5 * beta, & \text{otherwise } + \end{cases} + + If `reduction` is not `none`, then: + + .. math:: + \ell(x, y) = + \begin{cases} + \operatorname{mean}(L), & \text{if reduction} = \text{`mean';}\\ + \operatorname{sum}(L), & \text{if reduction} = \text{`sum'.} + \end{cases} + + .. note:: + Smooth L1 loss can be seen as exactly :class:`L1Loss`, but with the :math:`|x - y| < beta` + portion replaced with a quadratic function such that its slope is 1 at :math:`|x - y| = beta`. + The quadratic segment smooths the L1 loss near :math:`|x - y| = 0`. + + .. note:: + Smooth L1 loss is closely related to :class:`HuberLoss`, being + equivalent to :math:`huber(x, y) / beta` (note that Smooth L1's beta hyper-parameter is + also known as delta for Huber). This leads to the following differences: + + * As beta -> 0, Smooth L1 loss converges to :class:`L1Loss`, while :class:`HuberLoss` + converges to a constant 0 loss. When beta is 0, Smooth L1 loss is equivalent to L1 loss. + * As beta -> :math:`+\infty`, Smooth L1 loss converges to a constant 0 loss, while + :class:`HuberLoss` converges to :class:`MSELoss`. + * For Smooth L1 loss, as beta varies, the L1 segment of the loss has a constant slope of 1. + For :class:`HuberLoss`, the slope of the L1 segment is beta. + + .. _`Fast R-CNN`: https://arxiv.org/abs/1504.08083 + + Args: + size_average (bool, optional): Deprecated (see :attr:`reduction`). By default, + the losses are averaged over each loss element in the batch. Note that for + some losses, there are multiple elements per sample. If the field :attr:`size_average` + is set to ``False``, the losses are instead summed for each minibatch. Ignored + when :attr:`reduce` is ``False``. Default: ``True`` + reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the + losses are averaged or summed over observations for each minibatch depending + on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per + batch element instead and ignores :attr:`size_average`. Default: ``True`` + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the sum of the output will be divided by the number of + elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` + and :attr:`reduce` are in the process of being deprecated, and in the meantime, + specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` + beta (float, optional): Specifies the threshold at which to change between L1 and L2 loss. + The value must be non-negative. Default: 1.0 + + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Target: :math:`(*)`, same shape as the input. + - Output: scalar. If :attr:`reduction` is ``'none'``, then :math:`(*)`, same shape as the input. + """ + __constants__ = ['reduction'] + + def __init__(self, size_average=None, reduce=None, reduction: str = 'mean', beta: float = 1.0) -> None: + super().__init__(size_average, reduce, reduction) + self.beta = beta + + def forward(self, input: Tensor, target: Tensor) -> Tensor: + return F.smooth_l1_loss(input, target, reduction=self.reduction, beta=self.beta) + + + + +class HuberLoss(_Loss): + r"""Creates a criterion that uses a squared term if the absolute + element-wise error falls below delta and a delta-scaled L1 term otherwise. + This loss combines advantages of both :class:`L1Loss` and :class:`MSELoss`; the + delta-scaled L1 region makes the loss less sensitive to outliers than :class:`MSELoss`, + while the L2 region provides smoothness over :class:`L1Loss` near 0. See + `Huber loss `_ for more information. + + For a batch of size :math:`N`, the unreduced loss can be described as: + + .. math:: + \ell(x, y) = L = \{l_1, ..., l_N\}^T + + with + + .. math:: + l_n = \begin{cases} + 0.5 (x_n - y_n)^2, & \text{if } |x_n - y_n| < delta \\ + delta * (|x_n - y_n| - 0.5 * delta), & \text{otherwise } + \end{cases} + + If `reduction` is not `none`, then: + + .. math:: + \ell(x, y) = + \begin{cases} + \operatorname{mean}(L), & \text{if reduction} = \text{`mean';}\\ + \operatorname{sum}(L), & \text{if reduction} = \text{`sum'.} + \end{cases} + + .. note:: + When delta is set to 1, this loss is equivalent to :class:`SmoothL1Loss`. + In general, this loss differs from :class:`SmoothL1Loss` by a factor of delta (AKA beta + in Smooth L1). + See :class:`SmoothL1Loss` for additional discussion on the differences in behavior + between the two losses. + + Args: + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the sum of the output will be divided by the number of + elements in the output, ``'sum'``: the output will be summed. Default: ``'mean'`` + delta (float, optional): Specifies the threshold at which to change between delta-scaled L1 and L2 loss. + The value must be positive. Default: 1.0 + + Shape: + - Input: :math:`(*)` where :math:`*` means any number of dimensions. + - Target: :math:`(*)`, same shape as the input. + - Output: scalar. If :attr:`reduction` is ``'none'``, then :math:`(*)`, same shape as the input. + """ + __constants__ = ['reduction', 'delta'] + + def __init__(self, reduction: str = 'mean', delta: float = 1.0) -> None: + super().__init__(reduction=reduction) + self.delta = delta + + def forward(self, input: Tensor, target: Tensor) -> Tensor: + return F.huber_loss(input, target, reduction=self.reduction, delta=self.delta) + + + + +class SoftMarginLoss(_Loss): + r"""Creates a criterion that optimizes a two-class classification + logistic loss between input tensor :math:`x` and target tensor :math:`y` + (containing 1 or -1). + + .. math:: + \text{loss}(x, y) = \sum_i \frac{\log(1 + \exp(-y[i]*x[i]))}{\text{x.nelement}()} + + Args: + size_average (bool, optional): Deprecated (see :attr:`reduction`). By default, + the losses are averaged over each loss element in the batch. Note that for + some losses, there are multiple elements per sample. If the field :attr:`size_average` + is set to ``False``, the losses are instead summed for each minibatch. Ignored + when :attr:`reduce` is ``False``. Default: ``True`` + reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the + losses are averaged or summed over observations for each minibatch depending + on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per + batch element instead and ignores :attr:`size_average`. Default: ``True`` + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the sum of the output will be divided by the number of + elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` + and :attr:`reduce` are in the process of being deprecated, and in the meantime, + specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` + + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Target: :math:`(*)`, same shape as the input. + - Output: scalar. If :attr:`reduction` is ``'none'``, then :math:`(*)`, same + shape as input. + + """ + __constants__ = ['reduction'] + + def forward(self, input: Tensor, target: Tensor) -> Tensor: + return F.soft_margin_loss(input, target, reduction=self.reduction) + + + + +class CrossEntropyLoss(_WeightedLoss): + r"""This criterion computes the cross entropy loss between input logits + and target. + + It is useful when training a classification problem with `C` classes. + If provided, the optional argument :attr:`weight` should be a 1D `Tensor` + assigning weight to each of the classes. + This is particularly useful when you have an unbalanced training set. + + The `input` is expected to contain the unnormalized logits for each class (which do `not` need + to be positive or sum to 1, in general). + `input` has to be a Tensor of size :math:`(C)` for unbatched input, + :math:`(minibatch, C)` or :math:`(minibatch, C, d_1, d_2, ..., d_K)` with :math:`K \geq 1` for the + `K`-dimensional case. The last being useful for higher dimension inputs, such + as computing cross entropy loss per-pixel for 2D images. + + The `target` that this criterion expects should contain either: + + - Class indices in the range :math:`[0, C)` where :math:`C` is the number of classes; if + `ignore_index` is specified, this loss also accepts this class index (this index + may not necessarily be in the class range). The unreduced (i.e. with :attr:`reduction` + set to ``'none'``) loss for this case can be described as: + + .. math:: + \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad + l_n = - w_{y_n} \log \frac{\exp(x_{n,y_n})}{\sum_{c=1}^C \exp(x_{n,c})} + \cdot \mathbb{1}\{y_n \not= \text{ignore\_index}\} + + where :math:`x` is the input, :math:`y` is the target, :math:`w` is the weight, + :math:`C` is the number of classes, and :math:`N` spans the minibatch dimension as well as + :math:`d_1, ..., d_k` for the `K`-dimensional case. If + :attr:`reduction` is not ``'none'`` (default ``'mean'``), then + + .. math:: + \ell(x, y) = \begin{cases} + \sum_{n=1}^N \frac{1}{\sum_{n=1}^N w_{y_n} \cdot \mathbb{1}\{y_n \not= \text{ignore\_index}\}} l_n, & + \text{if reduction} = \text{`mean';}\\ + \sum_{n=1}^N l_n, & + \text{if reduction} = \text{`sum'.} + \end{cases} + + Note that this case is equivalent to applying :class:`~core.nn.LogSoftmax` + on an input, followed by :class:`~core.nn.NLLLoss`. + + - Probabilities for each class; useful when labels beyond a single class per minibatch item + are required, such as for blended labels, label smoothing, etc. The unreduced (i.e. with + :attr:`reduction` set to ``'none'``) loss for this case can be described as: + + .. math:: + \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad + l_n = - \sum_{c=1}^C w_c \log \frac{\exp(x_{n,c})}{\sum_{i=1}^C \exp(x_{n,i})} y_{n,c} + + where :math:`x` is the input, :math:`y` is the target, :math:`w` is the weight, + :math:`C` is the number of classes, and :math:`N` spans the minibatch dimension as well as + :math:`d_1, ..., d_k` for the `K`-dimensional case. If + :attr:`reduction` is not ``'none'`` (default ``'mean'``), then + + .. math:: + \ell(x, y) = \begin{cases} + \frac{\sum_{n=1}^N l_n}{N}, & + \text{if reduction} = \text{`mean';}\\ + \sum_{n=1}^N l_n, & + \text{if reduction} = \text{`sum'.} + \end{cases} + + .. note:: + The performance of this criterion is generally better when `target` contains class + indices, as this allows for optimized computation. Consider providing `target` as + class probabilities only when a single class label per minibatch item is too restrictive. + + Args: + weight (Tensor, optional): a manual rescaling weight given to each class. + If given, has to be a Tensor of size `C` and floating point dtype + size_average (bool, optional): Deprecated (see :attr:`reduction`). By default, + the losses are averaged over each loss element in the batch. Note that for + some losses, there are multiple elements per sample. If the field :attr:`size_average` + is set to ``False``, the losses are instead summed for each minibatch. Ignored + when :attr:`reduce` is ``False``. Default: ``True`` + ignore_index (int, optional): Specifies a target value that is ignored + and does not contribute to the input gradient. When :attr:`size_average` is + ``True``, the loss is averaged over non-ignored targets. Note that + :attr:`ignore_index` is only applicable when the target contains class indices. + reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the + losses are averaged or summed over observations for each minibatch depending + on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per + batch element instead and ignores :attr:`size_average`. Default: ``True`` + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will + be applied, ``'mean'``: the weighted mean of the output is taken, + ``'sum'``: the output will be summed. Note: :attr:`size_average` + and :attr:`reduce` are in the process of being deprecated, and in + the meantime, specifying either of those two args will override + :attr:`reduction`. Default: ``'mean'`` + label_smoothing (float, optional): A float in [0.0, 1.0]. Specifies the amount + of smoothing when computing the loss, where 0.0 means no smoothing. The targets + become a mixture of the original ground truth and a uniform distribution as described in + `Rethinking the Inception Architecture for Computer Vision `__. Default: :math:`0.0`. + + Shape: + - Input: Shape :math:`(C)`, :math:`(N, C)` or :math:`(N, C, d_1, d_2, ..., d_K)` with :math:`K \geq 1` + in the case of `K`-dimensional loss. + - Target: If containing class indices, shape :math:`()`, :math:`(N)` or :math:`(N, d_1, d_2, ..., d_K)` with + :math:`K \geq 1` in the case of K-dimensional loss where each value should be between :math:`[0, C)`. + If containing class probabilities, same shape as the input and each value should be between :math:`[0, 1]`. + - Output: If reduction is 'none', shape :math:`()`, :math:`(N)` or :math:`(N, d_1, d_2, ..., d_K)` with :math:`K \geq 1` + in the case of K-dimensional loss, depending on the shape of the input. Otherwise, scalar. + + + where: + + .. math:: + \begin{aligned} + C ={} & \text{number of classes} \\ + N ={} & \text{batch size} \\ + \end{aligned} + + Examples:: + + >>> # Example of target with class indices + >>> loss = nn.CrossEntropyLoss() + >>> input = core.randn(3, 5, requires_grad=True) + >>> target = core.empty(3, dtype=core.long).random_(5) + >>> output = loss(input, target) + >>> output.backward() + >>> + >>> # Example of target with class probabilities + >>> input = core.randn(3, 5, requires_grad=True) + >>> target = core.randn(3, 5).softmax(dim=1) + >>> output = loss(input, target) + >>> output.backward() + """ + __constants__ = ['ignore_index', 'reduction', 'label_smoothing'] + ignore_index: int + label_smoothing: float + + def __init__(self, weight: Optional[Tensor] = None, size_average=None, ignore_index: int = -100, + reduce=None, reduction: str = 'mean', label_smoothing: float = 0.0) -> None: + super().__init__(weight, size_average, reduce, reduction) + self.ignore_index = ignore_index + self.label_smoothing = label_smoothing + + def forward(self, input: Tensor, target: Tensor) -> Tensor: + return F.cross_entropy(input, target, weight=self.weight, + ignore_index=self.ignore_index, reduction=self.reduction, + label_smoothing=self.label_smoothing) + + + + +class MultiLabelSoftMarginLoss(_WeightedLoss): + r"""Creates a criterion that optimizes a multi-label one-versus-all + loss based on max-entropy, between input :math:`x` and target :math:`y` of size + :math:`(N, C)`. + For each sample in the minibatch: + + .. math:: + loss(x, y) = - \frac{1}{C} * \sum_i y[i] * \log((1 + \exp(-x[i]))^{-1}) + + (1-y[i]) * \log\left(\frac{\exp(-x[i])}{(1 + \exp(-x[i]))}\right) + + where :math:`i \in \left\{0, \; \cdots , \; \text{x.nElement}() - 1\right\}`, + :math:`y[i] \in \left\{0, \; 1\right\}`. + + Args: + weight (Tensor, optional): a manual rescaling weight given to each + class. If given, it has to be a Tensor of size `C`. Otherwise, it is + treated as if having all ones. + size_average (bool, optional): Deprecated (see :attr:`reduction`). By default, + the losses are averaged over each loss element in the batch. Note that for + some losses, there are multiple elements per sample. If the field :attr:`size_average` + is set to ``False``, the losses are instead summed for each minibatch. Ignored + when :attr:`reduce` is ``False``. Default: ``True`` + reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the + losses are averaged or summed over observations for each minibatch depending + on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per + batch element instead and ignores :attr:`size_average`. Default: ``True`` + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the sum of the output will be divided by the number of + elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` + and :attr:`reduce` are in the process of being deprecated, and in the meantime, + specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` + + Shape: + - Input: :math:`(N, C)` where `N` is the batch size and `C` is the number of classes. + - Target: :math:`(N, C)`, label targets must have the same shape as the input. + - Output: scalar. If :attr:`reduction` is ``'none'``, then :math:`(N)`. + """ + __constants__ = ['reduction'] + + def forward(self, input: Tensor, target: Tensor) -> Tensor: + return F.multilabel_soft_margin_loss(input, target, weight=self.weight, reduction=self.reduction) + + + + +class CosineEmbeddingLoss(_Loss): + r"""Creates a criterion that measures the loss given input tensors + :math:`x_1`, :math:`x_2` and a `Tensor` label :math:`y` with values 1 or -1. + Use (:math:`y=1`) to maximize the cosine similarity of two inputs, and (:math:`y=-1`) otherwise. + This is typically used for learning nonlinear + embeddings or semi-supervised learning. + + The loss function for each sample is: + + .. math:: + \text{loss}(x, y) = + \begin{cases} + 1 - \cos(x_1, x_2), & \text{if } y = 1 \\ + \max(0, \cos(x_1, x_2) - \text{margin}), & \text{if } y = -1 + \end{cases} + + Args: + margin (float, optional): Should be a number from :math:`-1` to :math:`1`, + :math:`0` to :math:`0.5` is suggested. If :attr:`margin` is missing, the + default value is :math:`0`. + size_average (bool, optional): Deprecated (see :attr:`reduction`). By default, + the losses are averaged over each loss element in the batch. Note that for + some losses, there are multiple elements per sample. If the field :attr:`size_average` + is set to ``False``, the losses are instead summed for each minibatch. Ignored + when :attr:`reduce` is ``False``. Default: ``True`` + reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the + losses are averaged or summed over observations for each minibatch depending + on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per + batch element instead and ignores :attr:`size_average`. Default: ``True`` + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the sum of the output will be divided by the number of + elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` + and :attr:`reduce` are in the process of being deprecated, and in the meantime, + specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` + + Shape: + - Input1: :math:`(N, D)` or :math:`(D)`, where `N` is the batch size and `D` is the embedding dimension. + - Input2: :math:`(N, D)` or :math:`(D)`, same shape as Input1. + - Target: :math:`(N)` or :math:`()`. + - Output: If :attr:`reduction` is ``'none'``, then :math:`(N)`, otherwise scalar. + + Examples:: + + >>> loss = nn.CosineEmbeddingLoss() + >>> input1 = core.randn(3, 5, requires_grad=True) + >>> input2 = core.randn(3, 5, requires_grad=True) + >>> target = core.ones(3) + >>> output = loss(input1, input2, target) + >>> output.backward() + """ + __constants__ = ['margin', 'reduction'] + margin: float + + def __init__(self, margin: float = 0., size_average=None, reduce=None, reduction: str = 'mean') -> None: + super().__init__(size_average, reduce, reduction) + self.margin = margin + + def forward(self, input1: Tensor, input2: Tensor, target: Tensor) -> Tensor: + return F.cosine_embedding_loss(input1, input2, target, margin=self.margin, reduction=self.reduction) + + + + +class MarginRankingLoss(_Loss): + r"""Creates a criterion that measures the loss given + inputs :math:`x1`, :math:`x2`, two 1D mini-batch or 0D `Tensors`, + and a label 1D mini-batch or 0D `Tensor` :math:`y` (containing 1 or -1). + + If :math:`y = 1` then it assumed the first input should be ranked higher + (have a larger value) than the second input, and vice-versa for :math:`y = -1`. + + The loss function for each pair of samples in the mini-batch is: + + .. math:: + \text{loss}(x1, x2, y) = \max(0, -y * (x1 - x2) + \text{margin}) + + Args: + margin (float, optional): Has a default value of :math:`0`. + size_average (bool, optional): Deprecated (see :attr:`reduction`). By default, + the losses are averaged over each loss element in the batch. Note that for + some losses, there are multiple elements per sample. If the field :attr:`size_average` + is set to ``False``, the losses are instead summed for each minibatch. Ignored + when :attr:`reduce` is ``False``. Default: ``True`` + reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the + losses are averaged or summed over observations for each minibatch depending + on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per + batch element instead and ignores :attr:`size_average`. Default: ``True`` + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the sum of the output will be divided by the number of + elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` + and :attr:`reduce` are in the process of being deprecated, and in the meantime, + specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` + + Shape: + - Input1: :math:`(N)` or :math:`()` where `N` is the batch size. + - Input2: :math:`(N)` or :math:`()`, same shape as the Input1. + - Target: :math:`(N)` or :math:`()`, same shape as the inputs. + - Output: scalar. If :attr:`reduction` is ``'none'`` and Input size is not :math:`()`, then :math:`(N)`. + + Examples:: + + >>> loss = nn.MarginRankingLoss() + >>> input1 = core.randn(3, requires_grad=True) + >>> input2 = core.randn(3, requires_grad=True) + >>> target = core.randn(3).sign() + >>> output = loss(input1, input2, target) + >>> output.backward() + """ + __constants__ = ['margin', 'reduction'] + margin: float + + def __init__(self, margin: float = 0., size_average=None, reduce=None, reduction: str = 'mean') -> None: + super().__init__(size_average, reduce, reduction) + self.margin = margin + + def forward(self, input1: Tensor, input2: Tensor, target: Tensor) -> Tensor: + return F.margin_ranking_loss(input1, input2, target, margin=self.margin, reduction=self.reduction) + + + + +class MultiMarginLoss(_WeightedLoss): + r"""Creates a criterion that optimizes a multi-class classification hinge + loss (margin-based loss) between input :math:`x` (a 2D mini-batch `Tensor`) and + output :math:`y` (which is a 1D tensor of target class indices, + :math:`0 \leq y \leq \text{x.size}(1)-1`): + + For each mini-batch sample, the loss in terms of the 1D input :math:`x` and scalar + output :math:`y` is: + + .. math:: + \text{loss}(x, y) = \frac{\sum_i \max(0, \text{margin} - x[y] + x[i])^p}{\text{x.size}(0)} + + where :math:`i \in \left\{0, \; \cdots , \; \text{x.size}(0) - 1\right\}` + and :math:`i \neq y`. + + Optionally, you can give non-equal weighting on the classes by passing + a 1D :attr:`weight` tensor into the constructor. + + The loss function then becomes: + + .. math:: + \text{loss}(x, y) = \frac{\sum_i w[y] * \max(0, \text{margin} - x[y] + x[i])^p}{\text{x.size}(0)} + + Args: + p (int, optional): Has a default value of :math:`1`. :math:`1` and :math:`2` + are the only supported values. + margin (float, optional): Has a default value of :math:`1`. + weight (Tensor, optional): a manual rescaling weight given to each + class. If given, it has to be a Tensor of size `C`. Otherwise, it is + treated as if having all ones. + size_average (bool, optional): Deprecated (see :attr:`reduction`). By default, + the losses are averaged over each loss element in the batch. Note that for + some losses, there are multiple elements per sample. If the field :attr:`size_average` + is set to ``False``, the losses are instead summed for each minibatch. Ignored + when :attr:`reduce` is ``False``. Default: ``True`` + reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the + losses are averaged or summed over observations for each minibatch depending + on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per + batch element instead and ignores :attr:`size_average`. Default: ``True`` + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the sum of the output will be divided by the number of + elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` + and :attr:`reduce` are in the process of being deprecated, and in the meantime, + specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` + + Shape: + - Input: :math:`(N, C)` or :math:`(C)`, where :math:`N` is the batch size and :math:`C` is the number of classes. + - Target: :math:`(N)` or :math:`()`, where each value is :math:`0 \leq \text{targets}[i] \leq C-1`. + - Output: scalar. If :attr:`reduction` is ``'none'``, then same shape as the target. + + Examples:: + + >>> loss = nn.MultiMarginLoss() + >>> x = core.tensor([[0.1, 0.2, 0.4, 0.8]]) + >>> y = core.tensor([3]) + >>> # 0.25 * ((1-(0.8-0.1)) + (1-(0.8-0.2)) + (1-(0.8-0.4))) + >>> loss(x, y) + tensor(0.32...) + """ + __constants__ = ['p', 'margin', 'reduction'] + margin: float + p: int + + def __init__(self, p: int = 1, margin: float = 1., weight: Optional[Tensor] = None, size_average=None, + reduce=None, reduction: str = 'mean') -> None: + super().__init__(weight, size_average, reduce, reduction) + if p not in (1, 2): + raise ValueError("only p == 1 and p == 2 supported") + if weight is not None and weight.dim() != 1 : + raise ValueError( + f"MultiMarginLoss: expected weight to be None or 1D tensor, got {weight.dim()}D instead" + ) + self.p = p + self.margin = margin + + def forward(self, input: Tensor, target: Tensor) -> Tensor: + return F.multi_margin_loss(input, target, p=self.p, margin=self.margin, + weight=self.weight, reduction=self.reduction) + + + + +class TripletMarginLoss(_Loss): + r"""Creates a criterion that measures the triplet loss given an input + tensors :math:`x1`, :math:`x2`, :math:`x3` and a margin with a value greater than :math:`0`. + This is used for measuring a relative similarity between samples. A triplet + is composed by `a`, `p` and `n` (i.e., `anchor`, `positive examples` and `negative + examples` respectively). The shapes of all input tensors should be + :math:`(N, D)`. + + The distance swap is described in detail in the paper `Learning shallow + convolutional feature descriptors with triplet losses`_ by + V. Balntas, E. Riba et al. + + The loss function for each sample in the mini-batch is: + + .. math:: + L(a, p, n) = \max \{d(a_i, p_i) - d(a_i, n_i) + {\rm margin}, 0\} + + + where + + .. math:: + d(x_i, y_i) = \left\lVert {\bf x}_i - {\bf y}_i \right\rVert_p + + The norm is calculated using the specified p value and a small constant :math:`\varepsilon` is + added for numerical stability. + + See also :class:`~core.nn.TripletMarginWithDistanceLoss`, which computes the + triplet margin loss for input tensors using a custom distance function. + + Args: + margin (float, optional): Default: :math:`1`. + p (int, optional): The norm degree for pairwise distance. Default: :math:`2`. + eps (float, optional): Small constant for numerical stability. Default: :math:`1e-6`. + swap (bool, optional): The distance swap is described in detail in the paper + `Learning shallow convolutional feature descriptors with triplet losses` by + V. Balntas, E. Riba et al. Default: ``False``. + size_average (bool, optional): Deprecated (see :attr:`reduction`). By default, + the losses are averaged over each loss element in the batch. Note that for + some losses, there are multiple elements per sample. If the field :attr:`size_average` + is set to ``False``, the losses are instead summed for each minibatch. Ignored + when :attr:`reduce` is ``False``. Default: ``True`` + reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the + losses are averaged or summed over observations for each minibatch depending + on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per + batch element instead and ignores :attr:`size_average`. Default: ``True`` + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the sum of the output will be divided by the number of + elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` + and :attr:`reduce` are in the process of being deprecated, and in the meantime, + specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` + + Shape: + - Input: :math:`(N, D)` or :math:`(D)` where :math:`D` is the vector dimension. + - Output: A Tensor of shape :math:`(N)` if :attr:`reduction` is ``'none'`` and + input shape is :math:`(N, D)`; a scalar otherwise. + + Examples:: + + >>> triplet_loss = nn.TripletMarginLoss(margin=1.0, p=2, eps=1e-7) + >>> anchor = core.randn(100, 128, requires_grad=True) + >>> positive = core.randn(100, 128, requires_grad=True) + >>> negative = core.randn(100, 128, requires_grad=True) + >>> output = triplet_loss(anchor, positive, negative) + >>> output.backward() + + .. _Learning shallow convolutional feature descriptors with triplet losses: + http://www.bmva.org/bmvc/2016/papers/paper119/index.html + """ + __constants__ = ['margin', 'p', 'eps', 'swap', 'reduction'] + margin: float + p: float + eps: float + swap: bool + + def __init__(self, margin: float = 1.0, p: float = 2., eps: float = 1e-6, swap: bool = False, size_average=None, + reduce=None, reduction: str = 'mean'): + super().__init__(size_average, reduce, reduction) + if margin <= 0: + raise ValueError( + f"TripletMarginLoss: expected margin to be greater than 0, got {margin} instead" + ) + self.margin = margin + self.p = p + self.eps = eps + self.swap = swap + + def forward(self, anchor: Tensor, positive: Tensor, negative: Tensor) -> Tensor: + return F.triplet_margin_loss(anchor, positive, negative, margin=self.margin, p=self.p, + eps=self.eps, swap=self.swap, reduction=self.reduction) + + + + +class TripletMarginWithDistanceLoss(_Loss): + r"""Creates a criterion that measures the triplet loss given input + tensors :math:`a`, :math:`p`, and :math:`n` (representing anchor, + positive, and negative examples, respectively), and a nonnegative, + real-valued function ("distance function") used to compute the relationship + between the anchor and positive example ("positive distance") and the + anchor and negative example ("negative distance"). + + The unreduced loss (i.e., with :attr:`reduction` set to ``'none'``) + can be described as: + + .. math:: + \ell(a, p, n) = L = \{l_1,\dots,l_N\}^\top, \quad + l_i = \max \{d(a_i, p_i) - d(a_i, n_i) + {\rm margin}, 0\} + + where :math:`N` is the batch size; :math:`d` is a nonnegative, real-valued function + quantifying the closeness of two tensors, referred to as the :attr:`distance_function`; + and :math:`margin` is a nonnegative margin representing the minimum difference + between the positive and negative distances that is required for the loss to + be 0. The input tensors have :math:`N` elements each and can be of any shape + that the distance function can handle. + + If :attr:`reduction` is not ``'none'`` + (default ``'mean'``), then: + + .. math:: + \ell(x, y) = + \begin{cases} + \operatorname{mean}(L), & \text{if reduction} = \text{`mean';}\\ + \operatorname{sum}(L), & \text{if reduction} = \text{`sum'.} + \end{cases} + + See also :class:`~core.nn.TripletMarginLoss`, which computes the triplet + loss for input tensors using the :math:`l_p` distance as the distance function. + + Args: + distance_function (Callable, optional): A nonnegative, real-valued function that + quantifies the closeness of two tensors. If not specified, + `nn.PairwiseDistance` will be used. Default: ``None`` + margin (float, optional): A nonnegative margin representing the minimum difference + between the positive and negative distances required for the loss to be 0. Larger + margins penalize cases where the negative examples are not distant enough from the + anchors, relative to the positives. Default: :math:`1`. + swap (bool, optional): Whether to use the distance swap described in the paper + `Learning shallow convolutional feature descriptors with triplet losses` by + V. Balntas, E. Riba et al. If True, and if the positive example is closer to the + negative example than the anchor is, swaps the positive example and the anchor in + the loss computation. Default: ``False``. + reduction (str, optional): Specifies the (optional) reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the sum of the output will be divided by the number of + elements in the output, ``'sum'``: the output will be summed. Default: ``'mean'`` + + + Shape: + - Input: :math:`(N, *)` where :math:`*` represents any number of additional dimensions + as supported by the distance function. + - Output: A Tensor of shape :math:`(N)` if :attr:`reduction` is ``'none'``, or a scalar + otherwise. + + Examples:: + + >>> # Initialize embeddings + >>> embedding = nn.Embedding(1000, 128) + >>> anchor_ids = core.randint(0, 1000, (1,)) + >>> positive_ids = core.randint(0, 1000, (1,)) + >>> negative_ids = core.randint(0, 1000, (1,)) + >>> anchor = embedding(anchor_ids) + >>> positive = embedding(positive_ids) + >>> negative = embedding(negative_ids) + >>> + >>> # Built-in Distance Function + >>> triplet_loss = \ + >>> nn.TripletMarginWithDistanceLoss(distance_function=nn.PairwiseDistance()) + >>> output = triplet_loss(anchor, positive, negative) + >>> output.backward() + >>> + >>> # Custom Distance Function + >>> def l_infinity(x1, x2): + >>> return core.max(core.abs(x1 - x2), dim=1).values + >>> + >>> # xdoctest: +SKIP("FIXME: Would call backwards a second time") + >>> triplet_loss = ( + >>> nn.TripletMarginWithDistanceLoss(distance_function=l_infinity, margin=1.5)) + >>> output = triplet_loss(anchor, positive, negative) + >>> output.backward() + >>> + >>> # Custom Distance Function (Lambda) + >>> triplet_loss = ( + >>> nn.TripletMarginWithDistanceLoss( + >>> distance_function=lambda x, y: 1.0 - F.cosine_similarity(x, y))) + >>> output = triplet_loss(anchor, positive, negative) + >>> output.backward() + + Reference: + V. Balntas, et al.: Learning shallow convolutional feature descriptors with triplet losses: + http://www.bmva.org/bmvc/2016/papers/paper119/index.html + """ + __constants__ = ['margin', 'swap', 'reduction'] + margin: float + swap: bool + + def __init__(self, *, distance_function: Optional[Callable[[Tensor, Tensor], Tensor]] = None, + margin: float = 1.0, swap: bool = False, reduction: str = 'mean'): + super().__init__(size_average=None, reduce=None, reduction=reduction) + if margin <= 0: + raise ValueError( + f"TripletMarginWithDistanceLoss: expected margin to be greater than 0, got {margin} instead" + ) + self.distance_function: Optional[Callable[[Tensor, Tensor], Tensor]] = \ + distance_function if distance_function is not None else PairwiseDistance() + self.margin = margin + self.swap = swap + + def forward(self, anchor: Tensor, positive: Tensor, negative: Tensor) -> Tensor: + return F.triplet_margin_with_distance_loss(anchor, positive, negative, + distance_function=self.distance_function, + margin=self.margin, swap=self.swap, reduction=self.reduction) + + + + +class CTCLoss(_Loss): + r"""The Connectionist Temporal Classification loss. + + Calculates loss between a continuous (unsegmented) time series and a target sequence. CTCLoss sums over the + probability of possible alignments of input to target, producing a loss value which is differentiable + with respect to each input node. The alignment of input to target is assumed to be "many-to-one", which + limits the length of the target sequence such that it must be :math:`\leq` the input length. + + Args: + blank (int, optional): blank label. Default :math:`0`. + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the output losses will be divided by the target lengths and + then the mean over the batch is taken, ``'sum'``: the output losses will be summed. + Default: ``'mean'`` + zero_infinity (bool, optional): + Whether to zero infinite losses and the associated gradients. + Default: ``False`` + Infinite losses mainly occur when the inputs are too short + to be aligned to the targets. + + Shape: + - Log_probs: Tensor of size :math:`(T, N, C)` or :math:`(T, C)`, + where :math:`T = \text{input length}`, + :math:`N = \text{batch size}`, and + :math:`C = \text{number of classes (including blank)}`. + The logarithmized probabilities of the outputs (e.g. obtained with + :func:`core.nn.functional.log_softmax`). + - Targets: Tensor of size :math:`(N, S)` or + :math:`(\operatorname{sum}(\text{target\_lengths}))`, + where :math:`N = \text{batch size}` and + :math:`S = \text{max target length, if shape is } (N, S)`. + It represents the target sequences. Each element in the target + sequence is a class index. And the target index cannot be blank (default=0). + In the :math:`(N, S)` form, targets are padded to the + length of the longest sequence, and stacked. + In the :math:`(\operatorname{sum}(\text{target\_lengths}))` form, + the targets are assumed to be un-padded and + concatenated within 1 dimension. + - Input_lengths: Tuple or tensor of size :math:`(N)` or :math:`()`, + where :math:`N = \text{batch size}`. It represents the lengths of the + inputs (must each be :math:`\leq T`). And the lengths are specified + for each sequence to achieve masking under the assumption that sequences + are padded to equal lengths. + - Target_lengths: Tuple or tensor of size :math:`(N)` or :math:`()`, + where :math:`N = \text{batch size}`. It represents lengths of the targets. + Lengths are specified for each sequence to achieve masking under the + assumption that sequences are padded to equal lengths. If target shape is + :math:`(N,S)`, target_lengths are effectively the stop index + :math:`s_n` for each target sequence, such that ``target_n = targets[n,0:s_n]`` for + each target in a batch. Lengths must each be :math:`\leq S` + If the targets are given as a 1d tensor that is the concatenation of individual + targets, the target_lengths must add up to the total length of the tensor. + - Output: scalar if :attr:`reduction` is ``'mean'`` (default) or + ``'sum'``. If :attr:`reduction` is ``'none'``, then :math:`(N)` if input is batched or + :math:`()` if input is unbatched, where :math:`N = \text{batch size}`. + + Examples:: + + >>> # Target are to be padded + >>> T = 50 # Input sequence length + >>> C = 20 # Number of classes (including blank) + >>> N = 16 # Batch size + >>> S = 30 # Target sequence length of longest target in batch (padding length) + >>> S_min = 10 # Minimum target length, for demonstration purposes + >>> + >>> # Initialize random batch of input vectors, for *size = (T,N,C) + >>> input = core.randn(T, N, C).log_softmax(2).detach().requires_grad_() + >>> + >>> # Initialize random batch of targets (0 = blank, 1:C = classes) + >>> target = core.randint(low=1, high=C, size=(N, S), dtype=core.long) + >>> + >>> input_lengths = core.full(size=(N,), fill_value=T, dtype=core.long) + >>> target_lengths = core.randint(low=S_min, high=S, size=(N,), dtype=core.long) + >>> ctc_loss = nn.CTCLoss() + >>> loss = ctc_loss(input, target, input_lengths, target_lengths) + >>> loss.backward() + >>> + >>> + >>> # Target are to be un-padded + >>> T = 50 # Input sequence length + >>> C = 20 # Number of classes (including blank) + >>> N = 16 # Batch size + >>> + >>> # Initialize random batch of input vectors, for *size = (T,N,C) + >>> input = core.randn(T, N, C).log_softmax(2).detach().requires_grad_() + >>> input_lengths = core.full(size=(N,), fill_value=T, dtype=core.long) + >>> + >>> # Initialize random batch of targets (0 = blank, 1:C = classes) + >>> target_lengths = core.randint(low=1, high=T, size=(N,), dtype=core.long) + >>> target = core.randint(low=1, high=C, size=(sum(target_lengths),), dtype=core.long) + >>> ctc_loss = nn.CTCLoss() + >>> loss = ctc_loss(input, target, input_lengths, target_lengths) + >>> loss.backward() + >>> + >>> + >>> # Target are to be un-padded and unbatched (effectively N=1) + >>> T = 50 # Input sequence length + >>> C = 20 # Number of classes (including blank) + >>> + >>> # Initialize random batch of input vectors, for *size = (T,C) + >>> # xdoctest: +SKIP("FIXME: error in doctest") + >>> input = core.randn(T, C).log_softmax(1).detach().requires_grad_() + >>> input_lengths = core.tensor(T, dtype=core.long) + >>> + >>> # Initialize random batch of targets (0 = blank, 1:C = classes) + >>> target_lengths = core.randint(low=1, high=T, size=(), dtype=core.long) + >>> target = core.randint(low=1, high=C, size=(target_lengths,), dtype=core.long) + >>> ctc_loss = nn.CTCLoss() + >>> loss = ctc_loss(input, target, input_lengths, target_lengths) + >>> loss.backward() + + Reference: + A. Graves et al.: Connectionist Temporal Classification: + Labelling Unsegmented Sequence Data with Recurrent Neural Networks: + https://www.cs.toronto.edu/~graves/icml_2006.pdf + + Note: + In order to use CuDNN, the following must be satisfied: :attr:`targets` must be + in concatenated format, all :attr:`input_lengths` must be `T`. :math:`blank=0`, + :attr:`target_lengths` :math:`\leq 256`, the integer arguments must be of + dtype :attr:`core.int32`. + + The regular implementation uses the (more common in PyTorch) `core.long` dtype. + + + Note: + In some circumstances when using the CUDA backend with CuDNN, this operator + may select a nondeterministic algorithm to increase performance. If this is + undesirable, you can try to make the operation deterministic (potentially at + a performance cost) by setting ``core.backends.cudnn.deterministic = + True``. + Please see the notes on :doc:`/notes/randomness` for background. + """ + __constants__ = ['blank', 'reduction'] + blank: int + zero_infinity: bool + + def __init__(self, blank: int = 0, reduction: str = 'mean', zero_infinity: bool = False): + super().__init__(reduction=reduction) + self.blank = blank + self.zero_infinity = zero_infinity + + def forward(self, log_probs: Tensor, targets: Tensor, input_lengths: Tensor, target_lengths: Tensor) -> Tensor: + return F.ctc_loss(log_probs, targets, input_lengths, target_lengths, self.blank, self.reduction, + self.zero_infinity) diff --git a/mindnlp/core/nn/modules/module.py b/mindnlp/core/nn/modules/module.py new file mode 100644 index 000000000..c259a903e --- /dev/null +++ b/mindnlp/core/nn/modules/module.py @@ -0,0 +1,2271 @@ +"""Module""" +import warnings +import weakref +import functools +from typing import Dict, Optional, Callable, Set, overload, TypeVar, Any, Iterator, Tuple, Union, \ + Mapping, List +import itertools +from collections import OrderedDict, namedtuple +import mindspore +try: + from mindspore.common._stub_tensor import StubTensor +except: + class StubTensor: pass + +from mindnlp import core +from mindnlp.core import device, dtype, Tensor + +from ..parameter import Parameter +from ...utils import hooks +from ...utils.hooks import RemovableHandle + +_grad_t = Union[Tuple[Tensor, ...], Tensor] +T = TypeVar('T', bound='Module') + +class _IncompatibleKeys(namedtuple('IncompatibleKeys', ['missing_keys', 'unexpected_keys'])): + def __repr__(self): + if not self.missing_keys and not self.unexpected_keys: + return '' + return super().__repr__() + + __str__ = __repr__ + +def _addindent(s_, numSpaces): + s = s_.split('\n') + # don't do anything for single-line stuff + if len(s) == 1: + return s_ + first = s.pop(0) + s = [(numSpaces * ' ') + line for line in s] + s = '\n'.join(s) + s = first + '\n' + s + return s + +_EXTRA_STATE_KEY_SUFFIX = '_extra_state' + +_global_buffer_registration_hooks: Dict[int, Callable] = OrderedDict() +_global_module_registration_hooks: Dict[int, Callable] = OrderedDict() +_global_parameter_registration_hooks: Dict[int, Callable] = OrderedDict() + + +_global_backward_pre_hooks: Dict[int, Callable] = OrderedDict() +_global_backward_hooks: Dict[int, Callable] = OrderedDict() +_global_is_full_backward_hook: Optional[bool] = None +_global_forward_pre_hooks: Dict[int, Callable] = OrderedDict() +_global_forward_hooks: Dict[int, Callable] = OrderedDict() +_global_forward_hooks_always_called: Dict[int, bool] = OrderedDict() +_global_forward_hooks_with_kwargs: Dict[int, bool] = OrderedDict() + + +class _WrappedHook: + def __init__(self, hook: Callable, module: Optional["Module"] = None): + self.hook: Callable = hook + functools.update_wrapper(self, hook) + + self.with_module: bool = False + + if module is not None: + self.module: weakref.ReferenceType[Module] = weakref.ref(module) + self.with_module = True + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + if self.with_module: + module = self.module() + if module is None: + raise RuntimeError("You are trying to call the hook of a dead Module!") + return self.hook(module, *args, **kwargs) + return self.hook(*args, **kwargs) + + def __getstate__(self) -> Dict: + result = {"hook": self.hook, "with_module": self.with_module} + if self.with_module: + result["module"] = self.module() + + return result + + def __setstate__(self, state: Dict): + self.hook = state["hook"] + self.with_module = state["with_module"] + + if self.with_module: + if state["module"] is None: + raise RuntimeError("You are trying to revive the hook of a dead Module!") + self.module = weakref.ref(state["module"]) + + +def register_module_buffer_registration_hook( + hook: Callable[..., None], +) -> RemovableHandle: + r"""Register a buffer registration hook common to all modules. + + .. warning :: + + This adds global state to the `nn.Module` module + + The hook will be called every time :func:`register_buffer` is invoked. + It should have the following signature:: + + hook(module, name, buffer) -> None or new buffer + + The hook can modify the input or return a single modified value in the hook. + + Returns: + :class:`core.utils.hooks.RemovableHandle`: + a handle that can be used to remove the added hook by calling + ``handle.remove()`` + """ + handle = RemovableHandle(_global_buffer_registration_hooks) + _global_buffer_registration_hooks[handle.id] = hook + return handle + + +def register_module_module_registration_hook( + hook: Callable[..., None], +) -> RemovableHandle: + r"""Register a module registration hook common to all modules. + + .. warning :: + + This adds global state to the `nn.Module` module + + The hook will be called every time :func:`register_module` is invoked. + It should have the following signature:: + + hook(module, name, submodule) -> None or new submodule + + The hook can modify the input or return a single modified value in the hook. + + Returns: + :class:`core.utils.hooks.RemovableHandle`: + a handle that can be used to remove the added hook by calling + ``handle.remove()`` + """ + handle = RemovableHandle(_global_module_registration_hooks) + _global_module_registration_hooks[handle.id] = hook + return handle + + +def register_module_parameter_registration_hook( + hook: Callable[..., None], +) -> RemovableHandle: + r"""Register a parameter registration hook common to all modules. + + .. warning :: + + This adds global state to the `nn.Module` module + + The hook will be called every time :func:`register_parameter` is invoked. + It should have the following signature:: + + hook(module, name, param) -> None or new parameter + + The hook can modify the input or return a single modified value in the hook. + + Returns: + :class:`core.utils.hooks.RemovableHandle`: + a handle that can be used to remove the added hook by calling + ``handle.remove()`` + """ + handle = RemovableHandle(_global_parameter_registration_hooks) + _global_parameter_registration_hooks[handle.id] = hook + return handle + + +def register_module_forward_pre_hook(hook: Callable[..., None]) -> RemovableHandle: + r"""Register a forward pre-hook common to all modules. + + .. warning :: + + This adds global state to the `nn.module` module + and it is only intended for debugging/profiling purposes. + + The hook will be called every time before :func:`forward` is invoked. + It should have the following signature:: + + hook(module, input) -> None or modified input + + The input contains only the positional arguments given to the module. + Keyword arguments won't be passed to the hooks and only to the ``forward``. + The hook can modify the input. User can either return a tuple or a + single modified value in the hook. We will wrap the value into a tuple + if a single value is returned(unless that value is already a tuple). + + This hook has precedence over the specific module hooks registered with + ``register_forward_pre_hook``. + + Returns: + :class:`core.utils.hooks.RemovableHandle`: + a handle that can be used to remove the added hook by calling + ``handle.remove()`` + """ + handle = RemovableHandle(_global_forward_pre_hooks) + _global_forward_pre_hooks[handle.id] = hook + return handle + + +def register_module_forward_hook( + hook: Callable[..., None], + *, + with_kwargs: bool = False, + always_call: bool = False, +) -> RemovableHandle: + r"""Register a global forward hook for all the modules. + + .. warning :: + + This adds global state to the `nn.module` module + and it is only intended for debugging/profiling purposes. + + The hook will be called every time after :func:`forward` has computed an output. + It should have the following signature:: + + hook(module, input, output) -> None or modified output + + The input contains only the positional arguments given to the module. + Keyword arguments won't be passed to the hooks and only to the ``forward``. + You can optionally modify the output of the module by returning a new value + that will replace the output from the :func:`forward` function. + + Parameters: + hook (Callable): The user defined hook to be registered. + always_call (bool): If ``True`` the ``hook`` will be run regardless of + whether an exception is raised while calling the Module. + Default: ``False`` + Returns: + :class:`core.utils.hooks.RemovableHandle`: + a handle that can be used to remove the added hook by calling + ``handle.remove()`` + + This hook will be executed before specific module hooks registered with + ``register_forward_hook``. + """ + handle = RemovableHandle( + _global_forward_hooks, extra_dict=_global_forward_hooks_always_called + ) + _global_forward_hooks[handle.id] = hook + if with_kwargs: + _global_forward_hooks_with_kwargs[handle.id] = True + if always_call: + _global_forward_hooks_always_called[handle.id] = True + return handle + + +def register_module_backward_hook( + hook: Callable[["Module", _grad_t, _grad_t], Union[None, _grad_t]], +) -> RemovableHandle: + r"""Register a backward hook common to all the modules. + + This function is deprecated in favor of + :func:`core.nn.modules.module.register_module_full_backward_hook` + and the behavior of this function will change in future versions. + + Returns: + :class:`core.utils.hooks.RemovableHandle`: + a handle that can be used to remove the added hook by calling + ``handle.remove()`` + + """ + global _global_is_full_backward_hook + if _global_is_full_backward_hook is True: + raise RuntimeError( + "Cannot use both regular backward hooks and full backward hooks as a " + "global Module hook. Please use only one of them." + ) + + _global_is_full_backward_hook = False + + handle = RemovableHandle(_global_backward_hooks) + _global_backward_hooks[handle.id] = hook + return handle + + +def register_module_full_backward_pre_hook( + hook: Callable[["Module", _grad_t], Union[None, _grad_t]], +) -> RemovableHandle: + r"""Register a backward pre-hook common to all the modules. + + .. warning :: + This adds global state to the `nn.module` module + and it is only intended for debugging/profiling purposes. + + Hooks registered using this function behave in the same way as those + registered by :meth:`core.nn.Module.register_full_backward_pre_hook`. + Refer to its documentation for more details. + + Hooks registered using this function will be called before hooks registered + using :meth:`core.nn.Module.register_full_backward_pre_hook`. + + Returns: + :class:`core.utils.hooks.RemovableHandle`: + a handle that can be used to remove the added hook by calling + ``handle.remove()`` + + """ + handle = RemovableHandle(_global_backward_pre_hooks) + _global_backward_pre_hooks[handle.id] = hook + return handle + + +def register_module_full_backward_hook( + hook: Callable[["Module", _grad_t, _grad_t], Union[None, _grad_t]], +) -> RemovableHandle: + r"""Register a backward hook common to all the modules. + + .. warning :: + This adds global state to the `nn.module` module + and it is only intended for debugging/profiling purposes. + + Hooks registered using this function behave in the same way as those + registered by :meth:`core.nn.Module.register_full_backward_hook`. + Refer to its documentation for more details. + + Hooks registered using this function will be called before hooks registered + using :meth:`core.nn.Module.register_full_backward_hook`. + + Returns: + :class:`core.utils.hooks.RemovableHandle`: + a handle that can be used to remove the added hook by calling + ``handle.remove()`` + + """ + global _global_is_full_backward_hook + if _global_is_full_backward_hook is False: + raise RuntimeError( + "Cannot use both regular backward hooks and full backward hooks as a " + "global Module hook. Please use only one of them." + ) + + _global_is_full_backward_hook = True + + handle = RemovableHandle(_global_backward_hooks) + _global_backward_hooks[handle.id] = hook + return handle + + +# Trick mypy into not applying contravariance rules to inputs by defining +# forward as a value, rather than a function. See also +# https://github.com/python/mypy/issues/8795 +def _forward_unimplemented(self, *input: Any) -> None: + r"""Define the computation performed at every call. + + Should be overridden by all subclasses. + + .. note:: + Although the recipe for forward pass needs to be defined within + this function, one should call the :class:`Module` instance afterwards + instead of this since the former takes care of running the + registered hooks while the latter silently ignores them. + """ + raise NotImplementedError( + f'Module [{type(self).__name__}] is missing the required "forward" function' + ) + +class Module: + r"""Base class for all neural network modules. + + Your models should also subclass this class. + + Modules can also contain other Modules, allowing to nest them in + a tree structure. You can assign the submodules as regular attributes:: + + import minispore.nn as nn + import minispore.nn.functional as F + + class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + self.conv1 = nn.Conv2d(1, 20, 5) + self.conv2 = nn.Conv2d(20, 20, 5) + + def forward(self, x): + x = F.relu(self.conv1(x)) + return F.relu(self.conv2(x)) + """ + + __ms_class__ = False + training: bool + _parameters: Dict[str, Optional[Parameter]] + _buffers: Dict[str, Optional[Tensor]] + _non_persistent_buffers_set: Set[str] + _backward_pre_hooks: Dict[int, Callable] + _backward_hooks: Dict[int, Callable] + _is_full_backward_hook: Optional[bool] + _forward_hooks: Dict[int, Callable] + # Marks whether the corresponding _forward_hooks accept kwargs or not. + # As JIT does not support Set[int], this dict is used as a set, where all + # hooks represented in this dict accept kwargs. + _forward_hooks_with_kwargs: Dict[int, bool] + # forward hooks that should always be called even if an exception is raised + _forward_hooks_always_called: Dict[int, bool] + _forward_pre_hooks: Dict[int, Callable] + # Marks whether the corresponding _forward_hooks accept kwargs or not. + # As JIT does not support Set[int], this dict is used as a set, where all + # hooks represented in this dict accept kwargs. + _forward_pre_hooks_with_kwargs: Dict[int, bool] + _state_dict_hooks: Dict[int, Callable] + _load_state_dict_pre_hooks: Dict[int, Callable] + _state_dict_pre_hooks: Dict[int, Callable] + _load_state_dict_post_hooks: Dict[int, Callable] + _modules: Dict[str, Optional['Module']] + call_super_init: bool = False + _compiled_call_impl : Optional[Callable] = None + + def __init__(self): + """ + Calls super().__setattr__('a', a) instead of the typical self.a = a + to avoid Module.__setattr__ overhead. Module's __setattr__ has special + handling for parameters, submodules, and buffers but simply calls into + super().__setattr__ for all other attributes. + """ + super().__setattr__('training', True) + super().__setattr__('_parameters', OrderedDict()) + super().__setattr__('_buffers', OrderedDict()) + super().__setattr__('_non_persistent_buffers_set', set()) + super().__setattr__('_backward_pre_hooks', OrderedDict()) + super().__setattr__('_backward_hooks', OrderedDict()) + super().__setattr__('_is_full_backward_hook', None) + super().__setattr__('_forward_hooks', OrderedDict()) + super().__setattr__('_forward_hooks_with_kwargs', OrderedDict()) + super().__setattr__('_forward_hooks_always_called', OrderedDict()) + super().__setattr__('_forward_pre_hooks', OrderedDict()) + super().__setattr__('_forward_pre_hooks_with_kwargs', OrderedDict()) + super().__setattr__('_state_dict_hooks', OrderedDict()) + super().__setattr__('_state_dict_pre_hooks', OrderedDict()) + super().__setattr__('_load_state_dict_pre_hooks', OrderedDict()) + super().__setattr__('_load_state_dict_post_hooks', OrderedDict()) + super().__setattr__('_modules', OrderedDict()) + + def forward(self, *input, **kwargs): + """Defines the computation performed at every call. + + Should be overriden by all subclasses. + + .. note:: + Although the recipe for forward pass needs to be defined within + this function, one should call the :class:`Module` instance afterwards + instead of this since the former takes care of running the + registered hooks while the latter silently ignores them. + """ + raise NotImplementedError + + def register_buffer(self, name: str, tensor: Optional[Tensor], persistent: bool = True) -> None: + r"""Add a buffer to the module. + + This is typically used to register a buffer that should not to be + considered a model parameter. For example, BatchNorm's ``running_mean`` + is not a parameter, but is part of the module's state. Buffers, by + default, are persistent and will be saved alongside parameters. This + behavior can be changed by setting :attr:`persistent` to ``False``. The + only difference between a persistent buffer and a non-persistent buffer + is that the latter will not be a part of this module's + :attr:`state_dict`. + + Buffers can be accessed as attributes using given names. + + Args: + name (str): name of the buffer. The buffer can be accessed + from this module using the given name + tensor (Tensor or None): buffer to be registered. If ``None``, then operations + that run on buffers, such as :attr:`cuda`, are ignored. If ``None``, + the buffer is **not** included in the module's :attr:`state_dict`. + persistent (bool): whether the buffer is part of this module's + :attr:`state_dict`. + + Example:: + + >>> # xdoctest: +SKIP("undefined vars") + >>> self.register_buffer('running_mean', ops.zeros(num_features)) + + """ + if '_buffers' not in self.__dict__: + raise AttributeError( + "cannot assign buffer before Module.__init__() call") + elif not isinstance(name, str): + raise TypeError(f"buffer name should be a string. Got {type(name)}") + elif '.' in name: + raise KeyError("buffer name can't contain \".\"") + elif name == '': + raise KeyError("buffer name can't be empty string \"\"") + elif hasattr(self, name) and name not in self._buffers: + raise KeyError(f"attribute '{name}' already exists") + elif tensor is not None and not isinstance(tensor, core.Tensor): + raise TypeError(f"cannot assign '{type(tensor)}' object to buffer '{name}' " + "(torch Tensor or None required)" + ) + else: + for hook in _global_buffer_registration_hooks.values(): + output = hook(self, name, tensor) + if output is not None: + tensor = output + if isinstance(tensor, StubTensor): + tensor = mindspore.Tensor(tensor.stub_sync()) + self._buffers[name] = tensor + if persistent: + self._non_persistent_buffers_set.discard(name) + else: + self._non_persistent_buffers_set.add(name) + + def register_parameter(self, name: str, param: Optional[Parameter]) -> None: + r"""Add a parameter to the module. + + The parameter can be accessed as an attribute using given name. + + Args: + name (str): name of the parameter. The parameter can be accessed + from this module using the given name + param (Parameter or None): parameter to be added to the module. If + ``None``, then operations that run on parameters, such as :attr:`cuda`, + are ignored. If ``None``, the parameter is **not** included in the + module's :attr:`state_dict`. + """ + if '_parameters' not in self.__dict__: + raise AttributeError( + "cannot assign parameter before Module.__init__() call") + + elif not isinstance(name, str): + raise TypeError(f"parameter name should be a string. Got {type(name)}") + elif '.' in name: + raise KeyError("parameter name can't contain \".\"") + elif name == '': + raise KeyError("parameter name can't be empty string \"\"") + elif hasattr(self, name) and name not in self._parameters: + raise KeyError(f"attribute '{name}' already exists") + + if param is None: + self._parameters[name] = None + elif not isinstance(param, Parameter): + raise TypeError(f"cannot assign '{type(param)}' object to parameter '{name}' " + "(nn.Parameter or None required)" + ) + else: + for hook in _global_parameter_registration_hooks.values(): + output = hook(self, name, param) + if output is not None: + param = output + self._parameters[name] = param + + def add_module(self, name, module): + """Adds a child module to the current module. + + The module can be accessed as an attribute using the given name. + + Args: + name (string): name of the child module. The child module can be + accessed from this module using the given name + parameter (Module): child module to be added to the module. + """ + if not isinstance(module, Module) and module is not None: + raise TypeError("{} is not a Module subclass".format(type(module))) + if hasattr(self, name) and name not in self._modules: + raise KeyError("attribute '{}' already exists".format(name)) + self._modules[name] = module + + def get_parameter(self, target: str) -> "Parameter": + """Return the parameter given by ``target`` if it exists, otherwise throw an error. + + See the docstring for ``get_submodule`` for a more detailed + explanation of this method's functionality as well as how to + correctly specify ``target``. + + Args: + target: The fully-qualified string name of the Parameter + to look for. (See ``get_submodule`` for how to specify a + fully-qualified string.) + + Returns: + torch.nn.Parameter: The Parameter referenced by ``target`` + + Raises: + AttributeError: If the target string references an invalid + path or resolves to something that is not an + ``nn.Parameter`` + """ + module_path, _, param_name = target.rpartition(".") + + mod: core.nn.Module = self.get_submodule(module_path) + + if not hasattr(mod, param_name): + raise AttributeError( + mod._get_name() + " has no attribute `" + param_name + "`" + ) + + param: core.nn.Parameter = getattr(mod, param_name) + + if not isinstance(param, core.nn.Parameter): + raise AttributeError("`" + param_name + "` is not an nn.Parameter") + + return param + + def get_buffer(self, target: str) -> "Tensor": + """Return the buffer given by ``target`` if it exists, otherwise throw an error. + + See the docstring for ``get_submodule`` for a more detailed + explanation of this method's functionality as well as how to + correctly specify ``target``. + + Args: + target: The fully-qualified string name of the buffer + to look for. (See ``get_submodule`` for how to specify a + fully-qualified string.) + + Returns: + torch.Tensor: The buffer referenced by ``target`` + + Raises: + AttributeError: If the target string references an invalid + path or resolves to something that is not a + buffer + """ + module_path, _, buffer_name = target.rpartition(".") + + mod: core.nn.Module = self.get_submodule(module_path) + + if not hasattr(mod, buffer_name): + raise AttributeError( + mod._get_name() + " has no attribute `" + buffer_name + "`" + ) + + buffer: core.Tensor = getattr(mod, buffer_name) + + if buffer_name not in mod._buffers: + raise AttributeError("`" + buffer_name + "` is not a buffer") + + return buffer + + + def get_extra_state(self) -> Any: + """Return any extra state to include in the module's state_dict. + + Implement this and a corresponding :func:`set_extra_state` for your module + if you need to store extra state. This function is called when building the + module's `state_dict()`. + + Note that extra state should be picklable to ensure working serialization + of the state_dict. We only provide provide backwards compatibility guarantees + for serializing Tensors; other objects may break backwards compatibility if + their serialized pickled form changes. + + Returns: + object: Any extra state to store in the module's state_dict + """ + raise RuntimeError( + "Reached a code path in Module.get_extra_state() that should never be called. " + "Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml " + "to report this bug.") + + + def set_extra_state(self, state: Any) -> None: + """Set extra state contained in the loaded `state_dict`. + + This function is called from :func:`load_state_dict` to handle any extra state + found within the `state_dict`. Implement this function and a corresponding + :func:`get_extra_state` for your module if you need to store extra state within its + `state_dict`. + + Args: + state (dict): Extra state from the `state_dict` + """ + raise RuntimeError( + "Reached a code path in Module.set_extra_state() that should never be called. " + "Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml " + "to report this bug.") + + def _apply(self, fn, recurse=True): + if recurse: + for module in self.children(): + module._apply(fn) + + def compute_should_use_set_data(tensor, tensor_applied): + if core._has_compatible_shallow_copy_type(tensor, tensor_applied): + # If the new tensor has compatible tensor type as the existing tensor, + # the current behavior is to change the tensor in-place using `.data =`, + # and the future behavior is to overwrite the existing tensor. However, + # changing the current behavior is a BC-breaking change, and we want it + # to happen in future releases. So for now we introduce the + # `core.__future__.get_overwrite_module_params_on_conversion()` + # global flag to let the user control whether they want the future + # behavior of overwriting the existing tensor or not. + return True + else: + return False + + should_use_swap_tensors = False + + for key, param in self._parameters.items(): + if param is None: + continue + # Tensors stored in modules are graph leaves, and we don't want to + # track autograd history of `param_applied`, so we have to use + # `with core.no_grad():` + with core.no_grad(): + param_applied = fn(param) + p_should_use_set_data = compute_should_use_set_data(param, param_applied) + + # subclasses may have multiple child tensors so we need to use swap_tensors + # p_should_use_swap_tensors = ( + # should_use_swap_tensors or is_traceable_wrapper_subclass(param_applied) + # ) + p_should_use_swap_tensors = False + + param_grad = param.grad + if p_should_use_swap_tensors: + try: + if param_grad is not None: + # Accessing param.grad makes its at::Tensor's use_count 2, which will prevent swapping. + # Decrement use count of the gradient by setting to None + param.grad = None + param_applied = core.nn.Parameter( + param_applied, requires_grad=param.requires_grad + ) + core.utils.swap_tensors(param, param_applied) + except Exception as e: + if param_grad is not None: + param.grad = param_grad + raise RuntimeError( + f"_apply(): Couldn't swap {self._get_name()}.{key}" + ) from e + out_param = param + elif p_should_use_set_data: + param.data = param_applied + out_param = param + else: + assert isinstance(param, Parameter) + assert param.is_leaf + out_param = Parameter(param_applied, param.requires_grad) + self._parameters[key] = out_param + + if param_grad is not None: + with core.no_grad(): + grad_applied = fn(param_grad) + g_should_use_set_data = compute_should_use_set_data( + param_grad, grad_applied + ) + if p_should_use_swap_tensors: + grad_applied.requires_grad_(param_grad.requires_grad) + try: + core.utils.swap_tensors(param_grad, grad_applied) + except Exception as e: + raise RuntimeError( + f"_apply(): Couldn't swap {self._get_name()}.{key}.grad" + ) from e + out_param.grad = param_grad + elif g_should_use_set_data: + assert out_param.grad is not None + out_param.grad.data = grad_applied + else: + assert param_grad.is_leaf + out_param.grad = grad_applied.requires_grad_( + param_grad.requires_grad + ) + + for key, buf in self._buffers.items(): + if buf is not None: + self._buffers[key] = fn(buf) + + return self + + def apply(self, fn): + """Applies ``fn`` recursively to every submodule (as returned by ``.children()``) + as well as self. Typical use includes initializing the parameters of a model + (see also :ref:`torch-nn-init`). + + Args: + fn (:class:`Module` -> None): function to be applied to each submodule + + Returns: + Module: self + + Example: + >>> def init_weights(m): + >>> print(m) + >>> if type(m) == nn.Linear: + >>> m.weight.data.fill_(1.0) + >>> print(m.weight) + >>> + >>> net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2)) + >>> net.apply(init_weights) + Linear (2 -> 2) + Parameter containing: + 1 1 + 1 1 + [core.Tensor of size 2x2] + Linear (2 -> 2) + Parameter containing: + 1 1 + 1 1 + [core.Tensor of size 2x2] + Sequential ( + (0): Linear (2 -> 2) + (1): Linear (2 -> 2) + ) + """ + for module in self.children(): + module.apply(fn) + fn(self) + return self + + def _wrapped_call_impl(self, *args, **kwargs): + if self._compiled_call_impl is not None: + return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] + return self._call_impl(*args, **kwargs) + + # torchrec tests the code consistency with the following code + # fmt: off + def _call_impl(self, *args, **kwargs): + forward_call = self.forward + # If we don't have any hooks, we want to skip the rest of the logic in + # this function, and just call forward. + if self.__ms_class__: + return forward_call(*args, **kwargs) + + if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks + or _global_backward_pre_hooks or _global_backward_hooks + or _global_forward_hooks or _global_forward_pre_hooks): + return forward_call(*args, **kwargs) + + try: + result = None + called_always_called_hooks = set() + + full_backward_hooks, non_full_backward_hooks = [], [] + backward_pre_hooks = [] + if self._backward_pre_hooks or _global_backward_pre_hooks: + backward_pre_hooks = self._get_backward_pre_hooks() + + if self._backward_hooks or _global_backward_hooks: + full_backward_hooks, non_full_backward_hooks = self._get_backward_hooks() + + if _global_forward_pre_hooks or self._forward_pre_hooks: + for hook_id, hook in ( + *_global_forward_pre_hooks.items(), + *self._forward_pre_hooks.items(), + ): + if hook_id in self._forward_pre_hooks_with_kwargs: + args_kwargs_result = hook(self, args, kwargs) # type: ignore[misc] + if args_kwargs_result is not None: + if isinstance(args_kwargs_result, tuple) and len(args_kwargs_result) == 2: + args, kwargs = args_kwargs_result + else: + raise RuntimeError( + "forward pre-hook must return None or a tuple " + f"of (new_args, new_kwargs), but got {args_kwargs_result}." + ) + else: + args_result = hook(self, args) + if args_result is not None: + if not isinstance(args_result, tuple): + args_result = (args_result,) + args = args_result + + bw_hook = None + # if full_backward_hooks or backward_pre_hooks: + # bw_hook = BackwardHook(self, full_backward_hooks, backward_pre_hooks) + # args = bw_hook.setup_input_hook(args) + + result = forward_call(*args, **kwargs) + if _global_forward_hooks or self._forward_hooks: + for hook_id, hook in ( + *_global_forward_hooks.items(), + *self._forward_hooks.items(), + ): + # mark that always called hook is run + if hook_id in self._forward_hooks_always_called or hook_id in _global_forward_hooks_always_called: + called_always_called_hooks.add(hook_id) + + if hook_id in self._forward_hooks_with_kwargs: + hook_result = hook(self, args, kwargs, result) + else: + hook_result = hook(self, args, result) + + if hook_result is not None: + result = hook_result + + if bw_hook: + if not isinstance(result, (core.Tensor, tuple)): + warnings.warn("For backward hooks to be called," + " module output should be a Tensor or a tuple of Tensors" + f" but received {type(result)}") + result = bw_hook.setup_output_hook(result) + + # Handle the non-full backward hooks + if non_full_backward_hooks: + var = result + while not isinstance(var, core.Tensor): + if isinstance(var, dict): + var = next(v for v in var.values() if isinstance(v, core.Tensor)) + else: + var = var[0] + # grad_fn = var.grad_fn + # if grad_fn is not None: + # for hook in non_full_backward_hooks: + # grad_fn.register_hook(_WrappedHook(hook, self)) + # self._maybe_warn_non_full_backward_hook(args, result, grad_fn) + + return result + + except Exception: + # run always called hooks if they have not already been run + # For now only forward hooks have the always_call option but perhaps + # this functionality should be added to full backward hooks as well. + for hook_id, hook in _global_forward_hooks.items(): + if hook_id in _global_forward_hooks_always_called and hook_id not in called_always_called_hooks: # type: ignore[possibly-undefined] + try: + hook_result = hook(self, args, result) # type: ignore[possibly-undefined] + if hook_result is not None: + result = hook_result + except Exception as e: + warnings.warn("global module forward hook with ``always_call=True`` raised an exception " + f"that was silenced as another error was raised in forward: {str(e)}") + continue + + for hook_id, hook in self._forward_hooks.items(): + if hook_id in self._forward_hooks_always_called and hook_id not in called_always_called_hooks: # type: ignore[possibly-undefined] + try: + if hook_id in self._forward_hooks_with_kwargs: + hook_result = hook(self, args, kwargs, result) # type: ignore[possibly-undefined] + else: + hook_result = hook(self, args, result) # type: ignore[possibly-undefined] + if hook_result is not None: + result = hook_result + except Exception as e: + warnings.warn("module forward hook with ``always_call=True`` raised an exception " + f"that was silenced as another error was raised in forward: {str(e)}") + continue + # raise exception raised in try block + raise + # fmt: on + + __call__: Callable[..., Any] = _wrapped_call_impl + + def __getstate__(self): + state = self.__dict__.copy() + state.pop("_compiled_call_impl", None) + return state + + def __setstate__(self, state): + self.__dict__.update(state) + + # Support loading old checkpoints that don't have the following attrs: + if "_forward_pre_hooks" not in self.__dict__: + self._forward_pre_hooks = OrderedDict() + if "_forward_pre_hooks_with_kwargs" not in self.__dict__: + self._forward_pre_hooks_with_kwargs = OrderedDict() + if "_forward_hooks_with_kwargs" not in self.__dict__: + self._forward_hooks_with_kwargs = OrderedDict() + if "_forward_hooks_always_called" not in self.__dict__: + self._forward_hooks_always_called = OrderedDict() + if "_state_dict_hooks" not in self.__dict__: + self._state_dict_hooks = OrderedDict() + if "_state_dict_pre_hooks" not in self.__dict__: + self._state_dict_pre_hooks = OrderedDict() + if "_load_state_dict_pre_hooks" not in self.__dict__: + self._load_state_dict_pre_hooks = OrderedDict() + if "_load_state_dict_post_hooks" not in self.__dict__: + self._load_state_dict_post_hooks = OrderedDict() + if "_non_persistent_buffers_set" not in self.__dict__: + self._non_persistent_buffers_set = set() + if "_is_full_backward_hook" not in self.__dict__: + self._is_full_backward_hook = None + if "_backward_pre_hooks" not in self.__dict__: + self._backward_pre_hooks = OrderedDict() + + def __getattr__(self, name): + if '_parameters' in self.__dict__: + _parameters = self.__dict__['_parameters'] + if name in _parameters: + return _parameters[name] + if '_buffers' in self.__dict__: + _buffers = self.__dict__['_buffers'] + if name in _buffers: + return _buffers[name] + if '_modules' in self.__dict__: + modules = self.__dict__['_modules'] + if name in modules: + return modules[name] + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") + + def __setattr__(self, name: str, value: Union[Tensor, 'Module']) -> None: + def remove_from(*dicts_or_sets): + for d in dicts_or_sets: + if name in d: + if isinstance(d, dict): + del d[name] + else: + d.discard(name) + + params = self.__dict__.get('_parameters') + + if isinstance(value, Parameter): + if params is None: + raise AttributeError( + "cannot assign parameters before Module.__init__() call") + remove_from(self.__dict__, self._buffers, self._modules, self._non_persistent_buffers_set) + self.register_parameter(name, value) + elif params is not None and name in params: + if value is not None: + raise TypeError(f"cannot assign '{type(value)}' as parameter '{name}' " + "(core.nn.Parameter or None expected)" + ) + self.register_parameter(name, value) + else: + modules = self.__dict__.get('_modules') + if isinstance(value, Module): + if modules is None: + raise AttributeError( + "cannot assign module before Module.__init__() call") + remove_from(self.__dict__, self._parameters, self._buffers, self._non_persistent_buffers_set) + for hook in _global_module_registration_hooks.values(): + output = hook(self, name, value) + if output is not None: + value = output + modules[name] = value + elif modules is not None and name in modules: + if value is not None: + raise TypeError(f"cannot assign '{type(value)}' as child module '{name}' " + "(nn.Module or None expected)" + ) + for hook in _global_module_registration_hooks.values(): + output = hook(self, name, value) + if output is not None: + value = output + modules[name] = value + else: + buffers = self.__dict__.get('_buffers') + if buffers is not None and name in buffers: + if value is not None and not isinstance(value, Tensor): + raise TypeError(f"cannot assign '{type(value)}' as buffer '{name}' " + "(core.Tensor or None expected)" + ) + for hook in _global_buffer_registration_hooks.values(): + output = hook(self, name, value) + if output is not None: + value = output + buffers[name] = value + else: + super().__setattr__(name, value) + + def __delattr__(self, name): + if name in self._parameters: + del self._parameters[name] + elif name in self._buffers: + del self._buffers[name] + self._non_persistent_buffers_set.discard(name) + elif name in self._modules: + del self._modules[name] + else: + super().__delattr__(name) + + + def extra_repr(self) -> str: + r"""Set the extra representation of the module. + + To print customized extra information, you should re-implement + this method in your own modules. Both single-line and multi-line + strings are acceptable. + """ + return '' + + + def __repr__(self): + # We treat the extra repr like the sub-module, one item per line + extra_lines = [] + extra_repr = self.extra_repr() + # empty string will be split into list [''] + if extra_repr: + extra_lines = extra_repr.split('\n') + child_lines = [] + for key, module in self._modules.items(): + mod_str = repr(module) + mod_str = _addindent(mod_str, 2) + child_lines.append('(' + key + '): ' + mod_str) + lines = extra_lines + child_lines + + main_str = self._get_name() + '(' + if lines: + # simple one-liner info, which most builtin Modules will use + if len(extra_lines) == 1 and not child_lines: + main_str += extra_lines[0] + else: + main_str += '\n ' + '\n '.join(lines) + '\n' + + main_str += ')' + return main_str + + def __dir__(self): + module_attrs = dir(self.__class__) + attrs = list(self.__dict__.keys()) + parameters = list(self._parameters.keys()) + modules = list(self._modules.keys()) + buffers = list(self._buffers.keys()) + keys = module_attrs + attrs + parameters + modules + buffers + + # Eliminate attrs that are not legal Python variable names + keys = [key for key in keys if not key[0].isdigit()] + + return sorted(keys) + + def cuda(self: T, device: Optional[Union[int, device]] = None) -> T: + r"""Move all model parameters and buffers to the GPU. + + This also makes associated parameters and buffers different objects. So + it should be called before constructing optimizer if the module will + live on GPU while being optimized. + + .. note:: + This method modifies the module in-place. + + Args: + device (int, optional): if specified, all parameters will be + copied to that device + + Returns: + Module: self + """ + return self._apply(lambda t: t.cuda(device)) + + def npu(self: T, device: Optional[Union[int, device]] = None) -> T: + return self._apply(lambda t: t.npu(device)) + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + r"""Copy parameters and buffers from :attr:`state_dict` into only this module, but not its descendants. + + This is called on every submodule + in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this + module in input :attr:`state_dict` is provided as :attr:`local_metadata`. + For state dicts without metadata, :attr:`local_metadata` is empty. + Subclasses can achieve class-specific backward compatible loading using + the version number at `local_metadata.get("version", None)`. + Additionally, :attr:`local_metadata` can also contain the key + `assign_to_params_buffers` that indicates whether keys should be + assigned their corresponding tensor in the state_dict. + + .. note:: + :attr:`state_dict` is not the same object as the input + :attr:`state_dict` to :meth:`~torch.nn.Module.load_state_dict`. So + it can be modified. + + Args: + state_dict (dict): a dict containing parameters and + persistent buffers. + prefix (str): the prefix for parameters and buffers used in this + module + local_metadata (dict): a dict containing the metadata for this module. + See + strict (bool): whether to strictly enforce that the keys in + :attr:`state_dict` with :attr:`prefix` match the names of + parameters and buffers in this module + missing_keys (list of str): if ``strict=True``, add missing keys to + this list + unexpected_keys (list of str): if ``strict=True``, add unexpected + keys to this list + error_msgs (list of str): error messages should be added to this + list, and will be reported together in + :meth:`~torch.nn.Module.load_state_dict` + """ + for hook in self._load_state_dict_pre_hooks.values(): + hook( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) + + persistent_buffers = { + k: v + for k, v in self._buffers.items() + if k not in self._non_persistent_buffers_set + } + local_name_params = itertools.chain( + self._parameters.items(), persistent_buffers.items() + ) + local_state = {k: v for k, v in local_name_params if v is not None} + assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False) + use_swap_tensors = core.__future__.get_swap_module_params_on_conversion() + + for name, param in local_state.items(): + key = prefix + name + if key in state_dict: + input_param = state_dict[key] + if not core.overrides.is_tensor_like(input_param): + error_msgs.append( + f'While copying the parameter named "{key}", ' + "expected torch.Tensor or Tensor-like object from checkpoint but " + f"received {type(input_param)}" + ) + continue + + # This is used to avoid copying uninitialized parameters into + # non-lazy modules, since they dont have the hook to do the checks + # in such case, it will error when accessing the .shape attribute. + is_param_lazy = core.nn.parameter.is_lazy(param) + # Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+ + if ( + not is_param_lazy + and len(param.shape) == 0 + and len(input_param.shape) == 1 + ): + input_param = input_param[0] + + if not is_param_lazy and input_param.shape != param.shape: + # local shape should match the one in checkpoint + error_msgs.append( + f"size mismatch for {key}: copying a param with shape {input_param.shape} from checkpoint, " + f"the shape in current model is {param.shape}." + ) + continue + + if ( + param.is_meta + and not input_param.is_meta + and not assign_to_params_buffers + ): + warnings.warn( + f"for {key}: copying from a non-meta parameter in the checkpoint to a meta " + "parameter in the current model, which is a no-op. (Did you mean to " + "pass `assign=True` to assign items in the state dictionary to their " + "corresponding key in the module instead of copying them in place?)" + ) + + try: + with core.no_grad(): + if use_swap_tensors: + new_input_param = param.module_load( + input_param, assign=assign_to_params_buffers + ) + if id(new_input_param) == id(input_param) or id( + new_input_param + ) == id(param): + raise RuntimeError( + "module_load returned one of self or other, please .detach() " + "the result if returning one of the inputs in module_load" + ) + if isinstance(param, core.nn.Parameter): + if not isinstance(new_input_param, core.nn.Parameter): + new_input_param = core.nn.Parameter( + new_input_param, + requires_grad=param.requires_grad, + ) + else: + new_input_param.requires_grad_(param.requires_grad) + core.utils.swap_tensors(param, new_input_param) + del new_input_param + elif assign_to_params_buffers: + # Shape checks are already done above + if isinstance(param, core.nn.Parameter): + if not isinstance(input_param, core.nn.Parameter): + input_param = core.nn.Parameter( + input_param, requires_grad=param.requires_grad + ) + else: + input_param.requires_grad_(param.requires_grad) + setattr(self, name, input_param) + else: + param.copy_(input_param) + except Exception as ex: + action = "swapping" if use_swap_tensors else "copying" + error_msgs.append( + f'While {action} the parameter named "{key}", ' + f"whose dimensions in the model are {param.size()} and " + f"whose dimensions in the checkpoint are {input_param.size()}, " + f"an exception occurred : {ex.args}." + ) + elif strict: + missing_keys.append(key) + + extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX + if ( + getattr(self.__class__, "set_extra_state", Module.set_extra_state) + is not Module.set_extra_state + ): + if extra_state_key in state_dict: + self.set_extra_state(state_dict[extra_state_key]) + elif strict: + missing_keys.append(extra_state_key) + elif strict and (extra_state_key in state_dict): + unexpected_keys.append(extra_state_key) + + if strict: + for key in state_dict.keys(): + if key.startswith(prefix) and key != extra_state_key: + input_name = key[len(prefix) :].split(".", 1) + # Must be Module if it have attributes + if len(input_name) > 1: + if input_name[0] not in self._modules: + unexpected_keys.append(key) + elif input_name[0] not in local_state: + unexpected_keys.append(key) + + def load_state_dict(self, state_dict: Mapping[str, Any], + strict: bool = True, assign: bool = False): + r"""Copy parameters and buffers from :attr:`state_dict` into this module and its descendants. + + If :attr:`strict` is ``True``, then + the keys of :attr:`state_dict` must exactly match the keys returned + by this module's :meth:`~nn.Module.state_dict` function. + + Args: + state_dict (dict): a dict containing parameters and + persistent buffers. + strict (bool, optional): whether to strictly enforce that the keys + in :attr:`state_dict` match the keys returned by this module's + :meth:`~nn.Module.state_dict` function. Default: ``True`` + assign (bool, optional): When ``False``, the properties of the tensors + in the current module are preserved while when ``True``, the + properties of the Tensors in the state dict are preserved. The only + exception is the ``requires_grad`` field of :class:`~nn.Parameter`s + for which the value from the module is preserved. + Default: ``False`` + + Returns: + ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields: + * **missing_keys** is a list of str containing the missing keys + * **unexpected_keys** is a list of str containing the unexpected keys + + Note: + If a parameter or buffer is registered as ``None`` and its corresponding key + exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a + ``RuntimeError``. + """ + if not isinstance(state_dict, Mapping): + raise TypeError(f"Expected state_dict to be dict-like, got {type(state_dict)}.") + + missing_keys: List[str] = [] + unexpected_keys: List[str] = [] + error_msgs: List[str] = [] + + # copy state_dict so _load_from_state_dict can modify it + metadata = getattr(state_dict, '_metadata', None) + state_dict = OrderedDict(state_dict) + + if metadata is not None: + # mypy isn't aware that "_metadata" exists in state_dict + state_dict._metadata = metadata # type: ignore[attr-defined] + + def load(module, local_state_dict, prefix=''): + local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) + if assign: + local_metadata['assign_to_params_buffers'] = assign + module._load_from_state_dict( + local_state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) + for name, child in module._modules.items(): + if child is not None: + child_prefix = prefix + name + '.' + child_state_dict = {k: v for k, v in local_state_dict.items() if k.startswith(child_prefix)} + load(child, child_state_dict, child_prefix) # noqa: F821 + + # Note that the hook can modify missing_keys and unexpected_keys. + incompatible_keys = _IncompatibleKeys(missing_keys, unexpected_keys) + for hook in module._load_state_dict_post_hooks.values(): + out = hook(module, incompatible_keys) + assert out is None, ( + "Hooks registered with ``register_load_state_dict_post_hook`` are not" + "expected to return new values, if incompatible_keys need to be modified," + "it should be done inplace." + ) + + load(self, state_dict) + del load + + if strict: + if len(unexpected_keys) > 0: + error_msgs.insert( + 0, 'Unexpected key(s) in state_dict: {}. '.format( + ', '.join(f'"{k}"' for k in unexpected_keys))) + if len(missing_keys) > 0: + error_msgs.insert( + 0, 'Missing key(s) in state_dict: {}. '.format( + ', '.join(f'"{k}"' for k in missing_keys))) + + if len(error_msgs) > 0: + raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( + self.__class__.__name__, "\n\t".join(error_msgs))) + return _IncompatibleKeys(missing_keys, unexpected_keys) + + + def _named_members(self, get_members_fn, prefix='', recurse=True, remove_duplicate: bool = True): + r"""Help yield various names + members of modules.""" + memo = set() + modules = self.named_modules(prefix=prefix, remove_duplicate=remove_duplicate) if recurse else [(prefix, self)] + for module_prefix, module in modules: + members = get_members_fn(module) + for k, v in members: + if v is None or v in memo: + continue + if remove_duplicate: + memo.add(v) + name = module_prefix + ('.' if module_prefix else '') + k + yield name, v + + def parameters(self, recurse: bool = True) -> Iterator[Parameter]: + r"""Return an iterator over module parameters. + + This is typically passed to an optimizer. + + Args: + recurse (bool): if True, then yields parameters of this module + and all submodules. Otherwise, yields only parameters that + are direct members of this module. + + Yields: + Parameter: module parameter + + Example:: + + >>> # xdoctest: +SKIP("undefined vars") + >>> for param in model.parameters(): + >>> print(type(param), param.shape) + (20L,) + (20L, 1L, 5L, 5L) + + """ + for name, param in self.named_parameters(recurse=recurse): + yield param + + def trainable_params(self, recurse: bool = True): + params = tuple() + for name, param in self.named_parameters(recurse=recurse): + if param.requires_grad: + params += (param,) + return params + + def get_submodule(self, target: str) -> "Module": + """Return the submodule given by ``target`` if it exists, otherwise throw an error. + + For example, let's say you have an ``nn.Module`` ``A`` that + looks like this: + + .. code-block:: text + + A( + (net_b): Module( + (net_c): Module( + (conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2)) + ) + (linear): Linear(in_features=100, out_features=200, bias=True) + ) + ) + + (The diagram shows an ``nn.Module`` ``A``. ``A`` has a nested + submodule ``net_b``, which itself has two submodules ``net_c`` + and ``linear``. ``net_c`` then has a submodule ``conv``.) + + To check whether or not we have the ``linear`` submodule, we + would call ``get_submodule("net_b.linear")``. To check whether + we have the ``conv`` submodule, we would call + ``get_submodule("net_b.net_c.conv")``. + + The runtime of ``get_submodule`` is bounded by the degree + of module nesting in ``target``. A query against + ``named_modules`` achieves the same result, but it is O(N) in + the number of transitive modules. So, for a simple check to see + if some submodule exists, ``get_submodule`` should always be + used. + + Args: + target: The fully-qualified string name of the submodule + to look for. (See above example for how to specify a + fully-qualified string.) + + Returns: + nn.Module: The submodule referenced by ``target`` + + Raises: + AttributeError: If the target string references an invalid + path or resolves to something that is not an + ``nn.Module`` + """ + if target == "": + return self + + atoms: List[str] = target.split(".") + mod: Module = self + + for item in atoms: + + if not hasattr(mod, item): + raise AttributeError(mod._get_name() + " has no " + "attribute `" + item + "`") + + mod = getattr(mod, item) + + if not isinstance(mod, Module): + raise AttributeError("`" + item + "` is not " + "an nn.Module") + + return mod + + def get_parameters(self, expand=True): + return self.parameters(expand) + + def named_parameters( + self, + prefix: str = '', + recurse: bool = True, + remove_duplicate: bool = True + ) -> Iterator[Tuple[str, Parameter]]: + r"""Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself. + + Args: + prefix (str): prefix to prepend to all parameter names. + recurse (bool): if True, then yields parameters of this module + and all submodules. Otherwise, yields only parameters that + are direct members of this module. + remove_duplicate (bool, optional): whether to remove the duplicated + parameters in the result. Defaults to True. + + Yields: + (str, Parameter): Tuple containing the name and parameter + + Example:: + + >>> # xdoctest: +SKIP("undefined vars") + >>> for name, param in self.named_parameters(): + >>> if name in ['bias']: + >>> print(param.shape) + + """ + gen = self._named_members( + lambda module: module._parameters.items(), + prefix=prefix, recurse=recurse, remove_duplicate=remove_duplicate) + yield from gen + + def parameters_and_names(self, name_prefix='', expand=True): + return self.named_parameters(name_prefix, expand) + + def buffers(self, recurse: bool = True) -> Iterator[Tensor]: + r"""Return an iterator over module buffers. + + Args: + recurse (bool): if True, then yields buffers of this module + and all submodules. Otherwise, yields only buffers that + are direct members of this module. + + Yields: + core.Tensor: module buffer + + Example:: + + >>> # xdoctest: +SKIP("undefined vars") + >>> for buf in model.buffers(): + >>> print(type(buf), buf.shape) + (20L,) + (20L, 1L, 5L, 5L) + + """ + for _, buf in self.named_buffers(recurse=recurse): + yield buf + + + def named_buffers(self, prefix: str = '', recurse: bool = True, remove_duplicate: bool = True) -> Iterator[Tuple[str, Tensor]]: + r"""Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself. + + Args: + prefix (str): prefix to prepend to all buffer names. + recurse (bool, optional): if True, then yields buffers of this module + and all submodules. Otherwise, yields only buffers that + are direct members of this module. Defaults to True. + remove_duplicate (bool, optional): whether to remove the duplicated buffers in the result. Defaults to True. + + Yields: + (str, core.Tensor): Tuple containing the name and buffer + + Example:: + + >>> # xdoctest: +SKIP("undefined vars") + >>> for name, buf in self.named_buffers(): + >>> if name in ['running_var']: + >>> print(buf.shape) + + """ + gen = self._named_members( + lambda module: module._buffers.items(), + prefix=prefix, recurse=recurse, remove_duplicate=remove_duplicate) + yield from gen + + def _all_buffers(self, memo=None): + if memo is None: + memo = set() + for name, b in self._buffers.items(): + if b is not None and b not in memo: + memo.add(b) + yield b + for module in self.children(): + for b in module._all_buffers(memo): + yield b + + def children(self): + """Returns an iterator over immediate children modules. + + Yields: + Module: a child module + """ + for name, module in self.named_children(): + yield module + + def named_children(self): + """Returns an iterator over immediate children modules, yielding both + the name of the module as well as the module itself. + + Yields: + (string, Module): Tuple containing a name and child module + + Example: + >>> for name, module in model.named_children(): + >>> if name in ['conv4', 'conv5']: + >>> print(module) + """ + memo = set() + for name, module in self._modules.items(): + if module is not None and module not in memo: + memo.add(module) + yield name, module + + def modules(self): + """Returns an iterator over all modules in the network. + + Yields: + Module: a module in the network + + Note: + Duplicate modules are returned only once. In the following + example, ``l`` will be returned only once. + + >>> l = nn.Linear(2, 2) + >>> net = nn.Sequential(l, l) + >>> for idx, m in enumerate(net.modules()): + >>> print(idx, '->', m) + 0 -> Sequential ( + (0): Linear (2 -> 2) + (1): Linear (2 -> 2) + ) + 1 -> Linear (2 -> 2) + """ + for name, module in self.named_modules(): + yield module + + def named_modules(self, memo: Optional[Set['Module']] = None, prefix: str = '', remove_duplicate: bool = True): + r"""Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself. + + Args: + memo: a memo to store the set of modules already added to the result + prefix: a prefix that will be added to the name of the module + remove_duplicate: whether to remove the duplicated module instances in the result + or not + + Yields: + (str, Module): Tuple of name and module + + Note: + Duplicate modules are returned only once. In the following + example, ``l`` will be returned only once. + + Example:: + + >>> l = nn.Linear(2, 2) + >>> net = nn.Sequential(l, l) + >>> for idx, m in enumerate(net.named_modules()): + ... print(idx, '->', m) + + 0 -> ('', Sequential( + (0): Linear(in_features=2, out_features=2, bias=True) + (1): Linear(in_features=2, out_features=2, bias=True) + )) + 1 -> ('0', Linear(in_features=2, out_features=2, bias=True)) + + """ + if memo is None: + memo = set() + if self not in memo: + if remove_duplicate: + memo.add(self) + yield prefix, self + for name, module in self._modules.items(): + if module is None: + continue + submodule_prefix = prefix + ('.' if prefix else '') + name + yield from module.named_modules(memo, submodule_prefix, remove_duplicate) + + def cells_and_names(self, cells=None, name_prefix=''): + return self.named_modules(cells, name_prefix) + + def jit(self, mode=True): + self.__ms_class__ = mode + for module in self.children(): + module.jit(mode) + return self + + def compile(self, *args, **kwargs): + self.jit() + def forward_fn(*args, **kwargs): + return self.forward(*args, **kwargs) + + # forward_fn = mindspore.jit(forward_fn, *args, **kwargs) + self._compiled_call_impl = forward_fn + + @property + def skip_syntax(self): + return self.__ms_class__ + + def train(self, mode=True): + """Sets the module in training mode. + + This has any effect only on modules such as Dropout or BatchNorm. + + Returns: + Module: self + """ + self.training = mode + for module in self.children(): + module.train(mode) + return self + + set_train = train + + def eval(self): + """Sets the module in evaluation mode. + + This has any effect only on modules such as Dropout or BatchNorm. + """ + return self.train(False) + + def requires_grad_(self: T, requires_grad: bool = True) -> T: + r"""Change if autograd should record operations on parameters in this module. + + This method sets the parameters' :attr:`requires_grad` attributes + in-place. + + This method is helpful for freezing part of the module for finetuning + or training parts of a model individually (e.g., GAN training). + + See :ref:`locally-disable-grad-doc` for a comparison between + `.requires_grad_()` and several similar mechanisms that may be confused with it. + + Args: + requires_grad (bool): whether autograd should record operations on + parameters in this module. Default: ``True``. + + Returns: + Module: self + """ + for p in self.parameters(): + p.requires_grad = requires_grad + return self + + + def _get_name(self): + return self.__class__.__name__ + + def to(self, *args, **kwargs): + r"""Move and/or cast the parameters and buffers. + + This can be called as + + .. function:: to(device=None, dtype=None, non_blocking=False) + :noindex: + + .. function:: to(dtype, non_blocking=False) + :noindex: + + .. function:: to(tensor, non_blocking=False) + :noindex: + + .. function:: to(memory_format=core.channels_last) + :noindex: + + Its signature is similar to :meth:`core.Tensor.to`, but only accepts + floating point or complex :attr:`dtype`\ s. In addition, this method will + only cast the floating point or complex parameters and buffers to :attr:`dtype` + (if given). The integral parameters and buffers will be moved + :attr:`device`, if that is given, but with dtypes unchanged. When + :attr:`non_blocking` is set, it tries to convert/move asynchronously + with respect to the host if possible, e.g., moving CPU Tensors with + pinned memory to CUDA devices. + + See below for examples. + + .. note:: + This method modifies the module in-place. + + Args: + device (:class:`core.device`): the desired device of the parameters + and buffers in this module + dtype (:class:`core.dtype`): the desired floating point or complex dtype of + the parameters and buffers in this module + tensor (core.Tensor): Tensor whose dtype and device are the desired + dtype and device for all parameters and buffers in this module + memory_format (:class:`core.memory_format`): the desired memory + format for 4D parameters and buffers in this module (keyword + only argument) + + Returns: + Module: self + + Examples:: + + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> linear = nn.Linear(2, 2) + >>> linear.weight + Parameter containing: + tensor([[ 0.1913, -0.3420], + [-0.5113, -0.2325]]) + >>> linear.to(core.double) + Linear(in_features=2, out_features=2, bias=True) + >>> linear.weight + Parameter containing: + tensor([[ 0.1913, -0.3420], + [-0.5113, -0.2325]], dtype=core.float64) + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA1) + >>> gpu1 = core.device("cuda:1") + >>> linear.to(gpu1, dtype=core.half, non_blocking=True) + Linear(in_features=2, out_features=2, bias=True) + >>> linear.weight + Parameter containing: + tensor([[ 0.1914, -0.3420], + [-0.5112, -0.2324]], dtype=core.float16, device='cuda:1') + >>> cpu = core.device("cpu") + >>> linear.to(cpu) + Linear(in_features=2, out_features=2, bias=True) + >>> linear.weight + Parameter containing: + tensor([[ 0.1914, -0.3420], + [-0.5112, -0.2324]], dtype=core.float16) + + >>> linear = nn.Linear(2, 2, bias=None).to(core.cdouble) + >>> linear.weight + Parameter containing: + tensor([[ 0.3741+0.j, 0.2382+0.j], + [ 0.5593+0.j, -0.4443+0.j]], dtype=core.complex128) + >>> linear(core.ones(3, 2, dtype=core.cdouble)) + tensor([[0.6122+0.j, 0.1150+0.j], + [0.6122+0.j, 0.1150+0.j], + [0.6122+0.j, 0.1150+0.j]], dtype=core.complex128) + + """ + device, dtype, non_blocking, convert_to_format = core._C._nn._parse_to( + *args, **kwargs + ) + + if dtype is not None: + if not (dtype.is_floating_point or dtype.is_complex): + raise TypeError( + "nn.Module.to only accepts floating point or complex " + f"dtypes, but got desired dtype={dtype}" + ) + if dtype.is_complex: + warnings.warn( + "Complex modules are a new feature under active development whose design may change, " + "and some modules might not work as expected when using complex tensors as parameters or buffers. " + "Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml " + "if a complex module does not work as expected." + ) + + def convert(t): + try: + if convert_to_format is not None and t.dim() in (4, 5): + return t.to( + device, + dtype if t.is_floating_point() or t.is_complex() else None, + non_blocking, + memory_format=convert_to_format, + ) + return t.to( + device, + dtype if t.is_floating_point() or t.is_complex() else None, + non_blocking=non_blocking, + ) + except NotImplementedError as e: + if str(e) == "Cannot copy out of meta tensor; no data!": + raise NotImplementedError( + f"{e} Please use core.nn.Module.to_empty() instead of core.nn.Module.to() " + f"when moving module from meta to a different device." + ) from None + else: + raise + + return self._apply(convert) + + def half(self: T) -> T: + r"""Casts all floating point parameters and buffers to ``half`` datatype. + + .. note:: + This method modifies the module in-place. + + Returns: + Module: self + """ + return self._apply(lambda t: t.half() if t.is_floating_point() else t) + + def bfloat16(self: T) -> T: + r"""Casts all floating point parameters and buffers to ``bfloat16`` datatype. + + .. note:: + This method modifies the module in-place. + + Returns: + Module: self + """ + return self._apply(lambda t: t.bfloat16() if t.is_floating_point() else t) + + def float(self: T) -> T: + r"""Casts all floating point parameters and buffers to ``float`` datatype. + + .. note:: + This method modifies the module in-place. + + Returns: + Module: self + """ + return self._apply(lambda t: t.float() if t.is_floating_point() else t) + + + def double(self: T) -> T: + r"""Casts all floating point parameters and buffers to ``double`` datatype. + + .. note:: + This method modifies the module in-place. + + Returns: + Module: self + """ + return self._apply(lambda t: t.double() if t.is_floating_point() else t) + + + def half(self: T) -> T: + r"""Casts all floating point parameters and buffers to ``half`` datatype. + + .. note:: + This method modifies the module in-place. + + Returns: + Module: self + """ + return self._apply(lambda t: t.half() if t.is_floating_point() else t) + + + def _save_to_state_dict(self, destination, prefix, keep_vars): + r"""Save module state to the `destination` dictionary. + + The `destination` dictionary will contain the state + of the module, but not its descendants. This is called on every + submodule in :meth:`~nn.Module.state_dict`. + + In rare cases, subclasses can achieve class-specific behavior by + overriding this method with custom logic. + + Args: + destination (dict): a dict where state will be stored + prefix (str): the prefix for parameters and buffers used in this + module + """ + for name, param in self._parameters.items(): + if param is not None: + destination[prefix + name] = param + for name, buf in self._buffers.items(): + if buf is not None and name not in self._non_persistent_buffers_set: + destination[prefix + name] = buf + extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX + if getattr(self.__class__, "get_extra_state", Module.get_extra_state) is not Module.get_extra_state: + destination[extra_state_key] = self.get_extra_state() + + # The user can pass an optional arbitrary mappable object to `state_dict`, in which case `state_dict` returns + # back that same object. But if they pass nothing, an `OrderedDict` is created and returned. + T_destination = TypeVar('T_destination', bound=Dict[str, Any]) + + @overload + def state_dict(self, *, destination: T_destination, prefix: str = ..., keep_vars: bool = ...) -> T_destination: + ... + + @overload + def state_dict(self, *, prefix: str = ..., keep_vars: bool = ...) -> Dict[str, Any]: + ... + + def state_dict(self, *args, destination=None, prefix='', keep_vars=False): + r"""Return a dictionary containing references to the whole state of the module. + + Both parameters and persistent buffers (e.g. running averages) are + included. Keys are corresponding parameter and buffer names. + Parameters and buffers set to ``None`` are not included. + + .. note:: + The returned object is a shallow copy. It contains references + to the module's parameters and buffers. + + .. warning:: + Currently ``state_dict()`` also accepts positional arguments for + ``destination``, ``prefix`` and ``keep_vars`` in order. However, + this is being deprecated and keyword arguments will be enforced in + future releases. + + .. warning:: + Please avoid the use of argument ``destination`` as it is not + designed for end-users. + + Args: + destination (dict, optional): If provided, the state of module will + be updated into the dict and the same object is returned. + Otherwise, an ``OrderedDict`` will be created and returned. + Default: ``None``. + prefix (str, optional): a prefix added to parameter and buffer + names to compose the keys in state_dict. Default: ``''``. + keep_vars (bool, optional): by default the :class:`~core.Tensor` s + returned in the state dict are detached from autograd. If it's + set to ``True``, detaching will not be performed. + Default: ``False``. + + Returns: + dict: + a dictionary containing a whole state of the module + + Example:: + + >>> # xdoctest: +SKIP("undefined vars") + >>> module.state_dict().keys() + ['bias', 'weight'] + + """ + # TODO: Remove `args` and the parsing logic when BC allows. + if len(args) > 0: + if destination is None: + destination = args[0] + if len(args) > 1 and prefix == '': + prefix = args[1] + if len(args) > 2 and keep_vars is False: + keep_vars = args[2] + # DeprecationWarning is ignored by default + warnings.warn( + "Positional args are being deprecated, use kwargs instead.") + + if destination is None: + destination = OrderedDict() + destination._metadata = OrderedDict() + + local_metadata = {} + if hasattr(destination, "_metadata"): + destination._metadata[prefix[:-1]] = local_metadata + + for hook in self._state_dict_pre_hooks.values(): + hook(self, prefix, keep_vars) + self._save_to_state_dict(destination, prefix, keep_vars) + for name, module in self._modules.items(): + if module is not None: + module.state_dict(destination=destination, prefix=prefix + name + '.', keep_vars=keep_vars) + for hook in self._state_dict_hooks.values(): + hook_result = hook(self, destination, prefix, local_metadata) + if hook_result is not None: + destination = hook_result + return destination + + def _register_load_state_dict_pre_hook(self, hook, with_module=False): + r"""Register a pre-hook for the :meth:`~nn.Module.load_state_dict` method. + + These hooks will be called with arguments: `state_dict`, `prefix`, + `local_metadata`, `strict`, `missing_keys`, `unexpected_keys`, + `error_msgs`, before loading `state_dict` into `self`. These arguments + are exactly the same as those of `_load_from_state_dict`. + + If ``with_module`` is ``True``, then the first argument to the hook is + an instance of the module. + + Arguments: + hook (Callable): Callable hook that will be invoked before + loading the state dict. + with_module (bool, optional): Whether or not to pass the module + instance to the hook as the first parameter. + """ + handle = hooks.RemovableHandle(self._load_state_dict_pre_hooks) + self._load_state_dict_pre_hooks[handle.id] = _WrappedHook(hook, self if with_module else None) + return handle + + def register_load_state_dict_post_hook(self, hook): + r"""Register a post hook to be run after module's ``load_state_dict`` is called. + + It should have the following signature:: + hook(module, incompatible_keys) -> None + + The ``module`` argument is the current module that this hook is registered + on, and the ``incompatible_keys`` argument is a ``NamedTuple`` consisting + of attributes ``missing_keys`` and ``unexpected_keys``. ``missing_keys`` + is a ``list`` of ``str`` containing the missing keys and + ``unexpected_keys`` is a ``list`` of ``str`` containing the unexpected keys. + + The given incompatible_keys can be modified inplace if needed. + + Note that the checks performed when calling :func:`load_state_dict` with + ``strict=True`` are affected by modifications the hook makes to + ``missing_keys`` or ``unexpected_keys``, as expected. Additions to either + set of keys will result in an error being thrown when ``strict=True``, and + clearing out both missing and unexpected keys will avoid an error. + + Returns: + :class:`utils.hooks.RemovableHandle`: + a handle that can be used to remove the added hook by calling + ``handle.remove()`` + """ + handle = hooks.RemovableHandle(self._load_state_dict_post_hooks) + self._load_state_dict_post_hooks[handle.id] = hook + return handle + + def parameters_dict(self, recurse=True): + param_dict = OrderedDict() + for name, param in self.named_parameters(recurse=recurse, remove_duplicate=False): + param_dict[name] = param + return param_dict + + def register_forward_pre_hook( + self, + hook: Union[ + Callable[[T, Tuple[Any, ...]], Optional[Any]], + Callable[[T, Tuple[Any, ...], Dict[str, Any]], Optional[Tuple[Any, Dict[str, Any]]]], + ], + *, + prepend: bool = False, + with_kwargs: bool = False, + ) -> RemovableHandle: + r"""Registers a forward pre-hook on the module. + + The hook will be called every time before :func:`forward` is invoked. + + + If ``with_kwargs`` is false or not specified, the input contains only + the positional arguments given to the module. Keyword arguments won't be + passed to the hooks and only to the ``forward``. The hook can modify the + input. User can either return a tuple or a single modified value in the + hook. We will wrap the value into a tuple if a single value is returned + (unless that value is already a tuple). The hook should have the + following signature:: + + hook(module, args) -> None or modified input + + If ``with_kwargs`` is true, the forward pre-hook will be passed the + kwargs given to the forward function. And if the hook modifies the + input, both the args and kwargs should be returned. The hook should have + the following signature:: + + hook(module, args, kwargs) -> None or a tuple of modified input and kwargs + + Args: + hook (Callable): The user defined hook to be registered. + prepend (bool): If true, the provided ``hook`` will be fired before + all existing ``forward_pre`` hooks on this + :class:`nn.modules.Module`. Otherwise, the provided + ``hook`` will be fired after all existing ``forward_pre`` hooks + on this :class:`nn.modules.Module`. Note that global + ``forward_pre`` hooks registered with + :func:`register_module_forward_pre_hook` will fire before all + hooks registered by this method. + Default: ``False`` + with_kwargs (bool): If true, the ``hook`` will be passed the kwargs + given to the forward function. + Default: ``False`` + + Returns: + :class:`utils.hooks.RemovableHandle`: + a handle that can be used to remove the added hook by calling + ``handle.remove()`` + """ + handle = hooks.RemovableHandle( + self._forward_pre_hooks, + extra_dict=self._forward_pre_hooks_with_kwargs + ) + self._forward_pre_hooks[handle.id] = hook + if with_kwargs: + self._forward_pre_hooks_with_kwargs[handle.id] = True + + if prepend: + self._forward_pre_hooks.move_to_end(handle.id, last=False) # type: ignore[attr-defined] + return handle + + + def register_forward_hook( + self, + hook: Union[ + Callable[[T, Tuple[Any, ...], Any], Optional[Any]], + Callable[[T, Tuple[Any, ...], Dict[str, Any], Any], Optional[Any]], + ], + *, + prepend: bool = False, + with_kwargs: bool = False, + ) -> RemovableHandle: + r"""Registers a forward hook on the module. + + The hook will be called every time after :func:`forward` has computed an output. + + If ``with_kwargs`` is ``False`` or not specified, the input contains only + the positional arguments given to the module. Keyword arguments won't be + passed to the hooks and only to the ``forward``. The hook can modify the + output. It can modify the input inplace but it will not have effect on + forward since this is called after :func:`forward` is called. The hook + should have the following signature:: + + hook(module, args, output) -> None or modified output + + If ``with_kwargs`` is ``True``, the forward hook will be passed the + ``kwargs`` given to the forward function and be expected to return the + output possibly modified. The hook should have the following signature:: + + hook(module, args, kwargs, output) -> None or modified output + + Args: + hook (Callable): The user defined hook to be registered. + prepend (bool): If ``True``, the provided ``hook`` will be fired + before all existing ``forward`` hooks on this + :class:`nn.modules.Module`. Otherwise, the provided + ``hook`` will be fired after all existing ``forward`` hooks on + this :class:`nn.modules.Module`. Note that global + ``forward`` hooks registered with + :func:`register_module_forward_hook` will fire before all hooks + registered by this method. + Default: ``False`` + with_kwargs (bool): If ``True``, the ``hook`` will be passed the + kwargs given to the forward function. + Default: ``False`` + + Returns: + :class:`utils.hooks.RemovableHandle`: + a handle that can be used to remove the added hook by calling + ``handle.remove()`` + """ + handle = hooks.RemovableHandle( + self._forward_hooks, + extra_dict=self._forward_hooks_with_kwargs + ) + self._forward_hooks[handle.id] = hook + if with_kwargs: + self._forward_hooks_with_kwargs[handle.id] = True + + if prepend: + self._forward_hooks.move_to_end(handle.id, last=False) # type: ignore[attr-defined] + return handle + + def zero_grad(self, set_to_none: bool = True) -> None: + r"""Reset gradients of all model parameters. + + See similar function under :class:`core.optim.Optimizer` for more context. + + Args: + set_to_none (bool): instead of setting to zero, set the grads to None. + See :meth:`core.optim.Optimizer.zero_grad` for details. + """ + if getattr(self, "_is_replica", False): + warnings.warn( + "Calling .zero_grad() from a module created with nn.DataParallel() has no effect. " + "The parameters are copied (in a differentiable manner) from the original module. " + "This means they are not leaf nodes in autograd and so don't accumulate gradients. " + "If you need gradients in your forward method, consider using autograd.grad instead." + ) + + for p in self.parameters(): + if p.grad is not None: + p.grad = None diff --git a/mindnlp/core/nn/modules/normalization.py b/mindnlp/core/nn/modules/normalization.py new file mode 100644 index 000000000..3fcc64418 --- /dev/null +++ b/mindnlp/core/nn/modules/normalization.py @@ -0,0 +1,168 @@ +"""normalization""" +import numbers +from ..parameter import Parameter +from .module import Module +from ..functional import group_norm, layer_norm +from .. import init +from ... import ops + + +class LayerNorm(Module): + r"""Applies Layer Normalization over a mini-batch of inputs as described in + the paper `Layer Normalization`_ . + + .. math:: + y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta + + The mean and standard-deviation are calculated separately over the last + certain number dimensions which have to be of the shape specified by + :attr:`normalized_shape`. + :math:`\gamma` and :math:`\beta` are learnable affine transform parameters of + :attr:`normalized_shape` if :attr:`elementwise_affine` is ``True``. + + .. note:: + Unlike Batch Normalization and Instance Normalization, which applies + scalar scale and bias for each entire channel/plane with the + :attr:`affine` option, Layer Normalization applies per-element scale and + bias with :attr:`elementwise_affine`. + + This layer uses statistics computed from input data in both training and + evaluation modes. + + Args: + normalized_shape (int or list or core.Size): input shape from an expected input + of size + + .. math:: + [* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] + \times \ldots \times \text{normalized_shape}[-1]] + + If a single integer is used, it is treated as a singleton list, and this module will + normalize over the last dimension which is expected to be of that specific size. + eps: a value added to the denominator for numerical stability. Default: 1e-5 + elementwise_affine: a boolean value that when set to ``True``, this module + has learnable per-element affine parameters. Default: ``True`` + + Shape: + - Input: :math:`(N, *)` + - Output: :math:`(N, *)` (same shape as input) + + Examples:: + + >>> input = core.randn(20, 5, 10, 10) + >>> # With Learnable Parameters + >>> m = nn.LayerNorm(input.size()[1:]) + >>> # Without Learnable Parameters + >>> m = nn.LayerNorm(input.size()[1:], elementwise_affine=False) + >>> # Normalize over last two dimensions + >>> m = nn.LayerNorm([10, 10]) + >>> # Normalize over last dimension of size 10 + >>> m = nn.LayerNorm(10) + >>> # Activating the module + >>> output = m(input) + + .. _`Layer Normalization`: https://arxiv.org/abs/1607.06450 + """ + def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True, bias: bool = True,dtype=None): + factory_kwargs = {'dtype': dtype} + super(LayerNorm, self).__init__() + if isinstance(normalized_shape, numbers.Integral): + normalized_shape = (normalized_shape,) + self.normalized_shape = tuple(normalized_shape) + self.eps = eps + self.elementwise_affine = elementwise_affine + if self.elementwise_affine: + self.weight = Parameter(ops.empty(self.normalized_shape, **factory_kwargs)) + if bias: + self.bias = Parameter(ops.empty(self.normalized_shape, **factory_kwargs)) + else: + self.register_parameter('bias', None) + else: + self.register_parameter('weight', None) + self.register_parameter('bias', None) + + self.reset_parameters() + + def reset_parameters(self) -> None: + if self.elementwise_affine: + init.ones_(self.weight) + if self.bias is not None: + init.zeros_(self.bias) + + def forward(self, input): + return layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps) + + def extra_repr(self): + return '{normalized_shape}, eps={eps}, ' \ + 'elementwise_affine={elementwise_affine}'.format(**self.__dict__) + + +class GroupNorm(Module): + r"""Applies Group Normalization over a mini-batch of inputs as described in + the paper `Group Normalization`_ . + + .. math:: + y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta + + The input channels are separated into :attr:`num_groups` groups, each containing + ``num_channels / num_groups`` channels. The mean and standard-deviation are calculated + separately over the each group. :math:`\gamma` and :math:`\beta` are learnable + per-channel affine transform parameter vectorss of size :attr:`num_channels` if + :attr:`affine` is ``True``. + + This layer uses statistics computed from input data in both training and + evaluation modes. + + Args: + num_groups (int): number of groups to separate the channels into + num_channels (int): number of channels expected in input + eps: a value added to the denominator for numerical stability. Default: 1e-5 + affine: a boolean value that when set to ``True``, this module + has learnable per-channel affine parameters. Default: ``True`` + + Shape: + - Input: :math:`(N, num\_channels, *)` + - Output: :math:`(N, num\_channels, *)` (same shape as input) + + Examples:: + + >>> input = core.randn(20, 6, 10, 10) + >>> # Separate 6 channels into 3 groups + >>> m = nn.GroupNorm(3, 6) + >>> # Separate 6 channels into 6 groups (equivalent with InstanceNorm) + >>> m = nn.GroupNorm(6, 6) + >>> # Put all 6 channels into a single group (equivalent with LayerNorm) + >>> m = nn.GroupNorm(1, 6) + >>> # Activating the module + >>> output = m(input) + + .. _`Group Normalization`: https://arxiv.org/abs/1803.08494 + """ + def __init__(self, num_groups, num_channels, eps=1e-5, affine=True, dtype=None): + factory_kwargs = {'dtype': dtype} + super(GroupNorm, self).__init__() + self.num_groups = num_groups + self.num_channels = num_channels + self.eps = eps + self.affine = affine + if self.affine: + self.weight = Parameter(ops.empty(num_channels, **factory_kwargs)) + self.bias = Parameter(ops.empty(num_channels, **factory_kwargs)) + else: + self.register_parameter('weight', None) + self.register_parameter('bias', None) + + self.reset_parameters() + + def forward(self, input): + return group_norm(input, self.num_groups, self.weight, self.bias, self.eps) + + + def reset_parameters(self) -> None: + if self.affine: + init.ones_(self.weight) + init.zeros_(self.bias) + + def extra_repr(self): + return '{num_groups}, {num_channels}, eps={eps}, ' \ + 'affine={affine}'.format(**self.__dict__) diff --git a/mindnlp/core/nn/modules/padding.py b/mindnlp/core/nn/modules/padding.py new file mode 100644 index 000000000..5c5afab0e --- /dev/null +++ b/mindnlp/core/nn/modules/padding.py @@ -0,0 +1,253 @@ +"""padding""" +from typing import Sequence, Tuple +from mindnlp.core import Tensor + +from .module import Module +from ._utils import _pair, _quadruple, _ntuple +from ..common_types import _size_2_t, _size_4_t, _size_6_t +from .. import functional as F + +class _ConstantPadNd(Module): + __constants__ = ['padding', 'value'] + value: float + padding: Sequence[int] + + def __init__(self, value: float) -> None: + super().__init__() + self.value = value + + def forward(self, input: Tensor) -> Tensor: + return F.pad(input, self.padding, 'constant', self.value) + + def extra_repr(self) -> str: + return f'padding={self.padding}, value={self.value}' + +class ConstantPad1d(_ConstantPadNd): + r"""Pads the input tensor boundaries with a constant value. + + For `N`-dimensional padding, use :func:`core.nn.functional.pad()`. + + Args: + padding (int, tuple): the size of the padding. If is `int`, uses the same + padding in both boundaries. If a 2-`tuple`, uses + (:math:`\text{padding\_left}`, :math:`\text{padding\_right}`) + + Shape: + - Input: :math:`(C, W_{in})` or :math:`(N, C, W_{in})`. + - Output: :math:`(C, W_{out})` or :math:`(N, C, W_{out})`, where + + :math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}` + """ + + padding: Tuple[int, int] + + def __init__(self, padding: _size_2_t, value: float): + super().__init__(value) + self.padding = _pair(padding) + +class ConstantPad2d(_ConstantPadNd): + r"""Pads the input tensor boundaries with a constant value. + + For `N`-dimensional padding, use :func:`core.nn.functional.pad()`. + + Args: + padding (int, tuple): the size of the padding. If is `int`, uses the same + padding in all boundaries. If a 4-`tuple`, uses (:math:`\text{padding\_left}`, + :math:`\text{padding\_right}`, :math:`\text{padding\_top}`, :math:`\text{padding\_bottom}`) + + Shape: + - Input: :math:`(N, C, H_{in}, W_{in})` or :math:`(C, H_{in}, W_{in})`. + - Output: :math:`(N, C, H_{out}, W_{out})` or :math:`(C, H_{out}, W_{out})`, where + + :math:`H_{out} = H_{in} + \text{padding\_top} + \text{padding\_bottom}` + + :math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}` + + """ + + __constants__ = ['padding', 'value'] + padding: Tuple[int, int, int, int] + + def __init__(self, padding: _size_4_t, value: float) -> None: + super().__init__(value) + self.padding = _quadruple(padding) + +class ConstantPad3d(_ConstantPadNd): + r"""Pads the input tensor boundaries with a constant value. + + For `N`-dimensional padding, use :func:`core.nn.functional.pad()`. + + Args: + padding (int, tuple): the size of the padding. If is `int`, uses the same + padding in all boundaries. If a 6-`tuple`, uses + (:math:`\text{padding\_left}`, :math:`\text{padding\_right}`, + :math:`\text{padding\_top}`, :math:`\text{padding\_bottom}`, + :math:`\text{padding\_front}`, :math:`\text{padding\_back}`) + + Shape: + - Input: :math:`(N, C, D_{in}, H_{in}, W_{in})` or :math:`(C, D_{in}, H_{in}, W_{in})`. + - Output: :math:`(N, C, D_{out}, H_{out}, W_{out})` or + :math:`(C, D_{out}, H_{out}, W_{out})`, where + + :math:`D_{out} = D_{in} + \text{padding\_front} + \text{padding\_back}` + + :math:`H_{out} = H_{in} + \text{padding\_top} + \text{padding\_bottom}` + + :math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}` + + """ + + padding: Tuple[int, int, int, int, int, int] + + def __init__(self, padding: _size_6_t, value: float) -> None: + super().__init__(value) + self.padding = _ntuple(6)(padding) + + +class ZeroPad1d(ConstantPad1d): + r"""Pads the input tensor boundaries with zero. + + For `N`-dimensional padding, use :func:`core.nn.functional.pad()`. + + Args: + padding (int, tuple): the size of the padding. If is `int`, uses the same + padding in both boundaries. If a 2-`tuple`, uses + (:math:`\text{padding\_left}`, :math:`\text{padding\_right}`) + + Shape: + - Input: :math:`(C, W_{in})` or :math:`(N, C, W_{in})`. + - Output: :math:`(C, W_{out})` or :math:`(N, C, W_{out})`, where + + :math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}` + + Examples:: + + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> m = nn.ZeroPad1d(2) + >>> input = core.randn(1, 2, 4) + >>> input + tensor([[[-1.0491, -0.7152, -0.0749, 0.8530], + [-1.3287, 1.8966, 0.1466, -0.2771]]]) + >>> m(input) + tensor([[[ 0.0000, 0.0000, -1.0491, -0.7152, -0.0749, 0.8530, 0.0000, + 0.0000], + [ 0.0000, 0.0000, -1.3287, 1.8966, 0.1466, -0.2771, 0.0000, + 0.0000]]]) + >>> m = nn.ZeroPad1d(2) + >>> input = core.randn(1, 2, 3) + >>> input + tensor([[[ 1.6616, 1.4523, -1.1255], + [-3.6372, 0.1182, -1.8652]]]) + >>> m(input) + tensor([[[ 0.0000, 0.0000, 1.6616, 1.4523, -1.1255, 0.0000, 0.0000], + [ 0.0000, 0.0000, -3.6372, 0.1182, -1.8652, 0.0000, 0.0000]]]) + >>> # using different paddings for different sides + >>> m = nn.ZeroPad1d((3, 1)) + >>> m(input) + tensor([[[ 0.0000, 0.0000, 0.0000, 1.6616, 1.4523, -1.1255, 0.0000], + [ 0.0000, 0.0000, 0.0000, -3.6372, 0.1182, -1.8652, 0.0000]]]) + """ + + padding: Tuple[int, int] + + def __init__(self, padding: _size_2_t) -> None: + super().__init__(padding, 0.) + + def extra_repr(self) -> str: + return f'{self.padding}' + + +class ZeroPad2d(ConstantPad2d): + r"""Pads the input tensor boundaries with zero. + + For `N`-dimensional padding, use :func:`core.nn.functional.pad()`. + + Args: + padding (int, tuple): the size of the padding. If is `int`, uses the same + padding in all boundaries. If a 4-`tuple`, uses (:math:`\text{padding\_left}`, + :math:`\text{padding\_right}`, :math:`\text{padding\_top}`, :math:`\text{padding\_bottom}`) + + Shape: + - Input: :math:`(N, C, H_{in}, W_{in})` or :math:`(C, H_{in}, W_{in})`. + - Output: :math:`(N, C, H_{out}, W_{out})` or :math:`(C, H_{out}, W_{out})`, where + + :math:`H_{out} = H_{in} + \text{padding\_top} + \text{padding\_bottom}` + + :math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}` + + Examples:: + + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> m = nn.ZeroPad2d(2) + >>> input = core.randn(1, 1, 3, 3) + >>> input + tensor([[[[-0.1678, -0.4418, 1.9466], + [ 0.9604, -0.4219, -0.5241], + [-0.9162, -0.5436, -0.6446]]]]) + >>> m(input) + tensor([[[[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], + [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], + [ 0.0000, 0.0000, -0.1678, -0.4418, 1.9466, 0.0000, 0.0000], + [ 0.0000, 0.0000, 0.9604, -0.4219, -0.5241, 0.0000, 0.0000], + [ 0.0000, 0.0000, -0.9162, -0.5436, -0.6446, 0.0000, 0.0000], + [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], + [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]]]) + >>> # using different paddings for different sides + >>> m = nn.ZeroPad2d((1, 1, 2, 0)) + >>> m(input) + tensor([[[[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], + [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], + [ 0.0000, -0.1678, -0.4418, 1.9466, 0.0000], + [ 0.0000, 0.9604, -0.4219, -0.5241, 0.0000], + [ 0.0000, -0.9162, -0.5436, -0.6446, 0.0000]]]]) + """ + + padding: Tuple[int, int, int, int] + + def __init__(self, padding: _size_4_t) -> None: + super().__init__(padding, 0.) + + def extra_repr(self) -> str: + return f'{self.padding}' + + +class ZeroPad3d(ConstantPad3d): + r"""Pads the input tensor boundaries with zero. + + For `N`-dimensional padding, use :func:`core.nn.functional.pad()`. + + Args: + padding (int, tuple): the size of the padding. If is `int`, uses the same + padding in all boundaries. If a 6-`tuple`, uses + (:math:`\text{padding\_left}`, :math:`\text{padding\_right}`, + :math:`\text{padding\_top}`, :math:`\text{padding\_bottom}`, + :math:`\text{padding\_front}`, :math:`\text{padding\_back}`) + + Shape: + - Input: :math:`(N, C, D_{in}, H_{in}, W_{in})` or :math:`(C, D_{in}, H_{in}, W_{in})`. + - Output: :math:`(N, C, D_{out}, H_{out}, W_{out})` or + :math:`(C, D_{out}, H_{out}, W_{out})`, where + + :math:`D_{out} = D_{in} + \text{padding\_front} + \text{padding\_back}` + + :math:`H_{out} = H_{in} + \text{padding\_top} + \text{padding\_bottom}` + + :math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}` + + Examples:: + + >>> m = nn.ZeroPad3d(3) + >>> input = core.randn(16, 3, 10, 20, 30) + >>> output = m(input) + >>> # using different paddings for different sides + >>> m = nn.ZeroPad3d((3, 3, 6, 6, 0, 1)) + >>> output = m(input) + """ + + padding: Tuple[int, int, int, int, int, int] + + def __init__(self, padding: _size_6_t) -> None: + super().__init__(padding, 0.) + + def extra_repr(self) -> str: + return f'{self.padding}' diff --git a/mindnlp/core/nn/modules/pixelshuffle.py b/mindnlp/core/nn/modules/pixelshuffle.py new file mode 100644 index 000000000..59fccc1aa --- /dev/null +++ b/mindnlp/core/nn/modules/pixelshuffle.py @@ -0,0 +1,115 @@ +"""pixel shuffle""" +from mindnlp.core import Tensor +from .module import Module +from .. import functional as F + + +__all__ = ['PixelShuffle', 'PixelUnshuffle'] + +class PixelShuffle(Module): + r"""Rearrange elements in a tensor according to an upscaling factor. + + Rearranges elements in a tensor of shape :math:`(*, C \times r^2, H, W)` + to a tensor of shape :math:`(*, C, H \times r, W \times r)`, where r is an upscale factor. + + This is useful for implementing efficient sub-pixel convolution + with a stride of :math:`1/r`. + + See the paper: + `Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network`_ + by Shi et al. (2016) for more details. + + Args: + upscale_factor (int): factor to increase spatial resolution by + + Shape: + - Input: :math:`(*, C_{in}, H_{in}, W_{in})`, where * is zero or more batch dimensions + - Output: :math:`(*, C_{out}, H_{out}, W_{out})`, where + + .. math:: + C_{out} = C_{in} \div \text{upscale\_factor}^2 + + .. math:: + H_{out} = H_{in} \times \text{upscale\_factor} + + .. math:: + W_{out} = W_{in} \times \text{upscale\_factor} + + Examples:: + + >>> pixel_shuffle = nn.PixelShuffle(3) + >>> input = core.randn(1, 9, 4, 4) + >>> output = pixel_shuffle(input) + >>> print(output.size()) + core.Size([1, 1, 12, 12]) + + .. _Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network: + https://arxiv.org/abs/1609.05158 + """ + + __constants__ = ['upscale_factor'] + upscale_factor: int + + def __init__(self, upscale_factor: int) -> None: + super().__init__() + self.upscale_factor = upscale_factor + + def forward(self, input: Tensor) -> Tensor: + return F.pixel_shuffle(input, self.upscale_factor) + + def extra_repr(self) -> str: + return f'upscale_factor={self.upscale_factor}' + + + +class PixelUnshuffle(Module): + r"""Reverse the PixelShuffle operation. + + Reverses the :class:`~core.nn.PixelShuffle` operation by rearranging elements + in a tensor of shape :math:`(*, C, H \times r, W \times r)` to a tensor of shape + :math:`(*, C \times r^2, H, W)`, where r is a downscale factor. + + See the paper: + `Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network`_ + by Shi et al. (2016) for more details. + + Args: + downscale_factor (int): factor to decrease spatial resolution by + + Shape: + - Input: :math:`(*, C_{in}, H_{in}, W_{in})`, where * is zero or more batch dimensions + - Output: :math:`(*, C_{out}, H_{out}, W_{out})`, where + + .. math:: + C_{out} = C_{in} \times \text{downscale\_factor}^2 + + .. math:: + H_{out} = H_{in} \div \text{downscale\_factor} + + .. math:: + W_{out} = W_{in} \div \text{downscale\_factor} + + Examples:: + + >>> pixel_unshuffle = nn.PixelUnshuffle(3) + >>> input = core.randn(1, 1, 12, 12) + >>> output = pixel_unshuffle(input) + >>> print(output.size()) + core.Size([1, 9, 4, 4]) + + .. _Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network: + https://arxiv.org/abs/1609.05158 + """ + + __constants__ = ['downscale_factor'] + downscale_factor: int + + def __init__(self, downscale_factor: int) -> None: + super().__init__() + self.downscale_factor = downscale_factor + + def forward(self, input: Tensor) -> Tensor: + return F.pixel_unshuffle(input, self.downscale_factor) + + def extra_repr(self) -> str: + return f'downscale_factor={self.downscale_factor}' diff --git a/mindnlp/core/nn/modules/pooling.py b/mindnlp/core/nn/modules/pooling.py new file mode 100644 index 000000000..91265372a --- /dev/null +++ b/mindnlp/core/nn/modules/pooling.py @@ -0,0 +1,585 @@ +"""pooling""" +# pylint: disable=unused-import +from typing import Optional +from mindnlp.core import Tensor + +from .module import Module +from ._utils import _single +from ..common_types import (_size_any_t, _size_1_t, _size_2_t, _size_3_t, + _ratio_3_t, _ratio_2_t, _size_any_opt_t, _size_2_opt_t, _size_3_opt_t) +from .. import functional as F + +class _MaxPoolNd(Module): + __constants__ = ['kernel_size', 'stride', 'padding', 'dilation', + 'return_indices', 'ceil_mode'] + return_indices: bool + ceil_mode: bool + + def __init__(self, kernel_size: _size_any_t, stride: Optional[_size_any_t] = None, + padding: _size_any_t = 0, dilation: _size_any_t = 1, + return_indices: bool = False, ceil_mode: bool = False) -> None: + super().__init__() + self.kernel_size = kernel_size + self.stride = stride if (stride is not None) else kernel_size + self.padding = padding + self.dilation = dilation + self.return_indices = return_indices + self.ceil_mode = ceil_mode + + def extra_repr(self) -> str: + return 'kernel_size={kernel_size}, stride={stride}, padding={padding}' \ + ', dilation={dilation}, ceil_mode={ceil_mode}'.format(**self.__dict__) + + +class MaxPool1d(_MaxPoolNd): + r"""Applies a 1D max pooling over an input signal composed of several input planes. + + In the simplest case, the output value of the layer with input size :math:`(N, C, L)` + and output :math:`(N, C, L_{out})` can be precisely described as: + + .. math:: + out(N_i, C_j, k) = \max_{m=0, \ldots, \text{kernel\_size} - 1} + input(N_i, C_j, stride \times k + m) + + If :attr:`padding` is non-zero, then the input is implicitly padded with negative infinity on both sides + for :attr:`padding` number of points. :attr:`dilation` is the stride between the elements within the + sliding window. This `link`_ has a nice visualization of the pooling parameters. + + Note: + When ceil_mode=True, sliding windows are allowed to go off-bounds if they start within the left padding + or the input. Sliding windows that would start in the right padded region are ignored. + + Args: + kernel_size: The size of the sliding window, must be > 0. + stride: The stride of the sliding window, must be > 0. Default value is :attr:`kernel_size`. + padding: Implicit negative infinity padding to be added on both sides, must be >= 0 and <= kernel_size / 2. + dilation: The stride between elements within a sliding window, must be > 0. + return_indices: If ``True``, will return the argmax along with the max values. + Useful for :class:`core.nn.MaxUnpool1d` later + ceil_mode: If ``True``, will use `ceil` instead of `floor` to compute the output shape. This + ensures that every element in the input tensor is covered by a sliding window. + + Shape: + - Input: :math:`(N, C, L_{in})` or :math:`(C, L_{in})`. + - Output: :math:`(N, C, L_{out})` or :math:`(C, L_{out})`, where + + .. math:: + L_{out} = \left\lfloor \frac{L_{in} + 2 \times \text{padding} - \text{dilation} + \times (\text{kernel\_size} - 1) - 1}{\text{stride}} + 1\right\rfloor + + Examples:: + + >>> # pool of size=3, stride=2 + >>> m = nn.MaxPool1d(3, stride=2) + >>> input = core.randn(20, 16, 50) + >>> output = m(input) + + .. _link: + https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md + """ + + kernel_size: _size_1_t + stride: _size_1_t + padding: _size_1_t + dilation: _size_1_t + + def forward(self, input: Tensor): + return F.max_pool1d(input, self.kernel_size, self.stride, + self.padding, self.dilation, ceil_mode=self.ceil_mode, + return_indices=self.return_indices) + + +class MaxPool2d(_MaxPoolNd): + r"""Applies a 2D max pooling over an input signal composed of several input planes. + + In the simplest case, the output value of the layer with input size :math:`(N, C, H, W)`, + output :math:`(N, C, H_{out}, W_{out})` and :attr:`kernel_size` :math:`(kH, kW)` + can be precisely described as: + + .. math:: + \begin{aligned} + out(N_i, C_j, h, w) ={} & \max_{m=0, \ldots, kH-1} \max_{n=0, \ldots, kW-1} \\ + & \text{input}(N_i, C_j, \text{stride[0]} \times h + m, + \text{stride[1]} \times w + n) + \end{aligned} + + If :attr:`padding` is non-zero, then the input is implicitly padded with negative infinity on both sides + for :attr:`padding` number of points. :attr:`dilation` controls the spacing between the kernel points. + It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does. + + Note: + When ceil_mode=True, sliding windows are allowed to go off-bounds if they start within the left padding + or the input. Sliding windows that would start in the right padded region are ignored. + + The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`dilation` can either be: + + - a single ``int`` -- in which case the same value is used for the height and width dimension + - a ``tuple`` of two ints -- in which case, the first `int` is used for the height dimension, + and the second `int` for the width dimension + + Args: + kernel_size: the size of the window to take a max over + stride: the stride of the window. Default value is :attr:`kernel_size` + padding: Implicit negative infinity padding to be added on both sides + dilation: a parameter that controls the stride of elements in the window + return_indices: if ``True``, will return the max indices along with the outputs. + Useful for :class:`core.nn.MaxUnpool2d` later + ceil_mode: when True, will use `ceil` instead of `floor` to compute the output shape + + Shape: + - Input: :math:`(N, C, H_{in}, W_{in})` or :math:`(C, H_{in}, W_{in})` + - Output: :math:`(N, C, H_{out}, W_{out})` or :math:`(C, H_{out}, W_{out})`, where + + .. math:: + H_{out} = \left\lfloor\frac{H_{in} + 2 * \text{padding[0]} - \text{dilation[0]} + \times (\text{kernel\_size[0]} - 1) - 1}{\text{stride[0]}} + 1\right\rfloor + + .. math:: + W_{out} = \left\lfloor\frac{W_{in} + 2 * \text{padding[1]} - \text{dilation[1]} + \times (\text{kernel\_size[1]} - 1) - 1}{\text{stride[1]}} + 1\right\rfloor + + Examples:: + + >>> # pool of square window of size=3, stride=2 + >>> m = nn.MaxPool2d(3, stride=2) + >>> # pool of non-square window + >>> m = nn.MaxPool2d((3, 2), stride=(2, 1)) + >>> input = core.randn(20, 16, 50, 32) + >>> output = m(input) + + .. _link: + https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md + """ + + kernel_size: _size_2_t + stride: _size_2_t + padding: _size_2_t + dilation: _size_2_t + + def forward(self, input: Tensor): + return F.max_pool2d(input, self.kernel_size, self.stride, + self.padding, self.dilation, ceil_mode=self.ceil_mode, + return_indices=self.return_indices) + + +class MaxPool3d(_MaxPoolNd): + r"""Applies a 3D max pooling over an input signal composed of several input planes. + + In the simplest case, the output value of the layer with input size :math:`(N, C, D, H, W)`, + output :math:`(N, C, D_{out}, H_{out}, W_{out})` and :attr:`kernel_size` :math:`(kD, kH, kW)` + can be precisely described as: + + .. math:: + \begin{aligned} + \text{out}(N_i, C_j, d, h, w) ={} & \max_{k=0, \ldots, kD-1} \max_{m=0, \ldots, kH-1} \max_{n=0, \ldots, kW-1} \\ + & \text{input}(N_i, C_j, \text{stride[0]} \times d + k, + \text{stride[1]} \times h + m, \text{stride[2]} \times w + n) + \end{aligned} + + If :attr:`padding` is non-zero, then the input is implicitly padded with negative infinity on both sides + for :attr:`padding` number of points. :attr:`dilation` controls the spacing between the kernel points. + It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does. + + Note: + When ceil_mode=True, sliding windows are allowed to go off-bounds if they start within the left padding + or the input. Sliding windows that would start in the right padded region are ignored. + + The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`dilation` can either be: + + - a single ``int`` -- in which case the same value is used for the depth, height and width dimension + - a ``tuple`` of three ints -- in which case, the first `int` is used for the depth dimension, + the second `int` for the height dimension and the third `int` for the width dimension + + Args: + kernel_size: the size of the window to take a max over + stride: the stride of the window. Default value is :attr:`kernel_size` + padding: Implicit negative infinity padding to be added on all three sides + dilation: a parameter that controls the stride of elements in the window + return_indices: if ``True``, will return the max indices along with the outputs. + Useful for :class:`core.nn.MaxUnpool3d` later + ceil_mode: when True, will use `ceil` instead of `floor` to compute the output shape + + Shape: + - Input: :math:`(N, C, D_{in}, H_{in}, W_{in})` or :math:`(C, D_{in}, H_{in}, W_{in})`. + - Output: :math:`(N, C, D_{out}, H_{out}, W_{out})` or :math:`(C, D_{out}, H_{out}, W_{out})`, where + + .. math:: + D_{out} = \left\lfloor\frac{D_{in} + 2 \times \text{padding}[0] - \text{dilation}[0] \times + (\text{kernel\_size}[0] - 1) - 1}{\text{stride}[0]} + 1\right\rfloor + + .. math:: + H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[1] - \text{dilation}[1] \times + (\text{kernel\_size}[1] - 1) - 1}{\text{stride}[1]} + 1\right\rfloor + + .. math:: + W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[2] - \text{dilation}[2] \times + (\text{kernel\_size}[2] - 1) - 1}{\text{stride}[2]} + 1\right\rfloor + + Examples:: + + >>> # pool of square window of size=3, stride=2 + >>> m = nn.MaxPool3d(3, stride=2) + >>> # pool of non-square window + >>> m = nn.MaxPool3d((3, 2, 2), stride=(2, 1, 2)) + >>> input = core.randn(20, 16, 50, 44, 31) + >>> output = m(input) + + .. _link: + https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md + """ # noqa: E501 + + kernel_size: _size_3_t + stride: _size_3_t + padding: _size_3_t + dilation: _size_3_t + + def forward(self, input: Tensor): + return F.max_pool3d(input, self.kernel_size, self.stride, + self.padding, self.dilation, ceil_mode=self.ceil_mode, + return_indices=self.return_indices) + + +class _AdaptiveAvgPoolNd(Module): + __constants__ = ['output_size'] + + def __init__(self, output_size: _size_any_opt_t) -> None: + super().__init__() + self.output_size = output_size + + def extra_repr(self) -> str: + return f'output_size={self.output_size}' + + +class AdaptiveAvgPool2d(_AdaptiveAvgPoolNd): + r"""Applies a 2D adaptive average pooling over an input signal composed of several input planes. + + The output is of size H x W, for any input size. + The number of output features is equal to the number of input planes. + + Args: + output_size: the target output size of the image of the form H x W. + Can be a tuple (H, W) or a single H for a square image H x H. + H and W can be either a ``int``, or ``None`` which means the size will + be the same as that of the input. + + Shape: + - Input: :math:`(N, C, H_{in}, W_{in})` or :math:`(C, H_{in}, W_{in})`. + - Output: :math:`(N, C, S_{0}, S_{1})` or :math:`(C, S_{0}, S_{1})`, where + :math:`S=\text{output\_size}`. + + Examples: + >>> # target output size of 5x7 + >>> m = nn.AdaptiveAvgPool2d((5, 7)) + >>> input = core.randn(1, 64, 8, 9) + >>> output = m(input) + >>> # target output size of 7x7 (square) + >>> m = nn.AdaptiveAvgPool2d(7) + >>> input = core.randn(1, 64, 10, 9) + >>> output = m(input) + >>> # target output size of 10x7 + >>> m = nn.AdaptiveAvgPool2d((None, 7)) + >>> input = core.randn(1, 64, 10, 9) + >>> output = m(input) + + """ + + output_size: _size_2_opt_t + + def forward(self, input: Tensor) -> Tensor: + return ops.adaptive_avg_pool2d(input, self.output_size) + +class AdaptiveAvgPool1d(_AdaptiveAvgPoolNd): + r"""Applies a 1D adaptive average pooling over an input signal composed of several input planes. + + The output size is :math:`L_{out}`, for any input size. + The number of output features is equal to the number of input planes. + + Args: + output_size: the target output size :math:`L_{out}`. + + Shape: + - Input: :math:`(N, C, L_{in})` or :math:`(C, L_{in})`. + - Output: :math:`(N, C, L_{out})` or :math:`(C, L_{out})`, where + :math:`L_{out}=\text{output\_size}`. + + Examples: + >>> # target output size of 5 + >>> m = nn.AdaptiveAvgPool1d(5) + >>> input = core.randn(1, 64, 8) + >>> output = m(input) + + """ + + output_size: _size_1_t + + def forward(self, input: Tensor) -> Tensor: + # Add a dimension to make it 2D + input_2d = input.unsqueeze(2) + + # Perform adaptive average pooling + output_2d = ops.adaptive_avg_pool2d(input_2d, (self.output_size, 1)) + + # Remove the added dimension to make it back to 1D + output_1d = output_2d.squeeze(2) + + return output_1d + +class _AvgPoolNd(Module): + __constants__ = ['kernel_size', 'stride', 'padding', 'ceil_mode', 'count_include_pad'] + + def extra_repr(self) -> str: + return f'kernel_size={self.kernel_size}, stride={self.stride}, padding={self.padding}' + + +class AvgPool1d(_AvgPoolNd): + r"""Applies a 1D average pooling over an input signal composed of several input planes. + + In the simplest case, the output value of the layer with input size :math:`(N, C, L)`, + output :math:`(N, C, L_{out})` and :attr:`kernel_size` :math:`k` + can be precisely described as: + + .. math:: + + \text{out}(N_i, C_j, l) = \frac{1}{k} \sum_{m=0}^{k-1} + \text{input}(N_i, C_j, \text{stride} \times l + m) + + If :attr:`padding` is non-zero, then the input is implicitly zero-padded on both sides + for :attr:`padding` number of points. + + Note: + When ceil_mode=True, sliding windows are allowed to go off-bounds if they start within the left padding + or the input. Sliding windows that would start in the right padded region are ignored. + + The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding` can each be + an ``int`` or a one-element tuple. + + Args: + kernel_size: the size of the window + stride: the stride of the window. Default value is :attr:`kernel_size` + padding: implicit zero padding to be added on both sides + ceil_mode: when True, will use `ceil` instead of `floor` to compute the output shape + count_include_pad: when True, will include the zero-padding in the averaging calculation + + Shape: + - Input: :math:`(N, C, L_{in})` or :math:`(C, L_{in})`. + - Output: :math:`(N, C, L_{out})` or :math:`(C, L_{out})`, where + + .. math:: + L_{out} = \left\lfloor \frac{L_{in} + + 2 \times \text{padding} - \text{kernel\_size}}{\text{stride}} + 1\right\rfloor + + Per the note above, if ``ceil_mode`` is True and :math:`(L_{out} - 1) \times \text{stride} \geq L_{in} + + \text{padding}`, we skip the last window as it would start in the right padded region, resulting in + :math:`L_{out}` being reduced by one. + + Examples:: + + >>> # pool with window of size=3, stride=2 + >>> m = nn.AvgPool1d(3, stride=2) + >>> m(core.tensor([[[1., 2, 3, 4, 5, 6, 7]]])) + tensor([[[2., 4., 6.]]]) + """ + + kernel_size: _size_1_t + stride: _size_1_t + padding: _size_1_t + ceil_mode: bool + count_include_pad: bool + + def __init__(self, kernel_size: _size_1_t, stride: _size_1_t = None, padding: _size_1_t = 0, ceil_mode: bool = False, + count_include_pad: bool = True) -> None: + super().__init__() + self.kernel_size = _single(kernel_size) + self.stride = _single(stride if stride is not None else kernel_size) + self.padding = _single(padding) + self.ceil_mode = ceil_mode + self.count_include_pad = count_include_pad + + def forward(self, input: Tensor) -> Tensor: + return F.avg_pool1d( + input, self.kernel_size[0], self.stride[0], self.padding[0], self.ceil_mode, + self.count_include_pad) + + +class AvgPool2d(_AvgPoolNd): + r"""Applies a 2D average pooling over an input signal composed of several input planes. + + In the simplest case, the output value of the layer with input size :math:`(N, C, H, W)`, + output :math:`(N, C, H_{out}, W_{out})` and :attr:`kernel_size` :math:`(kH, kW)` + can be precisely described as: + + .. math:: + + out(N_i, C_j, h, w) = \frac{1}{kH * kW} \sum_{m=0}^{kH-1} \sum_{n=0}^{kW-1} + input(N_i, C_j, stride[0] \times h + m, stride[1] \times w + n) + + If :attr:`padding` is non-zero, then the input is implicitly zero-padded on both sides + for :attr:`padding` number of points. + + Note: + When ceil_mode=True, sliding windows are allowed to go off-bounds if they start within the left padding + or the input. Sliding windows that would start in the right padded region are ignored. + + The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding` can either be: + + - a single ``int`` -- in which case the same value is used for the height and width dimension + - a ``tuple`` of two ints -- in which case, the first `int` is used for the height dimension, + and the second `int` for the width dimension + + Args: + kernel_size: the size of the window + stride: the stride of the window. Default value is :attr:`kernel_size` + padding: implicit zero padding to be added on both sides + ceil_mode: when True, will use `ceil` instead of `floor` to compute the output shape + count_include_pad: when True, will include the zero-padding in the averaging calculation + divisor_override: if specified, it will be used as divisor, otherwise size of the pooling region will be used. + + + Shape: + - Input: :math:`(N, C, H_{in}, W_{in})` or :math:`(C, H_{in}, W_{in})`. + - Output: :math:`(N, C, H_{out}, W_{out})` or :math:`(C, H_{out}, W_{out})`, where + + .. math:: + H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[0] - + \text{kernel\_size}[0]}{\text{stride}[0]} + 1\right\rfloor + + .. math:: + W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[1] - + \text{kernel\_size}[1]}{\text{stride}[1]} + 1\right\rfloor + + Per the note above, if ``ceil_mode`` is True and :math:`(H_{out} - 1)\times \text{stride}[0]\geq H_{in} + + \text{padding}[0]`, we skip the last window as it would start in the bottom padded region, + resulting in :math:`H_{out}` being reduced by one. + + The same applies for :math:`W_{out}`. + + Examples:: + + >>> # pool of square window of size=3, stride=2 + >>> m = nn.AvgPool2d(3, stride=2) + >>> # pool of non-square window + >>> m = nn.AvgPool2d((3, 2), stride=(2, 1)) + >>> input = core.randn(20, 16, 50, 32) + >>> output = m(input) + """ + + __constants__ = ['kernel_size', 'stride', 'padding', 'ceil_mode', 'count_include_pad', 'divisor_override'] + + kernel_size: _size_2_t + stride: _size_2_t + padding: _size_2_t + ceil_mode: bool + count_include_pad: bool + + def __init__(self, kernel_size: _size_2_t, stride: Optional[_size_2_t] = None, padding: _size_2_t = 0, + ceil_mode: bool = False, count_include_pad: bool = True, divisor_override: Optional[int] = None) -> None: + super().__init__() + self.kernel_size = kernel_size + self.stride = stride if (stride is not None) else kernel_size + self.padding = padding + self.ceil_mode = ceil_mode + self.count_include_pad = count_include_pad + self.divisor_override = divisor_override if divisor_override is not None else 0 + + def forward(self, input: Tensor) -> Tensor: + return F.avg_pool2d(input, self.kernel_size, self.stride, + self.padding, self.ceil_mode, self.count_include_pad, self.divisor_override) + + + +class AvgPool3d(_AvgPoolNd): + r"""Applies a 3D average pooling over an input signal composed of several input planes. + + In the simplest case, the output value of the layer with input size :math:`(N, C, D, H, W)`, + output :math:`(N, C, D_{out}, H_{out}, W_{out})` and :attr:`kernel_size` :math:`(kD, kH, kW)` + can be precisely described as: + + .. math:: + \begin{aligned} + \text{out}(N_i, C_j, d, h, w) ={} & \sum_{k=0}^{kD-1} \sum_{m=0}^{kH-1} \sum_{n=0}^{kW-1} \\ + & \frac{\text{input}(N_i, C_j, \text{stride}[0] \times d + k, + \text{stride}[1] \times h + m, \text{stride}[2] \times w + n)} + {kD \times kH \times kW} + \end{aligned} + + If :attr:`padding` is non-zero, then the input is implicitly zero-padded on all three sides + for :attr:`padding` number of points. + + Note: + When ceil_mode=True, sliding windows are allowed to go off-bounds if they start within the left padding + or the input. Sliding windows that would start in the right padded region are ignored. + + The parameters :attr:`kernel_size`, :attr:`stride` can either be: + + - a single ``int`` -- in which case the same value is used for the depth, height and width dimension + - a ``tuple`` of three ints -- in which case, the first `int` is used for the depth dimension, + the second `int` for the height dimension and the third `int` for the width dimension + + Args: + kernel_size: the size of the window + stride: the stride of the window. Default value is :attr:`kernel_size` + padding: implicit zero padding to be added on all three sides + ceil_mode: when True, will use `ceil` instead of `floor` to compute the output shape + count_include_pad: when True, will include the zero-padding in the averaging calculation + divisor_override: if specified, it will be used as divisor, otherwise :attr:`kernel_size` will be used + + Shape: + - Input: :math:`(N, C, D_{in}, H_{in}, W_{in})` or :math:`(C, D_{in}, H_{in}, W_{in})`. + - Output: :math:`(N, C, D_{out}, H_{out}, W_{out})` or + :math:`(C, D_{out}, H_{out}, W_{out})`, where + + .. math:: + D_{out} = \left\lfloor\frac{D_{in} + 2 \times \text{padding}[0] - + \text{kernel\_size}[0]}{\text{stride}[0]} + 1\right\rfloor + + .. math:: + H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[1] - + \text{kernel\_size}[1]}{\text{stride}[1]} + 1\right\rfloor + + .. math:: + W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[2] - + \text{kernel\_size}[2]}{\text{stride}[2]} + 1\right\rfloor + + Per the note above, if ``ceil_mode`` is True and :math:`(D_{out} - 1)\times \text{stride}[0]\geq D_{in} + + \text{padding}[0]`, we skip the last window as it would start in the padded region, + resulting in :math:`D_{out}` being reduced by one. + + The same applies for :math:`W_{out}` and :math:`H_{out}`. + + Examples:: + + >>> # pool of square window of size=3, stride=2 + >>> m = nn.AvgPool3d(3, stride=2) + >>> # pool of non-square window + >>> m = nn.AvgPool3d((3, 2, 2), stride=(2, 1, 2)) + >>> input = core.randn(20, 16, 50, 44, 31) + >>> output = m(input) + """ + + __constants__ = ['kernel_size', 'stride', 'padding', 'ceil_mode', 'count_include_pad', 'divisor_override'] + + kernel_size: _size_3_t + stride: _size_3_t + padding: _size_3_t + ceil_mode: bool + count_include_pad: bool + + def __init__(self, kernel_size: _size_3_t, stride: Optional[_size_3_t] = None, padding: _size_3_t = 0, + ceil_mode: bool = False, count_include_pad: bool = True, divisor_override: Optional[int] = None) -> None: + super().__init__() + self.kernel_size = kernel_size + self.stride = stride if (stride is not None) else kernel_size + self.padding = padding + self.ceil_mode = ceil_mode + self.count_include_pad = count_include_pad + self.divisor_override = divisor_override + + def forward(self, input: Tensor) -> Tensor: + return F.avg_pool3d(input, self.kernel_size, self.stride, + self.padding, self.ceil_mode, self.count_include_pad, self.divisor_override) + + def __setstate__(self, d): + super().__setstate__(d) + self.__dict__.setdefault('padding', 0) + self.__dict__.setdefault('ceil_mode', False) + self.__dict__.setdefault('count_include_pad', True) diff --git a/mindnlp/core/nn/modules/rnn.py b/mindnlp/core/nn/modules/rnn.py new file mode 100644 index 000000000..5c01c05f1 --- /dev/null +++ b/mindnlp/core/nn/modules/rnn.py @@ -0,0 +1,779 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""RNN operators module, include RNN, GRU.""" +import math +import warnings + +from ..parameter import Parameter +from .module import Module +from .dropout import Dropout +from ... import ops +from .. import init + + +__all__ = ['LSTM', 'GRU', 'RNN'] + + +def _init_state(shape, dtype, is_lstm): + hx = ops.zeros(*shape, dtype=dtype) + cx = ops.zeros(*shape, dtype=dtype) + if is_lstm: + return (hx, cx) + return hx + + +def sequence_mask(lengths, maxlen): + """generate mask matrix by seq_length""" + range_vector = ops.arange(start=0, end=maxlen, step=1, dtype=lengths.dtype) + result = range_vector < lengths.view(lengths.shape + (1,)) + return result.astype(core.int32) + + +def select_by_mask(inputs, mask): + """mask hiddens by mask matrix""" + return mask.view(mask.shape + (1,)).swapaxes(0, 1) \ + .expand_as(inputs).astype(core.bool_) * inputs + + +def get_hidden(output, seq_length): + """get hidden state by seq_length""" + batch_index = ops.arange(start=0, end=seq_length.shape[0], step=1, dtype=seq_length.dtype) + indices = ops.cat((seq_length.view(-1, 1) - 1, batch_index.view(-1, 1)), 1) + return ops.gather_nd(output, indices) + + +class _DynamicRNNBase(Module): + '''Dynamic RNN module to compute RNN cell by timesteps''' + + def __init__(self, mode): + super().__init__() + if mode == "RNN_RELU": + cell = _rnn_relu_cell + elif mode == "RNN_TANH": + cell = _rnn_tanh_cell + elif mode == "LSTM": + cell = _lstm_cell + elif mode == "GRU": + cell = _gru_cell + else: + raise ValueError("Unrecognized RNN mode: " + mode) + self.cell = cell + self.is_lstm = mode == "LSTM" + + def recurrent(self, x, h_0, w_ih, w_hh, b_ih, b_hh): + '''recurrent steps without sequence length''' + time_step = x.shape[0] + outputs = [] + t = 0 + h = h_0 + while t < time_step: + x_t = x[t:t + 1:1] + x_t = x_t.squeeze(0) + h = self.cell(x_t, h, w_ih, w_hh, b_ih, b_hh) + if self.is_lstm: + outputs.append(h[0]) + else: + outputs.append(h) + t += 1 + outputs = ops.stack(outputs) + return outputs, h + + def variable_recurrent(self, x, h, seq_length, w_ih, w_hh, b_ih, b_hh): + '''recurrent steps with sequence length''' + time_step = x.shape[0] + h_t = h + if self.is_lstm: + hidden_size = h[0].shape[-1] + zero_output = ops.zeros_like(h_t[0]) + else: + hidden_size = h.shape[-1] + zero_output = ops.zeros_like(h_t) + seq_length = seq_length.to(core.float32) + seq_length = ops.broadcast_to(seq_length, (hidden_size, -1)) + seq_length = seq_length.to(core.int32) + seq_length = ops.transpose(seq_length, 1, 0) + + outputs = [] + state_t = h_t + t = 0 + while t < time_step: + x_t = x[t:t + 1:1] + x_t = x_t.squeeze(0) + h_t = self.cell(x_t, state_t, w_ih, w_hh, b_ih, b_hh) + seq_cond = seq_length > t + if self.is_lstm: + state_t_0 = ops.where(seq_cond, h_t[0], state_t[0]) + state_t_1 = ops.where(seq_cond, h_t[1], state_t[1]) + output = ops.where(seq_cond, h_t[0], zero_output) + state_t = (state_t_0, state_t_1) + else: + state_t = ops.where(seq_cond, h_t, state_t) + output = ops.where(seq_cond, h_t, zero_output) + outputs.append(output) + t += 1 + outputs = ops.stack(outputs) + return outputs, state_t + + def forward(self, x, h, seq_length, w_ih, w_hh, b_ih, b_hh): + x_dtype = x.dtype + w_ih = w_ih.astype(x_dtype) + w_hh = w_hh.astype(x_dtype) + if b_ih is not None: + b_ih = b_ih.astype(x_dtype) + b_hh = b_hh.astype(x_dtype) + if seq_length is None: + return self.recurrent(x, h, w_ih, w_hh, b_ih, b_hh) + return self.variable_recurrent(x, h, seq_length, w_ih, w_hh, b_ih, b_hh) + + +class _DynamicRNNRelu(_DynamicRNNBase): + '''Dynamic RNN module with Relu activation''' + + def __init__(self): + mode = 'RNN_RELU' + super().__init__(mode) + + +class _DynamicRNNTanh(_DynamicRNNBase): + '''Dynamic RNN module with Tanh activation''' + + def __init__(self): + mode = 'RNN_TANH' + super().__init__(mode) + + +class _DynamicGRUCPUGPU(Module): + '''Dynamic GRU module on CPU and GPU''' + + def __init__(self): + super().__init__() + + def forward(self, x, h_0, seq_length, w_ih, w_hh, b_ih, b_hh): + '''_DynamicGRUCPUGPU''' + gate_size, input_size = w_ih.shape + hidden_size = gate_size // 3 + if self.is_gpu: + if b_ih is None: + weights = ops.concat(( + w_ih.view(-1, 1, 1), + w_hh.view(-1, 1, 1) + )) + bias = False + else: + bias = True + weights = ops.concat(( + w_ih.view(-1, 1, 1), + w_hh.view(-1, 1, 1), + b_ih.view(-1, 1, 1), + b_hh.view(-1, 1, 1) + )) + _gru = _get_cache_prim(CudnnGRU)(input_size, hidden_size, 1, bias, False, 0.0) + output, h_n, _, _ = _gru( + x, + h_0.view(1, *h_0.shape), + weights.astype(x.dtype) + ) + if seq_length is not None: + h_n = get_hidden(output, seq_length) + mask = sequence_mask(seq_length, x.shape[0]) + output = select_by_mask(output, mask) + else: + output, h_n = _DynamicRNNBase('GRU')(x, h_0, seq_length, w_ih, w_hh, b_ih, b_hh) + + return output, h_n + + +class _DynamicGRUAscend(Module): + '''Dynamic GRU module on Ascend''' + + def __init__(self): + super().__init__() + self.gru = DynamicGRUV2(gate_order='rzh') + + def forward(self, x, h_0, seq_length, w_ih, w_hh, b_ih, b_hh): + '''Dynamic GRU module on Ascend''' + if b_ih is None: + b_ih = ops.zeros(w_ih.shape[0], dtype=w_ih.dtype) + b_hh = ops.zeros(w_ih.shape[0], dtype=w_ih.dtype) + outputs, _, _, _, _, _ = self.gru(x.to(self.dtype), \ + ops.transpose(w_ih, 1, 0), \ + ops.transpose(w_hh, 1, 0), \ + b_ih, \ + b_hh, \ + None, h_0) + if seq_length is not None: + h = get_hidden(outputs, seq_length) + mask = sequence_mask(seq_length, x.shape[0]) + outputs = select_by_mask(outputs, mask) + else: + h = outputs[-1] + return outputs, h + + +class _DynamicLSTMCPUGPU(Module): + '''Dynamic LSTM module on CPU and GPU''' + + def __init__(self): + super().__init__() + + def forward(self, x, h_0, seq_length, w_ih, w_hh, b_ih, b_hh): + '''Dynamic LSTM module on CPU and GPU''' + gate_size, input_size = w_ih.shape + hidden_size = gate_size // 4 + if seq_length is not None: + output, (h_n, c_n) = _DynamicRNNBase('LSTM')(x, h_0, seq_length, w_ih, w_hh, b_ih, b_hh) + else: + if b_ih is None: + weights = ops.concat(( + w_ih.view(-1, 1, 1), + w_hh.view(-1, 1, 1) + )) + has_bias = False + else: + has_bias = True + if self.is_gpu: + weights = ops.concat(( + w_ih.view(-1, 1, 1), + w_hh.view(-1, 1, 1), + b_ih.view(-1, 1, 1), + b_hh.view(-1, 1, 1) + )) + else: + bias = b_ih + b_hh + weights = ops.concat(( + w_ih.view(-1, 1, 1), + w_hh.view(-1, 1, 1), + bias.view(-1, 1, 1) + )) + _lstm = _get_cache_prim(LSTMOP)(input_size, hidden_size, 1, has_bias, False, 0.0) + output, h_n, c_n, _, _ = _lstm( + x, + h_0[0].unsqueeze(0), + h_0[1].unsqueeze(0), + weights.astype(x.dtype) + ) + return output, (h_n, c_n) + + +class _DynamicLSTMAscend(Module): + '''Dynamic LSTM module on Ascend''' + + def __init__(self): + super().__init__() + self.lstm = DynamicRNN() + + def forward(self, x, h_0, seq_length, w_ih, w_hh, b_ih, b_hh): + '''Dynamic LSTM module on Ascend''' + w_ih_i, w_ih_f, w_ih_g, w_ih_o = ops.chunk(w_ih, 4, 0) + w_hh_i, w_hh_f, w_hh_g, w_hh_o = ops.chunk(w_hh, 4, 0) + w_ih = ops.cat((w_ih_i, w_ih_g, w_ih_f, w_ih_o), 0) + w_hh = ops.cat((w_hh_i, w_hh_g, w_hh_f, w_hh_o), 0) + weight = ops.cat((w_ih, w_hh), 1) + if b_ih is None: + bias = ops.zeros(w_ih.shape[0], dtype=w_ih.dtype) + else: + b_ih_i, b_ih_f, b_ih_g, b_ih_o = ops.chunk(b_ih, 4, 0) + b_hh_i, b_hh_f, b_hh_g, b_hh_o = ops.chunk(b_hh, 4, 0) + bias = ops.cat((b_ih_i + b_hh_i, \ + b_ih_g + b_hh_g, \ + b_ih_f + b_hh_f, \ + b_ih_o + b_hh_o), 0) + + outputs, h, c, _, _, _, _, _ = self.lstm(x.to(core.float16), \ + ops.transpose(weight, 1, 0).to(core.float16), \ + bias.to(core.float16), None, \ + h_0[0].unsqueeze(0).to(core.float16), \ + h_0[1].unsqueeze(0).to(core.float16)) + if seq_length is not None: + h = get_hidden(h, seq_length) + c = get_hidden(c, seq_length) + mask = sequence_mask(seq_length, x.shape[0]) + outputs = select_by_mask(outputs, mask) + else: + h = h[-1] + c = c[-1] + return outputs, (h, c) + + +class _RNNBase(Module): + '''Basic class for RNN operators''' + + def __init__(self, mode, input_size, hidden_size, num_layers=1, bias=True, + batch_first=False, dropout=0., bidirectional=False, dtype=None): + factory_kwargs = {'dtype': dtype} + super().__init__() + + if not 0 <= dropout < 1: + raise ValueError(f"For '{self.cls_name}', the 'dropout' must be a number in range [0, 1) " + f"representing the probability of an element being zeroed, but got {dropout}.") + + if dropout > 0 and num_layers == 1: + warnings.warn("dropout option adds dropout after all but last " + "recurrent layer, so non-zero dropout expects " + "num_layers greater than 1, but got dropout={} and " + "num_layers={}".format(dropout, num_layers)) + + if mode == "LSTM": + gate_size = 4 * hidden_size + self.rnn = _DynamicLSTMAscend() if is_ascend else _DynamicLSTMCPUGPU() + elif mode == "GRU": + if is_ascend and hidden_size % 16 != 0: + raise ValueError(f"GRU on ascend do not support hidden size that is not divisible by 16, " + f"but get hidden size {hidden_size}, please reset the argument.") + gate_size = 3 * hidden_size + self.rnn = _DynamicGRUAscend() if is_ascend else _DynamicGRUCPUGPU() + elif mode == "RNN_TANH": + gate_size = hidden_size + self.rnn = _DynamicRNNTanh() + elif mode == "RNN_RELU": + gate_size = hidden_size + self.rnn = _DynamicRNNRelu() + else: + raise ValueError(f"For '{self.cls_name}', the 'mode' must be in ['RNN_RELU', 'RNN_TANH', 'LSTM', 'GRU'], " + f"but got {mode}.") + + self.reverse = ReverseV2([0]) + self.reverse_sequence = ReverseSequence(0, 1) + self.hidden_size = hidden_size + self.batch_first = batch_first + self.num_layers = num_layers + self.dropout = dropout + self.dropout_op = Dropout(p=float(dropout)) + self.bidirectional = bidirectional + self.bias = bias + num_directions = 2 if bidirectional else 1 + self.is_lstm = mode == "LSTM" + + self.w_ih_list = [] + self.w_hh_list = [] + self.b_ih_list = [] + self.b_hh_list = [] + stdv = 1 / math.sqrt(self.hidden_size) + for layer in range(num_layers): + for direction in range(num_directions): + layer_input_size = input_size if layer == 0 else hidden_size * num_directions + suffix = '_reverse' if direction == 1 else '' + + w_ih = Parameter(ops.empty((gate_size, layer_input_size), **factory_kwargs)) + w_hh = Parameter(ops.empty((gate_size, hidden_size), **factory_kwargs)) + self.w_ih_list.append(w_ih) + self.w_hh_list.append(w_hh) + if bias: + b_ih = Parameter(ops.empty(gate_size, **factory_kwargs)) + # Second bias vector included for CuDNN compatibility. Only one + # bias vector is needed in standard definition. + b_hh = Parameter(ops.empty(gate_size, **factory_kwargs)) + self.b_ih_list.append(b_ih) + self.b_hh_list.append(b_hh) + + if bias: + layer_params = (w_ih, w_hh, b_ih, b_hh) + else: + layer_params = (w_ih, w_hh) + + suffix = '_reverse' if direction == 1 else '' + param_names = ['weight_ih_l{}{}', 'weight_hh_l{}{}'] + if bias: + param_names += ['bias_ih_l{}{}', 'bias_hh_l{}{}'] + param_names = [x.format(layer, suffix) for x in param_names] + + for name, param in zip(param_names, layer_params): + setattr(self, name, param) + self.reset_parameters() + + def reset_parameters(self) -> None: + stdv = 1.0 / math.sqrt(self.hidden_size) if self.hidden_size > 0 else 0 + for weight in self.parameters(): + init.uniform_(weight, -stdv, stdv) + + def _stacked_bi_dynamic_rnn(self, x, h, seq_length): + """stacked bidirectional dynamic_rnn""" + pre_layer = x + h_n = () + c_n = () + output = 0 + for i in range(self.num_layers): + offset = i * 2 + if self.bias: + w_f_ih, w_f_hh, b_f_ih, b_f_hh = \ + self.w_ih_list[offset], self.w_hh_list[offset], \ + self.b_ih_list[offset], self.b_hh_list[offset] + w_b_ih, w_b_hh, b_b_ih, b_b_hh = \ + self.w_ih_list[offset + 1], self.w_hh_list[offset + 1], \ + self.b_ih_list[offset + 1], self.b_hh_list[offset + 1] + else: + w_f_ih, w_f_hh = self.w_ih_list[offset], self.w_hh_list[offset] + w_b_ih, w_b_hh = self.w_ih_list[offset + 1], self.w_hh_list[offset + 1] + b_f_ih, b_f_hh, b_b_ih, b_b_hh = None, None, None, None + if self.is_lstm: + h_f_i = (h[0][offset], h[1][offset]) + h_b_i = (h[0][offset + 1], h[1][offset + 1]) + else: + h_f_i = h[offset] + h_b_i = h[offset + 1] + if seq_length is None: + x_b = self.reverse(pre_layer) + else: + x_b = self.reverse_sequence(pre_layer, seq_length) + output_f, h_t_f = self.rnn(pre_layer, h_f_i, seq_length, w_f_ih, w_f_hh, b_f_ih, b_f_hh) + output_b, h_t_b = self.rnn(x_b, h_b_i, seq_length, w_b_ih, w_b_hh, b_b_ih, b_b_hh) + if seq_length is None: + output_b = self.reverse(output_b) + else: + output_b = self.reverse_sequence(output_b, seq_length) + output = ops.cat((output_f, output_b), 2) + pre_layer = self.dropout_op(output) if (self.dropout != 0 and i < self.num_layers - 1) else output + if self.is_lstm: + h_n += (h_t_f[0], h_t_b[0],) + c_n += (h_t_f[1], h_t_b[1],) + else: + h_n += (h_t_f, h_t_b,) + if self.is_lstm: + h_n = ops.cat(h_n) + c_n = ops.cat(c_n) + h0_shape = h[0].shape + h1_shape = h[1].shape + h_n = h_n.view(h0_shape) + c_n = c_n.view(h1_shape) + return output, (h_n.view(h0_shape), c_n.view(h1_shape)) + h_n = ops.cat(h_n) + return output, h_n.view(h.shape) + + def _stacked_dynamic_rnn(self, x, h, seq_length): + """stacked mutil_layer dynamic_rnn""" + pre_layer = x + h_n = () + c_n = () + output = 0 + for i in range(self.num_layers): + if self.bias: + w_ih, w_hh, b_ih, b_hh = self.w_ih_list[i], self.w_hh_list[i], self.b_ih_list[i], self.b_hh_list[i] + else: + w_ih, w_hh = self.w_ih_list[i], self.w_hh_list[i] + b_ih, b_hh = None, None + if self.is_lstm: + h_i = (h[0][i], h[1][i]) + else: + h_i = h[i] + output, h_t = self.rnn(pre_layer, h_i, seq_length, w_ih, w_hh, b_ih, b_hh) + pre_layer = self.dropout_op(output) if (self.dropout != 0 and i < self.num_layers - 1) else output + if self.is_lstm: + h_n += (h_t[0],) + c_n += (h_t[1],) + else: + h_n += (h_t,) + if self.is_lstm: + h_n = ops.cat(h_n) + c_n = ops.cat(c_n) + h0_shape = h[0].shape + h1_shape = h[1].shape + h_n = h_n.view(h0_shape) + c_n = c_n.view(h1_shape) + return output, (h_n.view(h0_shape), c_n.view(h1_shape)) + h_n = ops.cat(h_n) + return output, h_n.view(h.shape) + + def forward(self, x, hx=None, seq_length=None): + '''Defines the RNN like operators performed''' + max_batch_size = x.shape[0] if self.batch_first else x.shape[1] + num_directions = 2 if self.bidirectional else 1 + x_dtype = x.dtype + if hx is None: + hx = _init_state((self.num_layers * num_directions, max_batch_size, self.hidden_size), \ + x_dtype, self.is_lstm) + if self.batch_first: + x = ops.permute(x, (1, 0, 2)) + if self.bidirectional: + x_n, hx_n = self._stacked_bi_dynamic_rnn(x, hx, seq_length) + else: + x_n, hx_n = self._stacked_dynamic_rnn(x, hx, seq_length) + if self.batch_first: + x_n = ops.permute(x_n, (1, 0, 2)) + if not self.is_lstm: + return x_n.astype(x_dtype), hx_n.astype(x_dtype) + return x_n.astype(x_dtype), (hx_n[0].astype(x_dtype), hx_n[1].astype(x_dtype)) + + +class RNN(_RNNBase): + r""" + Stacked Elman RNN layers, applying RNN layer with :math:`\tanh` or :math:`\text{ReLU}` non-linearity to the input. + + For each element in the input sequence, each layer computes the following function: + + .. math:: + h_t = activation(W_{ih} x_t + b_{ih} + W_{hh} h_{(t-1)} + b_{hh}) + + Here :math:`h_t` is the hidden state at time `t`, :math:`x_t` is + the input at time `t`, and :math:`h_{(t-1)}` is the hidden state of the + previous layer at time :math:`t-1` or the initial hidden state at time `0`. + :math:`W_{ih}` is the learnable input-hidden weights, and :math:`b_{ih}` is the learnable input-hidden bias. + :math:`W_{hh}` is the learnable hidden-hidden weights, and :math:`b_{hh}` is the learnable hidden-hidden bias. + + Args: + input_size (int): Number of features of input. + hidden_size (int): Number of features of hidden layer. + num_layers (int): Number of layers of stacked RNN. Default: ``1`` . + nonlinearity (str): The non-linearity to use. Can be either ``'tanh'`` or ``'relu'``. Default: ``'tanh'``. + bias (bool): Whether the cell has bias :math:`b_{ih}` and :math:`b_{hh}`. Default: ``True`` . + batch_first (bool): Specifies whether the first dimension of input `x` is batch_size. Default: ``False`` . + dropout (float): If not 0.0, append `Dropout` layer on the outputs of each + RNN layer except the last layer. Default ``0.0`` . The range of dropout is [0.0, 1.0). + bidirectional (bool): Specifies whether it is a bidirectional RNN, + num_directions=2 if bidirectional=True otherwise 1. Default: ``False`` . + dtype (:class:`core.dtype`): Dtype of Parameters. Default: ``mstype.float32`` . + + Inputs: + - **x** (Tensor) - Tensor of data type core.float32 or core.float16 and + shape :math:`(seq\_len, batch\_size, input\_size)` or :math:`(batch\_size, seq\_len, input\_size)` . + - **hx** (Tensor) - Tensor of data type core.float32 or core.float16 and + shape :math:`(num\_directions * num\_layers, batch\_size, hidden\_size)` . + - **seq_length** (Tensor) - The length of each sequence in an input batch. + Tensor of shape :math:`(batch\_size)` . Default: ``None`` . + This input indicates the real sequence length before padding to avoid padded elements + have been used to compute hidden state and affect the final output. It is recommended to + use this input when `x` has padding elements. + + Outputs: + Tuple, a tuple contains (`output`, `hx_n`). + + - **output** (Tensor) - Tensor of shape :math:`(seq\_len, batch\_size, num\_directions * hidden\_size)` or + :math:`(batch\_size, seq\_len, num\_directions * hidden\_size)` . + - **hx_n** (Tensor) - Tensor of shape :math:`(num\_directions * num\_layers, batch\_size, hidden\_size)` . + + Raises: + TypeError: If `input_size`, `hidden_size` or `num_layers` is not an int. + TypeError: If `bias`, `batch_first` or `bidirectional` is not a bool. + TypeError: If `dropout` is not a float. + ValueError: If `dropout` is not in range [0.0, 1.0). + ValueError: If `nonlinearity` is not in ['tanh', 'relu']. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import mindspore as ms + >>> import numpy as np + >>> net = ms.nn.RNN(10, 16, 2, bias=True, batch_first=True, bidirectional=False) + >>> x = ms.Tensor(np.ones([3, 5, 10]).astype(np.float32)) + >>> h0 = ms.Tensor(np.ones([1 * 2, 3, 16]).astype(np.float32)) + >>> output, hn = net(x, h0) + >>> print(output.shape) + (3, 5, 16) + """ + + def __init__(self, *args, **kwargs): + if 'nonlinearity' in kwargs: + if kwargs['nonlinearity'] == 'tanh': + mode = 'RNN_TANH' + elif kwargs['nonlinearity'] == 'relu': + mode = 'RNN_RELU' + else: + raise ValueError(f"For '{self.cls_name}', the 'nonlinearity' must be in ['tanh', 'relu'], " + f"but got {kwargs['nonlinearity']}.") + del kwargs['nonlinearity'] + else: + mode = 'RNN_TANH' + + super(RNN, self).__init__(mode, *args, **kwargs) + + +class GRU(_RNNBase): + r""" + Stacked GRU (Gated Recurrent Unit) layers. + + Apply GRU layer to the input. + + There are two gates in a GRU model. One is update gate and the other is reset gate. + Denote two consecutive time nodes as :math:`t-1` and :math:`t`. + Given an input :math:`x_t` at time :math:`t`, a hidden state :math:`h_{t-1}`, the update and reset gate at + time :math:`t` is computed using a gating mechanism. Update gate :math:`z_t` is designed to protect the cell + from perturbation by irrelevant inputs and past hidden state. Reset gate :math:`r_t` determines how much + information should be reset from old hidden state. New memory state :math:`n_t` is + calculated with the current input, on which the reset gate will be applied. Finally, current hidden state + :math:`h_{t}` is computed with the calculated update grate and new memory state. The complete + formulation is as follows: + + .. math:: + \begin{array}{ll} + r_t = \sigma(W_{ir} x_t + b_{ir} + W_{hr} h_{(t-1)} + b_{hr}) \\ + z_t = \sigma(W_{iz} x_t + b_{iz} + W_{hz} h_{(t-1)} + b_{hz}) \\ + n_t = \tanh(W_{in} x_t + b_{in} + r_t * (W_{hn} h_{(t-1)}+ b_{hn})) \\ + h_t = (1 - z_t) * n_t + z_t * h_{(t-1)} + \end{array} + + Here :math:`\sigma` is the sigmoid function, and :math:`*` is the Hadamard product. :math:`W, b` + are learnable weights between the output and the input in the formula. For instance, + :math:`W_{ir}, b_{ir}` are the weight and bias used to transform from input :math:`x` to :math:`r`. + Details can be found in paper + `Learning Phrase Representations using RNN Encoder–Decoder for Statistical Machine Translation + `_. + + Note: + When using GRU on Ascend, the hidden size only supports multiples of 16. + + Args: + input_size (int): Number of features of input. + hidden_size (int): Number of features of hidden layer. + num_layers (int): Number of layers of stacked GRU. Default: ``1`` . + bias (bool): Whether the cell has bias :math:`b_{in}` and :math:`b_{hn}`. Default: ``True`` . + batch_first (bool): Specifies whether the first dimension of input `x` is batch_size. Default: ``False`` . + dropout (float): If not 0.0, append `Dropout` layer on the outputs of each + GRU layer except the last layer. Default ``0.0`` . The range of dropout is [0.0, 1.0). + bidirectional (bool): Specifies whether it is a bidirectional GRU, + num_directions=2 if bidirectional=True otherwise 1. Default: ``False`` . + dtype (:class:`core.dtype`): Dtype of Parameters. Default: ``mstype.float32`` . + + Inputs: + - **x** (Tensor) - Tensor of data type core.float32 or core.float16 and + shape :math:`(seq\_len, batch\_size, input\_size)` or :math:`(batch\_size, seq\_len, input\_size)`. + - **hx** (Tensor) - Tensor of data type core.float32 or core.float16 and + shape :math:`(num\_directions * num\_layers, batch\_size, hidden\_size)`. + - **seq_length** (Tensor) - The length of each sequence in an input batch. + Tensor of shape :math:`(\text{batch_size})`. Default: ``None`` . + This input indicates the real sequence length before padding to avoid padded elements + have been used to compute hidden state and affect the final output. It is recommended to + use this input when **x** has padding elements. + + Outputs: + Tuple, a tuple contains (`output`, `h_n`). + + - **output** (Tensor) - Tensor of shape :math:`(seq\_len, batch\_size, num\_directions * hidden\_size)` or + :math:`(batch\_size, seq\_len, num\_directions * hidden\_size)`. + - **hx_n** (Tensor) - Tensor of shape :math:`(num\_directions * num\_layers, batch\_size, hidden\_size)`. + + Raises: + TypeError: If `input_size`, `hidden_size` or `num_layers` is not an int. + TypeError: If `bias`, `batch_first` or `bidirectional` is not a bool. + TypeError: If `dropout` is not a float. + ValueError: If `dropout` is not in range [0.0, 1.0). + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import mindspore as ms + >>> import numpy as np + >>> net = ms.nn.GRU(10, 16, 2, bias=True, batch_first=True, bidirectional=False) + >>> x = ms.Tensor(np.ones([3, 5, 10]).astype(np.float32)) + >>> h0 = ms.Tensor(np.ones([1 * 2, 3, 16]).astype(np.float32)) + >>> output, hn = net(x, h0) + >>> print(output.shape) + (3, 5, 16) + """ + + def __init__(self, *args, **kwargs): + mode = 'GRU' + super(GRU, self).__init__(mode, *args, **kwargs) + + +class LSTM(_RNNBase): + r""" + Stacked LSTM (Long Short-Term Memory) layers. + + Apply LSTM layer to the input. + + There are two pipelines connecting two consecutive cells in a LSTM model; one is cell state pipeline + and the other is hidden state pipeline. Denote two consecutive time nodes as :math:`t-1` and :math:`t`. + Given an input :math:`x_t` at time :math:`t`, an hidden state :math:`h_{t-1}` and an cell + state :math:`c_{t-1}` of the layer at time :math:`{t-1}`, the cell state and hidden state at + time :math:`t` is computed using an gating mechanism. Input gate :math:`i_t` is designed to protect the cell + from perturbation by irrelevant inputs. Forget gate :math:`f_t` affords protection of the cell by forgetting + some information in the past, which is stored in :math:`h_{t-1}`. Output gate :math:`o_t` protects other + units from perturbation by currently irrelevant memory contents. Candidate cell state :math:`\tilde{c}_t` is + calculated with the current input, on which the input gate will be applied. Finally, current cell state + :math:`c_{t}` and hidden state :math:`h_{t}` are computed with the calculated gates and cell states. The complete + formulation is as follows. + + .. math:: + \begin{array}{ll} \\ + i_t = \sigma(W_{ix} x_t + b_{ix} + W_{ih} h_{(t-1)} + b_{ih}) \\ + f_t = \sigma(W_{fx} x_t + b_{fx} + W_{fh} h_{(t-1)} + b_{fh}) \\ + \tilde{c}_t = \tanh(W_{cx} x_t + b_{cx} + W_{ch} h_{(t-1)} + b_{ch}) \\ + o_t = \sigma(W_{ox} x_t + b_{ox} + W_{oh} h_{(t-1)} + b_{oh}) \\ + c_t = f_t * c_{(t-1)} + i_t * \tilde{c}_t \\ + h_t = o_t * \tanh(c_t) \\ + \end{array} + + Here :math:`\sigma` is the sigmoid function, and :math:`*` is the Hadamard product. :math:`W, b` + are learnable weights between the output and the input in the formula. For instance, + :math:`W_{ix}, b_{ix}` are the weight and bias used to transform from input :math:`x` to :math:`i`. + Details can be found in paper `LONG SHORT-TERM MEMORY + `_ and + `Long Short-Term Memory Recurrent Neural Network Architectures for Large Scale Acoustic Modeling + `_. + + LSTM hides the cycle of the whole cyclic neural network on the time step of the sequence, + and input the sequence and initial state to obtain the matrix spliced by + the hidden state of each time step and the hidden state of the last time step. + We use the hidden state of the last time step as the coding feature of the input sentence and + output it to the next layer. + + .. math:: + h_{0:n},(h_{n}, c_{n}) = LSTM(x_{0:n},(h_{0},c_{0})) + + Args: + input_size (int): Number of features of input. + hidden_size (int): Number of features of hidden layer. + num_layers (int): Number of layers of stacked LSTM . Default: ``1`` . + bias (bool): Whether the cell has bias :math:`b_{ih}` and :math:`b_{fh}`. Default: ``True`` . + batch_first (bool): Specifies whether the first dimension of input `x` is batch_size. Default: ``False`` . + dropout (float, int): If not 0, append `Dropout` layer on the outputs of each + LSTM layer except the last layer. Default ``0`` . The range of dropout is [0.0, 1.0). + bidirectional (bool): Specifies whether it is a bidirectional LSTM, + num_directions=2 if bidirectional=True otherwise 1. Default: ``False`` . + dtype (:class:`core.dtype`): Dtype of Parameters. Default: ``mstype.float32`` . + + Inputs: + - **x** (Tensor) - Tensor of data type core.float32 or core.float16 and + shape :math:`(seq\_len, batch\_size, input\_size)` or :math:`(batch\_size, seq\_len, input\_size)` . + - **hx** (tuple) - A tuple of two Tensors (h_0, c_0) both of data type core.float32 + or core.float16 and shape :math:`(num\_directions * num\_layers, batch\_size, hidden\_size)` . + - **seq_length** (Tensor) - The length of each sequence in an input batch. + Tensor of shape :math:`(batch\_size)`. Default: ``None`` . + This input indicates the real sequence length before padding to avoid padded elements + have been used to compute hidden state and affect the final output. It is recommended to + use this input when **x** has padding elements. + + Outputs: + Tuple, a tuple contains (`output`, (`h_n`, `c_n`)). + + - **output** (Tensor) - Tensor of shape :math:`(seq\_len, batch\_size, num\_directions * hidden\_size)` . + - **hx_n** (tuple) - A tuple of two Tensor (h_n, c_n) both of shape + :math:`(num\_directions * num\_layers, batch\_size, hidden\_size)` . + + Raises: + TypeError: If `input_size`, `hidden_size` or `num_layers` is not an int. + TypeError: If `bias`, `batch_first` or `bidirectional` is not a bool. + TypeError: If `dropout` is not a float. + ValueError: If `dropout` is not in range [0.0, 1.0). + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import mindspore as ms + >>> import numpy as np + >>> net = ms.nn.LSTM(10, 16, 2, bias=True, batch_first=True, bidirectional=False) + >>> x = ms.Tensor(np.ones([3, 5, 10]).astype(np.float32)) + >>> h0 = ms.Tensor(np.ones([1 * 2, 3, 16]).astype(np.float32)) + >>> c0 = ms.Tensor(np.ones([1 * 2, 3, 16]).astype(np.float32)) + >>> output, (hn, cn) = net(x, (h0, c0)) + >>> print(output.shape) + (3, 5, 16) + """ + + def __init__(self, *args, **kwargs): + mode = 'LSTM' + super(LSTM, self).__init__(mode, *args, **kwargs) diff --git a/mindnlp/core/nn/modules/rnn_cell.py b/mindnlp/core/nn/modules/rnn_cell.py new file mode 100644 index 000000000..e1f529d1e --- /dev/null +++ b/mindnlp/core/nn/modules/rnn_cell.py @@ -0,0 +1,305 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""RNN Cells module, include RNNCell, GRUCell, LSTMCell.""" +import math + +from mindnlp import core +from ..parameter import Parameter +from .module import Module +from .. import init +from .. import functional as F +from ... import ops + +__all__ = ['LSTMCell', 'GRUCell', 'RNNCell'] + +def _rnn_tanh_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh): + """RNN cell function with tanh activation""" + if b_ih is None: + igates = ops.matmul(inputs, w_ih.T) + hgates = ops.matmul(hidden, w_hh.T) + else: + igates = ops.matmul(inputs, w_ih.T) + b_ih + hgates = ops.matmul(hidden, w_hh.T) + b_hh + return ops.tanh(igates + hgates) + + +def _rnn_relu_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh): + """RNN cell function with relu activation""" + if b_ih is None: + igates = ops.matmul(inputs, w_ih.T) + hgates = ops.matmul(hidden, w_hh.T) + else: + igates = ops.matmul(inputs, w_ih.T) + b_ih + hgates = ops.matmul(hidden, w_hh.T) + b_hh + return F.relu(igates + hgates) + + +def _lstm_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh): + """LSTM cell function""" + hx, cx = hidden + if b_ih is None: + gates = ops.matmul(inputs, w_ih.T) + ops.matmul(hx, w_hh.T) + else: + gates = ops.matmul(inputs, w_ih.T) + ops.matmul(hx, w_hh.T) + b_ih + b_hh + ingate, forgetgate, cellgate, outgate = ops.chunk(gates, 4, 1) + ingate = ops.sigmoid(ingate) + forgetgate = ops.sigmoid(forgetgate) + cellgate = ops.tanh(cellgate) + outgate = ops.sigmoid(outgate) + + cy = (forgetgate * cx) + (ingate * cellgate) + hy = outgate * ops.tanh(cy) + + return hy, cy + + +def _gru_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh): + """GRU cell function""" + if b_ih is None: + gi = ops.matmul(inputs, w_ih.T) + gh = ops.matmul(hidden, w_hh.T) + else: + gi = ops.matmul(inputs, w_ih.T) + b_ih + gh = ops.matmul(hidden, w_hh.T) + b_hh + i_r, i_i, i_n = ops.chunk(gi, 3, 1) + h_r, h_i, h_n = ops.chunk(gh, 3, 1) + + resetgate = ops.sigmoid(i_r + h_r) + inputgate = ops.sigmoid(i_i + h_i) + newgate = ops.tanh(i_n + resetgate * h_n) + hy = newgate + inputgate * (hidden - newgate) + + return hy + + +class RNNCellBase(Module): + """Basic class for RNN Cells""" + def __init__(self, input_size: int, hidden_size: int, bias: bool, num_chunks: int, + dtype=None): + factory_kwargs = {'dtype': dtype} + super().__init__() + self.input_size = input_size + self.hidden_size = hidden_size + self.bias = bias + self.weight_ih = Parameter(ops.empty((num_chunks * hidden_size, input_size), **factory_kwargs)) + self.weight_hh = Parameter(ops.empty((num_chunks * hidden_size, hidden_size), **factory_kwargs)) + if bias: + self.bias_ih = Parameter(ops.empty(num_chunks * hidden_size, **factory_kwargs)) + self.bias_hh = Parameter(ops.empty(num_chunks * hidden_size, **factory_kwargs)) + else: + self.bias_ih = None + self.bias_hh = None + self.reset_parameters() + + def reset_parameters(self) -> None: + stdv = 1.0 / math.sqrt(self.hidden_size) if self.hidden_size > 0 else 0 + for weight in self.parameters(): + init.uniform_(weight, -stdv, stdv) + + +class RNNCell(RNNCellBase): + r""" + An Elman RNN cell with tanh or ReLU non-linearity. + + .. math:: + h_t = \tanh(W_{ih} x_t + b_{ih} + W_{hh} h_{(t-1)} + b_{hh}) + + Here :math:`h_t` is the hidden state at time `t`, :math:`x_t` is + the input at time `t`, and :math:`h_{(t-1)}` is the hidden state of the + previous layer at time :math:`t-1` or the initial hidden state at time `0`. + If `nonlinearity` is `relu`, then `relu` is used instead of `tanh`. + + Args: + input_size (int): Number of features of input. + hidden_size (int): Number of features of hidden layer. + bias (bool): Whether the cell has bias :math:`b_{ih}` and :math:`b_{hh}`. Default: ``True`` . + nonlinearity (str): The non-linearity to use. Can be either ``"tanh"`` or ``"relu"`` . + Default: ``"tanh"`` . + dtype (:class:`core.dtype`): Dtype of Parameters. Default: ``mstype.float32`` . + + Inputs: + - **x** (Tensor) - Tensor of shape :math:`(batch\_size, input\_size)` . + - **hx** (Tensor) - Tensor of data type core.float32 and shape :math:`(batch\_size, hidden\_size)` . + + Outputs: + - **hx'** (Tensor) - Tensor of shape :math:`(batch\_size, hidden\_size)` . + + Raises: + TypeError: If `input_size` or `hidden_size` is not an int or not greater than 0. + TypeError: If `bias` is not a bool. + ValueError: If `nonlinearity` is not in ['tanh', 'relu']. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import mindspore as ms + >>> import numpy as np + >>> net = ms.nn.RNNCell(10, 16) + >>> x = ms.Tensor(np.ones([5, 3, 10]).astype(np.float32)) + >>> hx = ms.Tensor(np.ones([3, 16]).astype(np.float32)) + >>> output = [] + >>> for i in range(5): + ... hx = net(x[i], hx) + ... output.append(hx) + >>> print(output[0].shape) + (3, 16) + """ + _non_linearity = ['tanh', 'relu'] + + def __init__(self, input_size: int, hidden_size: int, bias: bool = True, nonlinearity: str = "tanh", + dtype=core.float32): + super().__init__(input_size, hidden_size, bias, num_chunks=1, dtype=dtype) + self.nonlinearity = nonlinearity + + def forward(self, x, hx): + if self.nonlinearity == "tanh": + ret = _rnn_tanh_cell(x, hx, self.weight_ih, self.weight_hh, self.bias_ih, self.bias_hh) + else: + ret = _rnn_relu_cell(x, hx, self.weight_ih, self.weight_hh, self.bias_ih, self.bias_hh) + return ret + + +class LSTMCell(RNNCellBase): + r""" + A LSTM (Long Short-Term Memory) cell. + + .. math:: + \begin{array}{ll} \\ + i_t = \sigma(W_{ix} x_t + b_{ix} + W_{ih} h_{(t-1)} + b_{ih}) \\ + f_t = \sigma(W_{fx} x_t + b_{fx} + W_{fh} h_{(t-1)} + b_{fh}) \\ + \tilde{c}_t = \tanh(W_{cx} x_t + b_{cx} + W_{ch} h_{(t-1)} + b_{ch}) \\ + o_t = \sigma(W_{ox} x_t + b_{ox} + W_{oh} h_{(t-1)} + b_{oh}) \\ + c_t = f_t * c_{(t-1)} + i_t * \tilde{c}_t \\ + h_t = o_t * \tanh(c_t) \\ + \end{array} + + Here :math:`\sigma` is the sigmoid function, and :math:`*` is the Hadamard product. :math:`W, b` + are learnable weights between the output and the input in the formula. For instance, + :math:`W_{ix}, b_{ix}` are the weight and bias used to transform from input :math:`x` to :math:`i`. + Details can be found in paper `LONG SHORT-TERM MEMORY + `_ and + `Long Short-Term Memory Recurrent Neural Network Architectures for Large Scale Acoustic Modeling + `_. + + The encapsulated LSTMCell can be simplified to the following formula: + + .. math:: + h^{'},c^{'} = LSTMCell(x, (h_0, c_0)) + + Args: + input_size (int): Number of features of input. + hidden_size (int): Number of features of hidden layer. + bias (bool): Whether the cell has bias `b_{ih}` and `b_{hh}`. Default: ``True`` . + dtype (:class:`core.dtype`): Dtype of Parameters. Default: ``mstype.float32`` . + + Inputs: + - **x** (Tensor) - Tensor of shape :math:`(batch\_size, input\_size)` . + - **hx** (tuple) - A tuple of two Tensors (h_0, c_0) both of data type core.float32 + and shape :math:`(batch\_size, hidden\_size)` . + + Outputs: + - **hx'** (Tensor) - A tuple of two Tensors (h', c') both of data shape :math:`(batch\_size, hidden\_size)` . + + Raises: + TypeError: If `input_size`, `hidden_size` is not an int. + TypeError: If `bias` is not a bool. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import mindspore as ms + >>> import numpy as np + >>> net = ms.nn.LSTMCell(10, 16) + >>> x = ms.Tensor(np.ones([5, 3, 10]).astype(np.float32)) + >>> h = ms.Tensor(np.ones([3, 16]).astype(np.float32)) + >>> c = ms.Tensor(np.ones([3, 16]).astype(np.float32)) + >>> output = [] + >>> for i in range(5): + ... hx = net(x[i], (h, c)) + ... output.append(hx) + >>> print(output[0][0].shape) + (3, 16) + """ + def __init__(self, input_size: int, hidden_size: int, bias: bool = True, + dtype=core.float32): + super().__init__(input_size, hidden_size, bias, num_chunks=4, dtype=dtype) + self.support_non_tensor_inputs = True + + def forward(self, x, hx): + return _lstm_cell(x, hx, self.weight_ih, self.weight_hh, self.bias_ih, self.bias_hh) + + +class GRUCell(RNNCellBase): + r""" + A GRU(Gated Recurrent Unit) cell. + + .. math:: + + \begin{array}{ll} + r = \sigma(W_{ir} x + b_{ir} + W_{hr} h + b_{hr}) \\ + z = \sigma(W_{iz} x + b_{iz} + W_{hz} h + b_{hz}) \\ + n = \tanh(W_{in} x + b_{in} + r * (W_{hn} h + b_{hn})) \\ + h' = (1 - z) * n + z * h + \end{array} + + Here :math:`\sigma` is the sigmoid function, and :math:`*` is the Hadamard product. :math:`W, b` + are learnable weights between the output and the input in the formula. :math:`h` is hidden state. + :math:`r` is reset gate. :math:`z` is update gate. :math:`n` is n-th layer. For instance, + :math:`W_{ir}, b_{ir}` are the weight and bias used to transform from input :math:`x` to :math:`r`. + Details can be found in paper + `Learning Phrase Representations using RNN Encoder–Decoder for Statistical Machine Translation + `_. + + Args: + input_size (int): Number of features of input. + hidden_size (int): Number of features of hidden layer. + bias (bool): Whether the cell has bias :math:`b_{in}` and :math:`b_{hn}`. Default: ``True`` . + dtype (:class:`core.dtype`): Dtype of Parameters. Default: ``mstype.float32`` . + + Inputs: + - **x** (Tensor) - Tensor of shape :math:`(batch\_size, input\_size)` . + - **hx** (Tensor) - Tensor of data type core.float32 and shape :math:`(batch\_size, hidden\_size)` . + + Outputs: + - **hx'** (Tensor) - Tensor of shape :math:`(batch\_size, hidden\_size)` . + + Raises: + TypeError: If `input_size`, `hidden_size` is not an int. + TypeError: If `bias` is not a bool. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import mindspore as ms + >>> import numpy as np + >>> net = ms.nn.GRUCell(10, 16) + >>> x = ms.Tensor(np.ones([5, 3, 10]).astype(np.float32)) + >>> hx = ms.Tensor(np.ones([3, 16]).astype(np.float32)) + >>> output = [] + >>> for i in range(5): + ... hx = net(x[i], hx) + ... output.append(hx) + >>> print(output[0].shape) + (3, 16) + """ + def __init__(self, input_size: int, hidden_size: int, bias: bool = True, + dtype=core.float32): + super().__init__(input_size, hidden_size, bias, num_chunks=3, dtype=dtype) + + def forward(self, x, hx): + return _gru_cell(x, hx, self.weight_ih, self.weight_hh, self.bias_ih, self.bias_hh) diff --git a/mindnlp/core/nn/modules/sparse.py b/mindnlp/core/nn/modules/sparse.py new file mode 100644 index 000000000..5712d68a4 --- /dev/null +++ b/mindnlp/core/nn/modules/sparse.py @@ -0,0 +1,131 @@ +"""sparse""" +from typing import Optional +from mindnlp.core import Tensor +from ..parameter import Parameter +from .module import Module +from .. import functional as F +from .. import init +from ... import ops + + +class Embedding(Module): + r"""A simple lookup table that stores embeddings of a fixed dictionary and size. + + This module is often used to store word embeddings and retrieve them using indices. + The input to the module is a list of indices, and the output is the corresponding + word embeddings. + """ + + __constants__ = ['num_embeddings', 'embedding_dim', 'padding_idx', 'max_norm', + 'norm_type', 'scale_grad_by_freq', 'sparse'] + + num_embeddings: int + embedding_dim: int + padding_idx: Optional[int] + max_norm: Optional[float] + norm_type: float + scale_grad_by_freq: bool + weight: Tensor + freeze: bool + sparse: bool + + def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None, + max_norm: Optional[float] = None, norm_type: float = 2., scale_grad_by_freq: bool = False, + sparse: bool = False, _weight: Optional[Tensor] = None, _freeze: bool = False, + dtype=None) -> None: + factory_kwargs = {'dtype': dtype} + super().__init__() + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + if padding_idx is not None: + if padding_idx > 0: + assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings' + elif padding_idx < 0: + assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings' + padding_idx = self.num_embeddings + padding_idx + self.padding_idx = padding_idx + self.max_norm = max_norm + self.norm_type = norm_type + self.scale_grad_by_freq = scale_grad_by_freq + if _weight is None: + self.weight = Parameter(ops.empty((num_embeddings, embedding_dim), **factory_kwargs), + requires_grad=not _freeze) + self.reset_parameters() + else: + assert list(_weight.shape) == [num_embeddings, embedding_dim], \ + 'Shape of weight does not match num_embeddings and embedding_dim' + self.weight = Parameter(_weight, requires_grad=not _freeze) + + self.sparse = sparse + + def reset_parameters(self) -> None: + init.normal_(self.weight) + self._fill_padding_idx_with_zero() + + def _fill_padding_idx_with_zero(self) -> None: + if self.padding_idx is not None: + self.weight[self.padding_idx] = 0 + + def forward(self, input: Tensor) -> Tensor: + return F.embedding( + input, self.weight, self.padding_idx, self.max_norm, + self.norm_type, self.scale_grad_by_freq) + + def extra_repr(self) -> str: + s = '{num_embeddings}, {embedding_dim}' + if self.padding_idx is not None: + s += ', padding_idx={padding_idx}' + if self.max_norm is not None: + s += ', max_norm={max_norm}' + if self.norm_type != 2: + s += ', norm_type={norm_type}' + if self.scale_grad_by_freq is not False: + s += ', scale_grad_by_freq={scale_grad_by_freq}' + if self.sparse is not False: + s += ', sparse=True' + return s.format(**self.__dict__) + + @classmethod + def from_pretrained(cls, embeddings, freeze=True, padding_idx=None, + max_norm=None, norm_type=2., scale_grad_by_freq=False, + sparse=False): + r"""Create Embedding instance from given 2-dimensional FloatTensor. + + Args: + embeddings (Tensor): FloatTensor containing weights for the Embedding. + First dimension is being passed to Embedding as ``num_embeddings``, second as ``embedding_dim``. + freeze (bool, optional): If ``True``, the tensor does not get updated in the learning process. + Equivalent to ``embedding.weight.requires_grad = False``. Default: ``True`` + padding_idx (int, optional): If specified, the entries at :attr:`padding_idx` do not contribute to the gradient; + therefore, the embedding vector at :attr:`padding_idx` is not updated during training, + i.e. it remains as a fixed "pad". + max_norm (float, optional): See module initialization documentation. + norm_type (float, optional): See module initialization documentation. Default ``2``. + scale_grad_by_freq (bool, optional): See module initialization documentation. Default ``False``. + sparse (bool, optional): See module initialization documentation. + + Examples:: + + >>> # FloatTensor containing pretrained weights + >>> weight = core.FloatTensor([[1, 2.3, 3], [4, 5.1, 6.3]]) + >>> embedding = nn.Embedding.from_pretrained(weight) + >>> # Get embeddings for index 1 + >>> input = core.LongTensor([1]) + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> embedding(input) + tensor([[ 4.0000, 5.1000, 6.3000]]) + """ + assert embeddings.dim() == 2, \ + 'Embeddings parameter is expected to be 2-dimensional' + rows, cols = embeddings.shape + embedding = cls( + num_embeddings=rows, + embedding_dim=cols, + _weight=embeddings, + _freeze=freeze, + padding_idx=padding_idx, + max_norm=max_norm, + norm_type=norm_type, + scale_grad_by_freq=scale_grad_by_freq, + sparse=sparse) + return embedding diff --git a/mindnlp/core/nn/modules/upsampling.py b/mindnlp/core/nn/modules/upsampling.py new file mode 100644 index 000000000..700b1b5aa --- /dev/null +++ b/mindnlp/core/nn/modules/upsampling.py @@ -0,0 +1,259 @@ +"""upsample""" +from typing import Optional +from mindnlp.core import Tensor + +from .module import Module +from .. import functional as F +from ..common_types import _size_2_t, _ratio_2_t, _size_any_t, _ratio_any_t + +__all__ = ['Upsample', 'UpsamplingNearest2d', 'UpsamplingBilinear2d'] + + + +class Upsample(Module): + r"""Upsamples a given multi-channel 1D (temporal), 2D (spatial) or 3D (volumetric) data. + + The input data is assumed to be of the form + `minibatch x channels x [optional depth] x [optional height] x width`. + Hence, for spatial inputs, we expect a 4D Tensor and for volumetric inputs, we expect a 5D Tensor. + + The algorithms available for upsampling are nearest neighbor and linear, + bilinear, bicubic and trilinear for 3D, 4D and 5D input Tensor, + respectively. + + One can either give a :attr:`scale_factor` or the target output :attr:`size` to + calculate the output size. (You cannot give both, as it is ambiguous) + + Args: + size (int or Tuple[int] or Tuple[int, int] or Tuple[int, int, int], optional): + output spatial sizes + scale_factor (float or Tuple[float] or Tuple[float, float] or Tuple[float, float, float], optional): + multiplier for spatial size. Has to match input size if it is a tuple. + mode (str, optional): the upsampling algorithm: one of ``'nearest'``, + ``'linear'``, ``'bilinear'``, ``'bicubic'`` and ``'trilinear'``. + Default: ``'nearest'`` + align_corners (bool, optional): if ``True``, the corner pixels of the input + and output tensors are aligned, and thus preserving the values at + those pixels. This only has effect when :attr:`mode` is + ``'linear'``, ``'bilinear'``, ``'bicubic'``, or ``'trilinear'``. + Default: ``False`` + recompute_scale_factor (bool, optional): recompute the scale_factor for use in the + interpolation calculation. If `recompute_scale_factor` is ``True``, then + `scale_factor` must be passed in and `scale_factor` is used to compute the + output `size`. The computed output `size` will be used to infer new scales for + the interpolation. Note that when `scale_factor` is floating-point, it may differ + from the recomputed `scale_factor` due to rounding and precision issues. + If `recompute_scale_factor` is ``False``, then `size` or `scale_factor` will + be used directly for interpolation. + + Shape: + - Input: :math:`(N, C, W_{in})`, :math:`(N, C, H_{in}, W_{in})` or :math:`(N, C, D_{in}, H_{in}, W_{in})` + - Output: :math:`(N, C, W_{out})`, :math:`(N, C, H_{out}, W_{out})` + or :math:`(N, C, D_{out}, H_{out}, W_{out})`, where + + .. math:: + D_{out} = \left\lfloor D_{in} \times \text{scale\_factor} \right\rfloor + + .. math:: + H_{out} = \left\lfloor H_{in} \times \text{scale\_factor} \right\rfloor + + .. math:: + W_{out} = \left\lfloor W_{in} \times \text{scale\_factor} \right\rfloor + + .. warning:: + With ``align_corners = True``, the linearly interpolating modes + (`linear`, `bilinear`, `bicubic`, and `trilinear`) don't proportionally + align the output and input pixels, and thus the output values can depend + on the input size. This was the default behavior for these modes up to + version 0.3.1. Since then, the default behavior is + ``align_corners = False``. See below for concrete examples on how this + affects the outputs. + + .. note:: + If you want downsampling/general resizing, you should use :func:`~nn.functional.interpolate`. + + Examples:: + + >>> input = core.arange(1, 5, dtype=core.float32).view(1, 1, 2, 2) + >>> input + tensor([[[[1., 2.], + [3., 4.]]]]) + + >>> m = nn.Upsample(scale_factor=2, mode='nearest') + >>> m(input) + tensor([[[[1., 1., 2., 2.], + [1., 1., 2., 2.], + [3., 3., 4., 4.], + [3., 3., 4., 4.]]]]) + + >>> # xdoctest: +IGNORE_WANT("other tests seem to modify printing styles") + >>> m = nn.Upsample(scale_factor=2, mode='bilinear') # align_corners=False + >>> m(input) + tensor([[[[1.0000, 1.2500, 1.7500, 2.0000], + [1.5000, 1.7500, 2.2500, 2.5000], + [2.5000, 2.7500, 3.2500, 3.5000], + [3.0000, 3.2500, 3.7500, 4.0000]]]]) + + >>> m = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) + >>> m(input) + tensor([[[[1.0000, 1.3333, 1.6667, 2.0000], + [1.6667, 2.0000, 2.3333, 2.6667], + [2.3333, 2.6667, 3.0000, 3.3333], + [3.0000, 3.3333, 3.6667, 4.0000]]]]) + + >>> # Try scaling the same data in a larger tensor + >>> input_3x3 = core.zeros(3, 3).view(1, 1, 3, 3) + >>> input_3x3[:, :, :2, :2].copy_(input) + tensor([[[[1., 2.], + [3., 4.]]]]) + >>> input_3x3 + tensor([[[[1., 2., 0.], + [3., 4., 0.], + [0., 0., 0.]]]]) + + >>> # xdoctest: +IGNORE_WANT("seems to fail when other tests are run in the same session") + >>> m = nn.Upsample(scale_factor=2, mode='bilinear') # align_corners=False + >>> # Notice that values in top left corner are the same with the small input (except at boundary) + >>> m(input_3x3) + tensor([[[[1.0000, 1.2500, 1.7500, 1.5000, 0.5000, 0.0000], + [1.5000, 1.7500, 2.2500, 1.8750, 0.6250, 0.0000], + [2.5000, 2.7500, 3.2500, 2.6250, 0.8750, 0.0000], + [2.2500, 2.4375, 2.8125, 2.2500, 0.7500, 0.0000], + [0.7500, 0.8125, 0.9375, 0.7500, 0.2500, 0.0000], + [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]]]) + + >>> m = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) + >>> # Notice that values in top left corner are now changed + >>> m(input_3x3) + tensor([[[[1.0000, 1.4000, 1.8000, 1.6000, 0.8000, 0.0000], + [1.8000, 2.2000, 2.6000, 2.2400, 1.1200, 0.0000], + [2.6000, 3.0000, 3.4000, 2.8800, 1.4400, 0.0000], + [2.4000, 2.7200, 3.0400, 2.5600, 1.2800, 0.0000], + [1.2000, 1.3600, 1.5200, 1.2800, 0.6400, 0.0000], + [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]]]) + """ + __constants__ = ['size', 'scale_factor', 'mode', 'align_corners', 'name', 'recompute_scale_factor'] + name: str + size: Optional[_size_any_t] + scale_factor: Optional[_ratio_any_t] + mode: str + align_corners: Optional[bool] + recompute_scale_factor: Optional[bool] + + def __init__(self, size: Optional[_size_any_t] = None, scale_factor: Optional[_ratio_any_t] = None, + mode: str = 'nearest', align_corners: Optional[bool] = None, + recompute_scale_factor: Optional[bool] = None) -> None: + super().__init__() + self.name = type(self).__name__ + self.size = size + if isinstance(scale_factor, tuple): + self.scale_factor = tuple(float(factor) for factor in scale_factor) + else: + self.scale_factor = float(scale_factor) if scale_factor else None + self.mode = mode + self.align_corners = align_corners + self.recompute_scale_factor = recompute_scale_factor + + def forward(self, input: Tensor) -> Tensor: + return F.interpolate(input, self.size, self.scale_factor, self.mode, self.align_corners, + recompute_scale_factor=self.recompute_scale_factor) + + def extra_repr(self) -> str: + if self.scale_factor is not None: + info = 'scale_factor=' + repr(self.scale_factor) + else: + info = 'size=' + repr(self.size) + info += ', mode=' + repr(self.mode) + return info + + +class UpsamplingNearest2d(Upsample): + r"""Applies a 2D nearest neighbor upsampling to an input signal composed of several input + channels. + + To specify the scale, it takes either the :attr:`size` or the :attr:`scale_factor` + as it's constructor argument. + + When :attr:`size` is given, it is the output size of the image `(h, w)`. + + Args: + size (int or Tuple[int, int], optional): output spatial sizes + scale_factor (float or Tuple[float, float], optional): multiplier for + spatial size. + + .. warning:: + This class is deprecated in favor of :func:`~nn.functional.interpolate`. + + Shape: + - Input: :math:`(N, C, H_{in}, W_{in})` + - Output: :math:`(N, C, H_{out}, W_{out})` where + + .. math:: + H_{out} = \left\lfloor H_{in} \times \text{scale\_factor} \right\rfloor + + .. math:: + W_{out} = \left\lfloor W_{in} \times \text{scale\_factor} \right\rfloor + + Examples:: + + >>> input = core.arange(1, 5, dtype=core.float32).view(1, 1, 2, 2) + >>> input + tensor([[[[1., 2.], + [3., 4.]]]]) + + >>> m = nn.UpsamplingNearest2d(scale_factor=2) + >>> m(input) + tensor([[[[1., 1., 2., 2.], + [1., 1., 2., 2.], + [3., 3., 4., 4.], + [3., 3., 4., 4.]]]]) + """ + def __init__(self, size: Optional[_size_2_t] = None, scale_factor: Optional[_ratio_2_t] = None) -> None: + super().__init__(size, scale_factor, mode='nearest') + + +class UpsamplingBilinear2d(Upsample): + r"""Applies a 2D bilinear upsampling to an input signal composed of several input + channels. + + To specify the scale, it takes either the :attr:`size` or the :attr:`scale_factor` + as it's constructor argument. + + When :attr:`size` is given, it is the output size of the image `(h, w)`. + + Args: + size (int or Tuple[int, int], optional): output spatial sizes + scale_factor (float or Tuple[float, float], optional): multiplier for + spatial size. + + .. warning:: + This class is deprecated in favor of :func:`~nn.functional.interpolate`. It is + equivalent to ``nn.functional.interpolate(..., mode='bilinear', align_corners=True)``. + + Shape: + - Input: :math:`(N, C, H_{in}, W_{in})` + - Output: :math:`(N, C, H_{out}, W_{out})` where + + .. math:: + H_{out} = \left\lfloor H_{in} \times \text{scale\_factor} \right\rfloor + + .. math:: + W_{out} = \left\lfloor W_{in} \times \text{scale\_factor} \right\rfloor + + Examples:: + + >>> input = core.arange(1, 5, dtype=core.float32).view(1, 1, 2, 2) + >>> input + tensor([[[[1., 2.], + [3., 4.]]]]) + + >>> # xdoctest: +IGNORE_WANT("do other tests modify the global state?") + >>> m = nn.UpsamplingBilinear2d(scale_factor=2) + >>> m(input) + tensor([[[[1.0000, 1.3333, 1.6667, 2.0000], + [1.6667, 2.0000, 2.3333, 2.6667], + [2.3333, 2.6667, 3.0000, 3.3333], + [3.0000, 3.3333, 3.6667, 4.0000]]]]) + """ + def __init__(self, size: Optional[_size_2_t] = None, scale_factor: Optional[_ratio_2_t] = None) -> None: + super().__init__(size, scale_factor, mode='bilinear', align_corners=True) diff --git a/mindnlp/core/nn/modules/utils.py b/mindnlp/core/nn/modules/utils.py new file mode 100644 index 000000000..c3f14433d --- /dev/null +++ b/mindnlp/core/nn/modules/utils.py @@ -0,0 +1,15 @@ +import collections +from itertools import repeat + +def _ntuple(n, name="parse"): + def parse(x): + if isinstance(x, (list, tuple)) and len(x) == 1: + x = x[0] + if isinstance(x, collections.abc.Iterable): + return tuple(x) + return tuple(repeat(x, n)) + + parse.__name__ = name + return parse + +_pair = _ntuple(2, "_pair") diff --git a/mindnlp/core/nn/parallel/__init__.py b/mindnlp/core/nn/parallel/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/mindnlp/core/nn/parallel/distributed.py b/mindnlp/core/nn/parallel/distributed.py new file mode 100644 index 000000000..aefe06666 --- /dev/null +++ b/mindnlp/core/nn/parallel/distributed.py @@ -0,0 +1,4 @@ +from ..modules import Module + +class DistributedDataParallel(Module): + pass \ No newline at end of file diff --git a/mindnlp/core/nn/parameter.py b/mindnlp/core/nn/parameter.py new file mode 100644 index 000000000..1c89e1954 --- /dev/null +++ b/mindnlp/core/nn/parameter.py @@ -0,0 +1,82 @@ +"""new Parameter""" +import uuid +import copy +from mindspore import Tensor +from mindspore._c_expression import ParamInfo # pylint: disable=no-name-in-module +from mindspore.common._stub_tensor import StubTensor + +class Parameter(Tensor): + grad = None + requires_grad = False + + def __init__(self, input_data=None, requires_grad=True, **kwargs): + super().__init__(input_data) + self.meta = False + self.param_info = ParamInfo() + self.param_info.name = str(uuid.uuid4()) + self.param_info.parameter_shape = self._shape + self.param_info.requires_grad = requires_grad + self._requires_grad = requires_grad + if self._requires_grad: + self.retain_grad() + + def __deepcopy__(self, memodict): + new_obj = Parameter(self) + return new_obj + + def clone(self): + return copy.deepcopy(self) + + def __parameter__(self): # only for O2 + """For parse check.""" + + @property + def name(self): # only for O2 + """ + Get the name of the parameter. + + Examples: + >>> from mindspore import Tensor, Parameter + >>> import numpy as np + >>> x = Parameter(Tensor(np.array([1, 2], dtype=np.float32)), name="param") + >>> x.name = "param1" + >>> x.name + 'param1' + """ + return self.param_info.name + + @property + def data(self): + return Tensor(self) + + @data.setter + def data(self, new_value): + if isinstance(new_value, StubTensor): + new_value = new_value.stub.get_value() + self.assign_value(new_value) + + @property + def requires_grad(self): + return self._requires_grad + + @requires_grad.setter + def requires_grad(self, value): + if not isinstance(value, bool): + raise TypeError("The 'requires_grad' attribute of parameter must be set as bool.") + self.param_info.requires_grad = value + self._requires_grad = value + if value: + if not hasattr(self, 'handle'): + self.retain_grad() + else: + if hasattr(self, 'handle'): + self.handle.remove() + delattr(self, 'handle') + + +class UninitializedParameter(Parameter): + def __init__(self, input_data=None, requires_grad=True): + super().__init__(input_data, requires_grad) + +def is_lazy(param): + return False diff --git a/mindnlp/core/nn/utils/__init__.py b/mindnlp/core/nn/utils/__init__.py new file mode 100644 index 000000000..719e70363 --- /dev/null +++ b/mindnlp/core/nn/utils/__init__.py @@ -0,0 +1,4 @@ +"""utils""" +from . import parametrizations +from .weight_norm import * +from .clip_grad import * diff --git a/mindnlp/core/nn/utils/clip_grad.py b/mindnlp/core/nn/utils/clip_grad.py new file mode 100644 index 000000000..709ea8bca --- /dev/null +++ b/mindnlp/core/nn/utils/clip_grad.py @@ -0,0 +1,122 @@ +"""clip grad""" +# mypy: allow-untyped-defs +import functools +from typing import Union, Iterable, Optional +from typing_extensions import deprecated + +from mindnlp import core +from ... import ops +from ...autograd import no_grad + +_tensor_or_tensors = Union[core.Tensor, Iterable[core.Tensor]] + +__all__ = ['clip_grad_norm_', 'clip_grad_norm', 'clip_grad_value_'] + +inf = float('inf') + +def _no_grad(func): + """ + This wrapper is needed to avoid a circular import when using @no_grad on the exposed functions + clip_grad_norm_ and clip_grad_value_ themselves. + """ + def _no_grad_wrapper(*args, **kwargs): + with no_grad(): + return func(*args, **kwargs) + functools.update_wrapper(_no_grad_wrapper, func) + return _no_grad_wrapper + + +@_no_grad +def clip_grad_norm_( + parameters: _tensor_or_tensors, max_norm: float, norm_type: float = 2.0, + error_if_nonfinite: bool = False, foreach: Optional[bool] = None) -> core.Tensor: + r"""Clip the gradient norm of an iterable of parameters. + + The norm is computed over all gradients together, as if they were + concatenated into a single vector. Gradients are modified in-place. + + Args: + parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a + single Tensor that will have gradients normalized + max_norm (float): max norm of the gradients + norm_type (float): type of the used p-norm. Can be ``'inf'`` for + infinity norm. + error_if_nonfinite (bool): if True, an error is thrown if the total + norm of the gradients from :attr:`parameters` is ``nan``, + ``inf``, or ``-inf``. Default: False (will switch to True in the future) + foreach (bool): use the faster foreach-based implementation. + If ``None``, use the foreach implementation for CUDA and CPU native tensors and silently + fall back to the slow implementation for other device types. + Default: ``None`` + + Returns: + Total norm of the parameter gradients (viewed as a single vector). + """ + if isinstance(parameters, core.Tensor): + parameters = [parameters] + grads = [p.grad for p in parameters if p.grad is not None] + max_norm = float(max_norm) + norm_type = float(norm_type) + if len(grads) == 0: + return core.tensor(0.) + if norm_type == inf: + norms = [g.abs().max() for g in grads] + total_norm = norms[0] if len(norms) == 1 else ops.max(ops.stack(norms)) + else: + total_norm = ops.norm(ops.stack([ops.norm(g, norm_type) for g in grads]), norm_type) + if error_if_nonfinite and ops.logical_or(total_norm.isnan(), total_norm.isinf()): + raise RuntimeError( + f'The total norm of order {norm_type} for gradients from ' + '`parameters` is non-finite, so it cannot be clipped. To disable ' + 'this error and scale the gradients by the non-finite norm anyway, ' + 'set `error_if_nonfinite=False`') + clip_coef = max_norm / (total_norm + 1e-6) + # Note: multiplying by the clamped coef is redundant when the coef is clamped to 1, but doing so + # avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization + # when the gradients do not reside in CPU memory. + clip_coef_clamped = ops.clamp(clip_coef, max=1.0) + for g in grads: + ops.assign(g, ops.mul(g, clip_coef_clamped)) + return total_norm + + + +@deprecated( + "`nn.utils.clip_grad_norm` is now deprecated " + "in favor of `nn.utils.clip_grad_norm_`.", + category=FutureWarning, +) +def clip_grad_norm( + parameters: _tensor_or_tensors, max_norm: float, norm_type: float = 2., + error_if_nonfinite: bool = False, foreach: Optional[bool] = None) -> core.Tensor: + r"""Clip the gradient norm of an iterable of parameters. + + .. warning:: + This method is now deprecated in favor of + :func:`nn.utils.clip_grad_norm_`. + """ + return clip_grad_norm_(parameters, max_norm, norm_type, error_if_nonfinite, foreach) + + + + +@_no_grad +def clip_grad_value_(gradients: _tensor_or_tensors, clip_value: float, foreach: Optional[bool] = None) -> None: + r"""Clip the gradients of an iterable of parameters at specified value. + + Gradients are modified in-place. + + Args: + parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a + single Tensor that will have gradients normalized + clip_value (float): maximum allowed value of the gradients. + The gradients are clipped in the range + :math:`\left[\text{-clip\_value}, \text{clip\_value}\right]` + foreach (bool): use the faster foreach-based implementation + If ``None``, use the foreach implementation for CUDA and CPU native tensors and + silently fall back to the slow implementation for other device types. + Default: ``None`` + """ + clip_value = float(clip_value) + for grad in gradients: + ops.assign(grad, ops.clamp(grad, -clip_value, clip_value)) diff --git a/mindnlp/core/nn/utils/parametrizations.py b/mindnlp/core/nn/utils/parametrizations.py new file mode 100644 index 000000000..c36ad466e --- /dev/null +++ b/mindnlp/core/nn/utils/parametrizations.py @@ -0,0 +1,628 @@ +"""parametrizations""" +# mypy: allow-untyped-defs +from enum import auto, Enum +from typing import Optional + +from mindnlp import core +from mindnlp.core import Tensor +from .. import functional as F +from ...nn.modules import Module +from ... import ops, nn +from ...autograd import no_grad +from . import parametrize +from .weight_norm import _weight_norm, norm_except_dim + +__all__ = ["orthogonal", "spectral_norm", "weight_norm"] + + +def _is_orthogonal(Q, eps=None): + n, k = Q.size(-2), Q.size(-1) + Id = ops.eye(k, dtype=Q.dtype) + # A reasonable eps, but not too large + eps = 10.0 * n * float(ops.finfo(Q.dtype).eps) + return ops.allclose(Q.mH @ Q, Id, atol=eps) + + +def _make_orthogonal(A): + """Assume that A is a tall matrix. + + Compute the Q factor s.t. A = QR (A may be complex) and diag(R) is real and non-negative. + """ + X, tau = ops.geqrf(A) + Q = ops.linalg.householder_product(X, tau) + # The diagonal of X is the diagonal of R (which is always real) so we normalise by its signs + Q *= X.diagonal(dim1=-2, dim2=-1).sgn().unsqueeze(-2) + return Q + + +class _OrthMaps(Enum): + matrix_exp = auto() + cayley = auto() + householder = auto() + + +class _Orthogonal(Module): + base: Tensor + + def __init__( + self, weight, orthogonal_map: _OrthMaps, *, use_trivialization=True + ) -> None: + super().__init__() + + # Note [Householder complex] + # For complex tensors, it is not possible to compute the tensor `tau` necessary for + # linalg.householder_product from the reflectors. + # To see this, note that the reflectors have a shape like: + # 0 0 0 + # * 0 0 + # * * 0 + # which, for complex matrices, give n(n-1) (real) parameters. Now, you need n^2 parameters + # to parametrize the unitary matrices. Saving tau on its own does not work either, because + # not every combination of `(A, tau)` gives a unitary matrix, meaning that if we optimise + # them as independent tensors we would not maintain the constraint + # An equivalent reasoning holds for rectangular matrices + if weight.is_complex() and orthogonal_map == _OrthMaps.householder: + raise ValueError( + "The householder parametrization does not support complex tensors." + ) + + self.shape = weight.shape + self.orthogonal_map = orthogonal_map + if use_trivialization: + self.register_buffer("base", None) + + def forward(self, X: core.Tensor) -> core.Tensor: + n, k = X.size(-2), X.size(-1) + transposed = n < k + if transposed: + X = X.mT + n, k = k, n + # Here n > k and X is a tall matrix + if self.orthogonal_map in (_OrthMaps.matrix_exp, _OrthMaps.cayley): + # We just need n x k - k(k-1)/2 parameters + X = X.tril() + if n != k: + # Embed into a square matrix + X = ops.cat( + [X, X.new_zeros(n, n - k).expand(*X.shape[:-2], -1, -1)], dim=-1 + ) + A = X - X.mH + # A is skew-symmetric (or skew-hermitian) + if self.orthogonal_map == _OrthMaps.matrix_exp: + Q = ops.matrix_exp(A) + elif self.orthogonal_map == _OrthMaps.cayley: + # Computes the Cayley retraction (I+A/2)(I-A/2)^{-1} + Id = ops.eye(n, dtype=A.dtype) + Q = ops.linalg.solve( + ops.add(Id, A, alpha=-0.5), ops.add(Id, A, alpha=0.5) + ) + # Q is now orthogonal (or unitary) of size (..., n, n) + if n != k: + Q = Q[..., :k] + # Q is now the size of the X (albeit perhaps transposed) + else: + # X is real here, as we do not support householder with complex numbers + A = X.tril(diagonal=-1) + tau = 2.0 / (1.0 + (A * A).sum(dim=-2)) + Q = ops.linalg.householder_product(A, tau) + # The diagonal of X is 1's and -1's + # We do not want to differentiate through this or update the diagonal of X hence the casting + Q = Q * X.diagonal(dim1=-2, dim2=-1).int().unsqueeze(-2) + + if hasattr(self, "base"): + Q = self.base @ Q + if transposed: + Q = Q.mT + return Q # type: ignore[possibly-undefined] + + @no_grad() + def right_inverse(self, Q: core.Tensor) -> core.Tensor: + if Q.shape != self.shape: + raise ValueError( + f"Expected a matrix or batch of matrices of shape {self.shape}. " + f"Got a tensor of shape {Q.shape}." + ) + + Q_init = Q + n, k = Q.size(-2), Q.size(-1) + transpose = n < k + if transpose: + Q = Q.mT + n, k = k, n + + # We always make sure to always copy Q in every path + if not hasattr(self, "base"): + # Note [right_inverse expm cayley] + # If we do not have use_trivialization=True, we just implement the inverse of the forward + # map for the Householder. To see why, think that for the Cayley map, + # we would need to find the matrix X \in R^{n x k} such that: + # Y = core.cat([X.tril(), X.new_zeros(n, n - k).expand(*X.shape[:-2], -1, -1)], dim=-1) + # A = Y - Y.mH + # cayley(A)[:, :k] + # gives the original tensor. It is not clear how to do this. + # Perhaps via some algebraic manipulation involving the QR like that of + # Corollary 2.2 in Edelman, Arias and Smith? + if self.orthogonal_map in (_OrthMaps.cayley, _OrthMaps.matrix_exp): + raise NotImplementedError( + "It is not possible to assign to the matrix exponential " + "or the Cayley parametrizations when use_trivialization=False." + ) + + # If parametrization == _OrthMaps.householder, make Q orthogonal via the QR decomposition. + # Here Q is always real because we do not support householder and complex matrices. + # See note [Householder complex] + A, tau = ops.geqrf(Q) + # We want to have a decomposition X = QR with diag(R) > 0, as otherwise we could + # decompose an orthogonal matrix Q as Q = (-Q)@(-Id), which is a valid QR decomposition + # The diagonal of Q is the diagonal of R from the qr decomposition + A.diagonal(dim1=-2, dim2=-1).sign_() + # Equality with zero is ok because LAPACK returns exactly zero when it does not want + # to use a particular reflection + A.diagonal(dim1=-2, dim2=-1)[tau == 0.0] *= -1 + return A.mT if transpose else A + else: + if n == k: + # We check whether Q is orthogonal + if not _is_orthogonal(Q): + Q = _make_orthogonal(Q) + else: # Is orthogonal + Q = Q.clone() + else: + # Complete Q into a full n x n orthogonal matrix + N = ops.randn( + *(Q.size()[:-2] + (n, n - k)), dtype=Q.dtype + ) + Q = ops.cat([Q, N], dim=-1) + Q = _make_orthogonal(Q) + self.base = Q + + # It is necessary to return the -Id, as we use the diagonal for the + # Householder parametrization. Using -Id makes: + # householder(core.zeros(m,n)) == core.eye(m,n) + # Poor man's version of eye_like + neg_Id = ops.zeros_like(Q_init) + neg_Id = neg_Id.diagonal(dim1=-2, dim2=-1).fill(-1.0) + return neg_Id + + + +def orthogonal( + module: Module, + name: str = "weight", + orthogonal_map: Optional[str] = None, + *, + use_trivialization: bool = True, +) -> Module: + r"""Apply an orthogonal or unitary parametrization to a matrix or a batch of matrices. + + Letting :math:`\mathbb{K}` be :math:`\mathbb{R}` or :math:`\mathbb{C}`, the parametrized + matrix :math:`Q \in \mathbb{K}^{m \times n}` is **orthogonal** as + + .. math:: + + \begin{align*} + Q^{\text{H}}Q &= \mathrm{I}_n \mathrlap{\qquad \text{if }m \geq n}\\ + QQ^{\text{H}} &= \mathrm{I}_m \mathrlap{\qquad \text{if }m < n} + \end{align*} + + where :math:`Q^{\text{H}}` is the conjugate transpose when :math:`Q` is complex + and the transpose when :math:`Q` is real-valued, and + :math:`\mathrm{I}_n` is the `n`-dimensional identity matrix. + In plain words, :math:`Q` will have orthonormal columns whenever :math:`m \geq n` + and orthonormal rows otherwise. + + If the tensor has more than two dimensions, we consider it as a batch of matrices of shape `(..., m, n)`. + + The matrix :math:`Q` may be parametrized via three different ``orthogonal_map`` in terms of the original tensor: + + - ``"matrix_exp"``/``"cayley"``: + the :func:`~core.matrix_exp` :math:`Q = \exp(A)` and the `Cayley map`_ + :math:`Q = (\mathrm{I}_n + A/2)(\mathrm{I}_n - A/2)^{-1}` are applied to a skew-symmetric + :math:`A` to give an orthogonal matrix. + - ``"householder"``: computes a product of Householder reflectors + (:func:`~core.linalg.householder_product`). + + ``"matrix_exp"``/``"cayley"`` often make the parametrized weight converge faster than + ``"householder"``, but they are slower to compute for very thin or very wide matrices. + + If ``use_trivialization=True`` (default), the parametrization implements the "Dynamic Trivialization Framework", + where an extra matrix :math:`B \in \mathbb{K}^{n \times n}` is stored under + ``module.parametrizations.weight[0].base``. This helps the + convergence of the parametrized layer at the expense of some extra memory use. + See `Trivializations for Gradient-Based Optimization on Manifolds`_ . + + Initial value of :math:`Q`: + If the original tensor is not parametrized and ``use_trivialization=True`` (default), the initial value + of :math:`Q` is that of the original tensor if it is orthogonal (or unitary in the complex case) + and it is orthogonalized via the QR decomposition otherwise (see :func:`core.linalg.qr`). + Same happens when it is not parametrized and ``orthogonal_map="householder"`` even when ``use_trivialization=False``. + Otherwise, the initial value is the result of the composition of all the registered + parametrizations applied to the original tensor. + + .. note:: + This function is implemented using the parametrization functionality + in :func:`~core.nn.utils.parametrize.register_parametrization`. + + + .. _`Cayley map`: https://en.wikipedia.org/wiki/Cayley_transform#Matrix_map + .. _`Trivializations for Gradient-Based Optimization on Manifolds`: https://arxiv.org/abs/1909.09501 + + Args: + module (nn.Module): module on which to register the parametrization. + name (str, optional): name of the tensor to make orthogonal. Default: ``"weight"``. + orthogonal_map (str, optional): One of the following: ``"matrix_exp"``, ``"cayley"``, ``"householder"``. + Default: ``"matrix_exp"`` if the matrix is square or complex, ``"householder"`` otherwise. + use_trivialization (bool, optional): whether to use the dynamic trivialization framework. + Default: ``True``. + + Returns: + The original module with an orthogonal parametrization registered to the specified + weight + + Example:: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK) + >>> orth_linear = orthogonal(nn.Linear(20, 40)) + >>> orth_linear + ParametrizedLinear( + in_features=20, out_features=40, bias=True + (parametrizations): ModuleDict( + (weight): ParametrizationList( + (0): _Orthogonal() + ) + ) + ) + >>> # xdoctest: +IGNORE_WANT + >>> Q = orth_linear.weight + >>> core.dist(Q.T @ Q, core.eye(20)) + tensor(4.9332e-07) + """ + weight = getattr(module, name, None) + if not isinstance(weight, Tensor): + raise ValueError( + f"Module '{module}' has no parameter or buffer with name '{name}'" + ) + + # We could implement this for 1-dim tensors as the maps on the sphere + # but I believe it'd bite more people than it'd help + if weight.ndim < 2: + raise ValueError( + "Expected a matrix or batch of matrices. " + f"Got a tensor of {weight.ndim} dimensions." + ) + + if orthogonal_map is None: + orthogonal_map = ( + "matrix_exp" + if weight.size(-2) == weight.size(-1) or weight.is_complex() + else "householder" + ) + + orth_enum = getattr(_OrthMaps, orthogonal_map, None) + if orth_enum is None: + raise ValueError( + 'orthogonal_map has to be one of "matrix_exp", "cayley", "householder". ' + f"Got: {orthogonal_map}" + ) + orth = _Orthogonal(weight, orth_enum, use_trivialization=use_trivialization) + parametrize.register_parametrization(module, name, orth, unsafe=True) + return module + + + +class _WeightNorm(Module): + def __init__( + self, + dim: Optional[int] = 0, + ) -> None: + super().__init__() + if dim is None: + dim = -1 + self.dim = dim + + def forward(self, weight_g, weight_v): + return _weight_norm(weight_v, weight_g, self.dim) + + def right_inverse(self, weight): + weight_g = norm_except_dim(weight, 2, self.dim) + weight_v = weight + + return weight_g, weight_v + + + +def weight_norm(module: Module, name: str = "weight", dim: int = 0): + r"""Apply weight normalization to a parameter in the given module. + + .. math:: + \mathbf{w} = g \dfrac{\mathbf{v}}{\|\mathbf{v}\|} + + Weight normalization is a reparameterization that decouples the magnitude + of a weight tensor from its direction. This replaces the parameter specified + by :attr:`name` with two parameters: one specifying the magnitude + and one specifying the direction. + + By default, with ``dim=0``, the norm is computed independently per output + channel/plane. To compute a norm over the entire weight tensor, use + ``dim=None``. + + See https://arxiv.org/abs/1602.07868 + + Args: + module (Module): containing module + name (str, optional): name of weight parameter + dim (int, optional): dimension over which to compute the norm + + Returns: + The original module with the weight norm hook + + Example:: + + >>> m = weight_norm(nn.Linear(20, 40), name='weight') + >>> m + ParametrizedLinear( + in_features=20, out_features=40, bias=True + (parametrizations): ModuleDict( + (weight): ParametrizationList( + (0): _WeightNorm() + ) + ) + ) + >>> m.parametrizations.weight.original0.size() + core.Size([40, 1]) + >>> m.parametrizations.weight.original1.size() + core.Size([40, 20]) + + """ + _weight_norm = _WeightNorm(dim) + parametrize.register_parametrization(module, name, _weight_norm, unsafe=True) + + def _weight_norm_compat_hook( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + g_key = f"{prefix}{name}_g" + v_key = f"{prefix}{name}_v" + if g_key in state_dict and v_key in state_dict: + original0 = state_dict.pop(g_key) + original1 = state_dict.pop(v_key) + state_dict[f"{prefix}parametrizations.{name}.original0"] = original0 + state_dict[f"{prefix}parametrizations.{name}.original1"] = original1 + + module._register_load_state_dict_pre_hook(_weight_norm_compat_hook) + return module + + + +class _SpectralNorm(Module): + def __init__( + self, + weight: core.Tensor, + n_power_iterations: int = 1, + dim: int = 0, + eps: float = 1e-12, + ) -> None: + super().__init__() + ndim = weight.ndim + if dim >= ndim or dim < -ndim: + raise IndexError( + "Dimension out of range (expected to be in range of " + f"[-{ndim}, {ndim - 1}] but got {dim})" + ) + + if n_power_iterations <= 0: + raise ValueError( + "Expected n_power_iterations to be positive, but " + f"got n_power_iterations={n_power_iterations}" + ) + self.dim = dim if dim >= 0 else dim + ndim + self.eps = eps + if ndim > 1: + # For ndim == 1 we do not need to approximate anything (see _SpectralNorm.forward) + self.n_power_iterations = n_power_iterations + weight_mat = self._reshape_weight_to_matrix(weight) + h, w = weight_mat.size() + + u = weight_mat.new_empty(h).normal_(0, 1) + v = weight_mat.new_empty(w).normal_(0, 1) + self.register_buffer("_u", F.normalize(u, dim=0, eps=self.eps)) + self.register_buffer("_v", F.normalize(v, dim=0, eps=self.eps)) + + # Start with u, v initialized to some reasonable values by performing a number + # of iterations of the power method + self._power_method(weight_mat, 15) + + def _reshape_weight_to_matrix(self, weight: core.Tensor) -> core.Tensor: + # Precondition + assert weight.ndim > 1 + + if self.dim != 0: + # permute dim to front + weight = weight.permute( + self.dim, *(d for d in range(weight.dim()) if d != self.dim) + ) + + return weight.flatten(1) + + @no_grad() + def _power_method(self, weight_mat: core.Tensor, n_power_iterations: int) -> None: + # See original note at torch/nn/utils/spectral_norm.py + # NB: If `do_power_iteration` is set, the `u` and `v` vectors are + # updated in power iteration **in-place**. This is very important + # because in `DataParallel` forward, the vectors (being buffers) are + # broadcast from the parallelized module to each module replica, + # which is a new module object created on the fly. And each replica + # runs its own spectral norm power iteration. So simply assigning + # the updated vectors to the module this function runs on will cause + # the update to be lost forever. And the next time the parallelized + # module is replicated, the same randomly initialized vectors are + # broadcast and used! + # + # Therefore, to make the change propagate back, we rely on two + # important behaviors (also enforced via tests): + # 1. `DataParallel` doesn't clone storage if the broadcast tensor + # is already on correct device; and it makes sure that the + # parallelized module is already on `device[0]`. + # 2. If the out tensor in `out=` kwarg has correct shape, it will + # just fill in the values. + # Therefore, since the same power iteration is performed on all + # devices, simply updating the tensors in-place will make sure that + # the module replica on `device[0]` will update the _u vector on the + # parallelized module (by shared storage). + # + # However, after we update `u` and `v` in-place, we need to **clone** + # them before using them to normalize the weight. This is to support + # backproping through two forward passes, e.g., the common pattern in + # GAN training: loss = D(real) - D(fake). Otherwise, engine will + # complain that variables needed to do backward for the first forward + # (i.e., the `u` and `v` vectors) are changed in the second forward. + + # Precondition + assert weight_mat.ndim > 1 + + for _ in range(n_power_iterations): + # Spectral norm of weight equals to `u^T W v`, where `u` and `v` + # are the first left and right singular vectors. + # This power iteration produces approximations of `u` and `v`. + self._u = F.normalize( + ops.mv(weight_mat, self._v), # pylint: disable=access-member-before-definition + dim=0, + eps=self.eps, + ) + self._v = F.normalize( + ops.mv(weight_mat.H, self._u), # type: ignore[has-type] + dim=0, + eps=self.eps, + ) + + def forward(self, weight: core.Tensor) -> core.Tensor: + if weight.ndim == 1: + # Faster and more exact path, no need to approximate anything + return F.normalize(weight, dim=0, eps=self.eps) + else: + weight_mat = self._reshape_weight_to_matrix(weight) + if self.training: + self._power_method(weight_mat, self.n_power_iterations) + # See above on why we need to clone + u = self._u.copy() + v = self._v.copy() + # The proper way of computing this should be through F.bilinear, but + # it seems to have some efficiency issues: + # https://github.com/pytorch/pytorch/issues/58093 + sigma = ops.vdot(u, ops.mv(weight_mat, v)) + return weight / sigma + + def right_inverse(self, value: core.Tensor) -> core.Tensor: + # we may want to assert here that the passed value already + # satisfies constraints + return value + + + +def spectral_norm( + module: Module, + name: str = "weight", + n_power_iterations: int = 1, + eps: float = 1e-12, + dim: Optional[int] = None, +) -> Module: + r"""Apply spectral normalization to a parameter in the given module. + + .. math:: + \mathbf{W}_{SN} = \dfrac{\mathbf{W}}{\sigma(\mathbf{W})}, + \sigma(\mathbf{W}) = \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2} + + When applied on a vector, it simplifies to + + .. math:: + \mathbf{x}_{SN} = \dfrac{\mathbf{x}}{\|\mathbf{x}\|_2} + + Spectral normalization stabilizes the training of discriminators (critics) + in Generative Adversarial Networks (GANs) by reducing the Lipschitz constant + of the model. :math:`\sigma` is approximated performing one iteration of the + `power method`_ every time the weight is accessed. If the dimension of the + weight tensor is greater than 2, it is reshaped to 2D in power iteration + method to get spectral norm. + + + See `Spectral Normalization for Generative Adversarial Networks`_ . + + .. _`power method`: https://en.wikipedia.org/wiki/Power_iteration + .. _`Spectral Normalization for Generative Adversarial Networks`: https://arxiv.org/abs/1802.05957 + + .. note:: + This function is implemented using the parametrization functionality + in :func:`~core.nn.utils.parametrize.register_parametrization`. It is a + reimplementation of :func:`core.nn.utils.spectral_norm`. + + .. note:: + When this constraint is registered, the singular vectors associated to the largest + singular value are estimated rather than sampled at random. These are then updated + performing :attr:`n_power_iterations` of the `power method`_ whenever the tensor + is accessed with the module on `training` mode. + + .. note:: + If the `_SpectralNorm` module, i.e., `module.parametrization.weight[idx]`, + is in training mode on removal, it will perform another power iteration. + If you'd like to avoid this iteration, set the module to eval mode + before its removal. + + Args: + module (nn.Module): containing module + name (str, optional): name of weight parameter. Default: ``"weight"``. + n_power_iterations (int, optional): number of power iterations to + calculate spectral norm. Default: ``1``. + eps (float, optional): epsilon for numerical stability in + calculating norms. Default: ``1e-12``. + dim (int, optional): dimension corresponding to number of outputs. + Default: ``0``, except for modules that are instances of + ConvTranspose{1,2,3}d, when it is ``1`` + + Returns: + The original module with a new parametrization registered to the specified + weight + + Example:: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK) + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> snm = spectral_norm(nn.Linear(20, 40)) + >>> snm + ParametrizedLinear( + in_features=20, out_features=40, bias=True + (parametrizations): ModuleDict( + (weight): ParametrizationList( + (0): _SpectralNorm() + ) + ) + ) + >>> core.linalg.matrix_norm(snm.weight, 2) + tensor(1.0081, grad_fn=) + """ + weight = getattr(module, name, None) + if not isinstance(weight, Tensor): + raise ValueError( + f"Module '{module}' has no parameter or buffer with name '{name}'" + ) + + if dim is None: + if isinstance( + module, + ( + nn.ConvTranspose1d, + nn.ConvTranspose2d, + nn.ConvTranspose3d, + ), + ): + dim = 1 + else: + dim = 0 + parametrize.register_parametrization( + module, name, _SpectralNorm(weight, n_power_iterations, dim, eps) + ) + return module diff --git a/mindnlp/core/nn/utils/parametrize.py b/mindnlp/core/nn/utils/parametrize.py new file mode 100644 index 000000000..4b5a44a72 --- /dev/null +++ b/mindnlp/core/nn/utils/parametrize.py @@ -0,0 +1,791 @@ +"""parametrize""" +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +import collections +import copyreg +from contextlib import contextmanager +from copy import deepcopy +from typing import Dict, Optional, Sequence, Tuple, Union + +from mindnlp import core +from mindnlp.core import Tensor +from ...nn.modules.container import Module, ModuleDict, ModuleList +from ...nn.parameter import Parameter +from ...autograd import no_grad + + +__all__ = [ + "cached", + "ParametrizationList", + "register_parametrization", + "is_parametrized", + "remove_parametrizations", + "type_before_parametrizations", + "transfer_parametrizations_and_params", +] + +_cache_enabled = 0 +_cache: Dict[Tuple[int, str], Optional[Tensor]] = {} + + +@contextmanager +def cached(): + r"""Context manager that enables the caching system within parametrizations registered with :func:`register_parametrization`. + + The value of the parametrized objects is computed and cached the first time + they are required when this context manager is active. The cached values are + discarded when leaving the context manager. + + This is useful when using a parametrized parameter more than once in the forward pass. + An example of this is when parametrizing the recurrent kernel of an RNN or when + sharing weights. + + The simplest way to activate the cache is by wrapping the forward pass of the neural network + + .. code-block:: python + + from mindnlp import core.nn.utils.parametrize as P + ... + with P.cached(): + output = model(inputs) + + in training and evaluation. One may also wrap the parts of the modules that use + several times the parametrized tensors. For example, the loop of an RNN with a + parametrized recurrent kernel: + + .. code-block:: python + + with P.cached(): + for x in xs: + out_rnn = self.rnn_cell(x, out_rnn) + """ + global _cache + global _cache_enabled + _cache_enabled += 1 + try: + yield + finally: + _cache_enabled -= 1 + if not _cache_enabled: + _cache = {} + + +def _register_parameter_or_buffer(module, name, X): + if isinstance(X, Parameter): + module.register_parameter(name, X) + else: + module.register_buffer(name, X) + + +def _maybe_set(dest: Tensor, src: Tensor) -> None: + dest.assign_value(src) # type: ignore[call-overload] + + +class ParametrizationList(ModuleList): + r"""A sequential container that holds and manages the original parameters or buffers of a parametrized :class:`core.nn.Module`. + + It is the type of ``module.parametrizations[tensor_name]`` when ``module[tensor_name]`` + has been parametrized with :func:`register_parametrization`. + + If the first registered parametrization has a ``right_inverse`` that returns one tensor or + does not have a ``right_inverse`` (in which case we assume that ``right_inverse`` is the identity), + it will hold the tensor under the name ``original``. + If it has a ``right_inverse`` that returns more than one tensor, these will be registered as + ``original0``, ``original1``, ... + + .. warning:: + This class is used internally by :func:`register_parametrization`. It is documented + here for completeness. It shall not be instantiated by the user. + + Args: + modules (sequence): sequence of modules representing the parametrizations + original (Parameter or Tensor): parameter or buffer that is parametrized + unsafe (bool): a boolean flag that denotes whether the parametrization + may change the dtype and shape of the tensor. Default: `False` + Warning: the parametrization is not checked for consistency upon registration. + Enable this flag at your own risk. + """ + + original: Tensor + unsafe: bool + + def __init__( + self, + modules: Sequence[Module], + original: Union[Tensor, Parameter], + unsafe: bool = False, + ) -> None: + # We require this because we need to treat differently the first parametrization + # This should never throw, unless this class is used from the outside + if len(modules) == 0: + raise ValueError("ParametrizationList requires one or more modules.") + + super().__init__(modules) + self.unsafe = unsafe + + # In plain words: + # module.weight must keep its dtype and shape. + # Furthermore, if there is no right_inverse or the right_inverse returns a tensor, + # this should be of the same dtype as the original tensor + # + # We check that the following invariants hold: + # X = module.weight + # Y = param.right_inverse(X) + # assert isinstance(Y, Tensor) or + # (isinstance(Y, collections.abc.Sequence) and all(isinstance(t, Tensor) for t in Y)) + # Z = param(Y) if isinstance(Y, Tensor) else param(*Y) + # # Consistency checks + # assert X.dtype == Z.dtype and X.shape == Z.shape + # # If it has one input, this allows to be able to use set_ to be able to + # # move data to/from the original tensor without changing its id (which is what the + # # optimizer uses to track parameters) + # if isinstance(Y, Tensor) + # assert X.dtype == Y.dtype + # Below we use original = X, new = Y + + original_shape = original.shape + original_dtype = original.dtype + + # Compute new + with no_grad(): + new = original + for module in reversed(self): # type: ignore[call-overload] + if hasattr(module, "right_inverse"): + try: + new = module.right_inverse(new) + except NotImplementedError: + pass + # else, or if it throws, we assume that right_inverse is the identity + + if not isinstance(new, Tensor) and not isinstance( + new, collections.abc.Sequence + ): + raise ValueError( + "'right_inverse' must return a Tensor or a Sequence of tensors (list, tuple...). " + f"Got {type(new).__name__}" + ) + + # Set the number of original tensors + self.is_tensor = isinstance(new, Tensor) + self.ntensors = 1 if self.is_tensor else len(new) + + # Register the tensor(s) + if self.is_tensor: + if original.dtype != new.dtype: + raise ValueError( + "When `right_inverse` outputs one tensor, it may not change the dtype.\n" + f"original.dtype: {original.dtype}\n" + f"right_inverse(original).dtype: {new.dtype}" + ) + # Set the original to original so that the user does not need to re-register the parameter + # manually in the optimiser + with no_grad(): + _maybe_set(original, new) + _register_parameter_or_buffer(self, "original", original) + else: + for i, originali in enumerate(new): + if not isinstance(originali, Tensor): + raise ValueError( + "'right_inverse' must return a Tensor or a Sequence of tensors " + "(list, tuple...). " + f"Got element {i} of the sequence with type {type(originali).__name__}." + ) + + # If the original tensor was a Parameter that required grad, we expect the user to + # add the new parameters to the optimizer after registering the parametrization + # (this is documented) + if isinstance(original, Parameter): + originali = Parameter(originali, original.requires_grad) + originali.requires_grad = original.requires_grad + _register_parameter_or_buffer(self, f"original{i}", originali) + + if not self.unsafe: + # Consistency checks: + # Since f : A -> B, right_inverse : B -> A, Z and original should live in B + # Z = forward(right_inverse(original)) + Z = self() + if not isinstance(Z, Tensor): + raise ValueError( + f"A parametrization must return a tensor. Got {type(Z).__name__}." + ) + if Z.dtype != original_dtype: + raise ValueError( + "Registering a parametrization may not change the dtype of the tensor, unless `unsafe` flag is enabled.\n" + f"unparametrized dtype: {original_dtype}\n" + f"parametrized dtype: {Z.dtype}" + ) + if Z.shape != original_shape: + raise ValueError( + "Registering a parametrization may not change the shape of the tensor, unless `unsafe` flag is enabled.\n" + f"unparametrized shape: {original_shape}\n" + f"parametrized shape: {Z.shape}" + ) + + def right_inverse(self, value: Tensor) -> None: + r"""Call the ``right_inverse`` methods of the parametrizations in the inverse registration order. + + Then, it stores the result in ``self.original`` if ``right_inverse`` outputs one tensor + or in ``self.original0``, ``self.original1``, ... if it outputs several. + + Args: + value (Tensor): Value to which initialize the module + """ + # All the exceptions in this function should almost never throw. + # They could throw if, for example, right_inverse function returns a different + # dtype when given a different input, which should most likely be caused by a + # bug in the user's code + + with no_grad(): + # See https://github.com/pytorch/pytorch/issues/53103 + for module in reversed(self): # type: ignore[call-overload] + if hasattr(module, "right_inverse"): + value = module.right_inverse(value) + else: + raise RuntimeError( + f"parametrization {type(module).__name__} does not implement " + "right_inverse." + ) + if self.is_tensor: + # These exceptions should only throw when a right_inverse function does not + # return the same dtype for every input, which should most likely be caused by a bug + if not isinstance(value, Tensor): + raise ValueError( + f"`right_inverse` should return a tensor. Got {type(value).__name__}" + ) + if value.dtype != self.original.dtype: + raise ValueError( + f"The tensor returned by `right_inverse` has dtype {value.dtype} " + f"while `original` has dtype {self.original.dtype}" + ) + # We know that the result is going to have the same dtype + _maybe_set(self.original, value) + else: + if not isinstance(value, collections.abc.Sequence): + raise ValueError( + "'right_inverse' must return a sequence of tensors. " + f"Got {type(value).__name__}." + ) + if len(value) != self.ntensors: + raise ValueError( + "'right_inverse' must return a sequence of tensors of length " + f"{self.ntensors}. Got a sequence of length {len(value)}." + ) + for i, tensor in enumerate(value): + original_i = getattr(self, f"original{i}") + if not isinstance(tensor, Tensor): + raise ValueError( + f"`right_inverse` must return a sequence of tensors. " + f"Got element {i} of type {type(tensor).__name__}" + ) + if original_i.dtype != tensor.dtype: + raise ValueError( + f"Tensor {i} returned by `right_inverse` has dtype {tensor.dtype} " + f"while `original{i}` has dtype {original_i.dtype}" + ) + _maybe_set(original_i, tensor) + + def forward(self) -> Tensor: + # Unpack the originals for the first parametrization + if self.is_tensor: + x = self[0](self.original) + else: + originals = (getattr(self, f"original{i}") for i in range(self.ntensors)) + x = self[0](*originals) + # It's not possible to call self[1:] here, so we have to be a bit more cryptic + # Also we want to skip all non-integer keys + curr_idx = 1 + while hasattr(self, str(curr_idx)): + x = self[curr_idx](x) + curr_idx += 1 + return x + + +def _inject_new_class(module: Module) -> None: + r"""Set up a module to be parametrized. + + This works by substituting the class of the module by a class + that extends it to be able to inject a property + + Args: + module (nn.Module): module into which to inject the property + """ + cls = module.__class__ + + def default_deepcopy(self, memo): + # Just emulate a standard deepcopy procedure when __deepcopy__ doesn't exist in the current class. + obj = memo.get(id(self), None) + if obj is not None: + return obj + replica = self.__new__(self.__class__) + memo[id(self)] = replica + replica.__dict__ = deepcopy(self.__dict__, memo) + # Also save all slots if they exist. + slots_to_save = copyreg._slotnames(self.__class__) # type: ignore[attr-defined] + for slot in slots_to_save: + if hasattr(self, slot): + setattr(replica, slot, deepcopy(getattr(self, slot), memo)) + return replica + + def getstate(self): + raise RuntimeError( + "Serialization of parametrized modules is only " + "supported through state_dict(). See:\n" + "https://pycore.org/tutorials/beginner/saving_loading_models.html" + "#saving-loading-a-general-checkpoint-for-inference-and-or-resuming-training" + ) + + dct = {"__getstate__": getstate} + # We don't allow serialization of parametrized modules but should still allow deepcopying. + # Default 'deepcopy' function invokes __deepcopy__ method instead of __getstate__ when it exists. + if not hasattr(cls, "__deepcopy__"): + dct["__deepcopy__"] = default_deepcopy # type: ignore[assignment] + + param_cls = type( + f"Parametrized{cls.__name__}", + (cls,), + dct, + ) + + module.__class__ = param_cls + + +def _inject_property(module: Module, tensor_name: str) -> None: + r"""Injects a property into module[tensor_name]. + + It assumes that the class in the module has already been modified from its + original one using _inject_new_class and that the tensor under :attr:`tensor_name` + has already been moved out + + Args: + module (nn.Module): module into which to inject the property + tensor_name (str): name of the name of the property to create + """ + # We check the precondition. + # This should never fire if register_parametrization is correctly implemented + assert not hasattr(module, tensor_name) + + def get_cached_parametrization(parametrization) -> Tensor: + global _cache # pylint: disable=global-variable-not-assigned + key = (id(module), tensor_name) + tensor = _cache.get(key) + if tensor is None: + tensor = Parameter(parametrization()) + _cache[key] = tensor + return tensor + + def get_parametrized(self) -> Tensor: + parametrization = self.parametrizations[tensor_name] + if _cache_enabled: + return get_cached_parametrization(parametrization) + else: + # If caching is not active, this function just evaluates the parametrization + return Parameter(parametrization()) + + def set_original(self, value: Tensor) -> None: + self.parametrizations[tensor_name].right_inverse(value) + setattr(module.__class__, tensor_name, property(get_parametrized, set_original)) + + +def register_parametrization( + module: Module, + tensor_name: str, + parametrization: Module, + *, + unsafe: bool = False, +) -> Module: + r"""Register a parametrization to a tensor in a module. + + Assume that ``tensor_name="weight"`` for simplicity. When accessing ``module.weight``, + the module will return the parametrized version ``parametrization(module.weight)``. + If the original tensor requires a gradient, the backward pass will differentiate + through :attr:`parametrization`, and the optimizer will update the tensor accordingly. + + The first time that a module registers a parametrization, this function will add an attribute + ``parametrizations`` to the module of type :class:`~ParametrizationList`. + + The list of parametrizations on the tensor ``weight`` will be accessible under + ``module.parametrizations.weight``. + + The original tensor will be accessible under + ``module.parametrizations.weight.original``. + + Parametrizations may be concatenated by registering several parametrizations + on the same attribute. + + The training mode of a registered parametrization is updated on registration + to match the training mode of the host module + + Parametrized parameters and buffers have an inbuilt caching system that can be activated + using the context manager :func:`cached`. + + A :attr:`parametrization` may optionally implement a method with signature + + .. code-block:: python + + def right_inverse(self, X: Tensor) -> Union[Tensor, Sequence[Tensor]] + + This method is called on the unparametrized tensor when the first parametrization + is registered to compute the initial value of the original tensor. + If this method is not implemented, the original tensor will be just the unparametrized tensor. + + If all the parametrizations registered on a tensor implement `right_inverse` it is possible + to initialize a parametrized tensor by assigning to it, as shown in the example below. + + It is possible for the first parametrization to depend on several inputs. + This may be implemented returning a tuple of tensors from ``right_inverse`` + (see the example implementation of a ``RankOne`` parametrization below). + + In this case, the unconstrained tensors are also located under ``module.parametrizations.weight`` + with names ``original0``, ``original1``,... + + .. note:: + + If unsafe=False (default) both the forward and right_inverse methods will be called + once to perform a number of consistency checks. + If unsafe=True, then right_inverse will be called if the tensor is not parametrized, + and nothing will be called otherwise. + + .. note:: + + In most situations, ``right_inverse`` will be a function such that + ``forward(right_inverse(X)) == X`` (see + `right inverse `_). + Sometimes, when the parametrization is not surjective, it may be reasonable + to relax this. + + .. warning:: + + If a parametrization depends on several inputs, :func:`~register_parametrization` + will register a number of new parameters. If such parametrization is registered + after the optimizer is created, these new parameters will need to be added manually + to the optimizer. See :meth:`core.Optimizer.add_param_group`. + + Args: + module (nn.Module): module on which to register the parametrization + tensor_name (str): name of the parameter or buffer on which to register + the parametrization + parametrization (nn.Module): the parametrization to register + Keyword args: + unsafe (bool): a boolean flag that denotes whether the parametrization + may change the dtype and shape of the tensor. Default: `False` + Warning: the parametrization is not checked for consistency upon registration. + Enable this flag at your own risk. + + Raises: + ValueError: if the module does not have a parameter or a buffer named :attr:`tensor_name` + + Examples: + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK) + >>> from mindnlp import core + >>> from mindnlp import core.nn as nn + >>> from mindnlp import core.nn.utils.parametrize as P + >>> + >>> class Symmetric(nn.Module): + >>> def forward(self, X): + >>> return X.triu() + X.triu(1).T # Return a symmetric matrix + >>> + >>> def right_inverse(self, A): + >>> return A.triu() + >>> + >>> m = nn.Linear(5, 5) + >>> P.register_parametrization(m, "weight", Symmetric()) + >>> print(core.allclose(m.weight, m.weight.T)) # m.weight is now symmetric + True + >>> A = core.rand(5, 5) + >>> A = A + A.T # A is now symmetric + >>> m.weight = A # Initialize the weight to be the symmetric matrix A + >>> print(core.allclose(m.weight, A)) + True + + >>> class RankOne(nn.Module): + >>> def forward(self, x, y): + >>> # Form a rank 1 matrix multiplying two vectors + >>> return x.unsqueeze(-1) @ y.unsqueeze(-2) + >>> + >>> def right_inverse(self, Z): + >>> # Project Z onto the rank 1 matrices + >>> U, S, Vh = core.linalg.svd(Z, full_matrices=False) + >>> # Return rescaled singular vectors + >>> s0_sqrt = S[0].sqrt().unsqueeze(-1) + >>> return U[..., :, 0] * s0_sqrt, Vh[..., 0, :] * s0_sqrt + >>> + >>> linear_rank_one = P.register_parametrization(nn.Linear(4, 4), "weight", RankOne()) + >>> print(core.linalg.matrix_rank(linear_rank_one.weight).item()) + 1 + + """ + parametrization.train(module.training) + if is_parametrized(module, tensor_name): + # Correctness checks. + # If A is the space of tensors with shape and dtype equal to module.weight + # we check that parametrization.forward and parametrization.right_inverse are + # functions from A to A + if not unsafe: + Y = getattr(module, tensor_name) + X = parametrization(Y) + if not isinstance(X, Tensor): + raise ValueError( + f"A parametrization must return a tensor. Got {type(X).__name__}." + ) + if X.dtype != Y.dtype: + raise ValueError( + "Registering a parametrization may not change the dtype of the tensor, unless the `unsafe` flag is enabled.\n" + f"module.{tensor_name}.dtype: {Y.dtype}\n" + f"parametrization(module.{tensor_name}).dtype: {X.dtype}" + ) + if X.shape != Y.shape: + raise ValueError( + "Registering a parametrization may not change the shape of the tensor, unless the `unsafe` flag is enabled.\n" + f"module.{tensor_name}.shape: {Y.shape}\n" + f"parametrization(module.{tensor_name}).shape: {X.shape}" + ) + if hasattr(parametrization, "right_inverse"): + try: + Z = parametrization.right_inverse(X) # type: ignore[operator] + except NotImplementedError: + pass + else: + if not isinstance(Z, Tensor): + raise ValueError( + f"parametrization.right_inverse must return a tensor. Got: {type(Z).__name__}" + ) + if Z.dtype != Y.dtype: + raise ValueError( + "The tensor returned by parametrization.right_inverse must have the same dtype " + f"as module.{tensor_name}, unless the `unsafe` flag is enabled.\n" + f"module.{tensor_name}.dtype: {Y.dtype}\n" + f"returned dtype: {Z.dtype}" + ) + if Z.shape != Y.shape: + raise ValueError( + "The tensor returned by parametrization.right_inverse must have the same shape " + f"as module.{tensor_name}, unless the `unsafe` flag is enabled.\n" + f"module.{tensor_name}.shape: {Y.shape}\n" + f"returned shape: {Z.shape}" + ) + # else right_inverse is assumed to be the identity + + # add the new parametrization to the parametrization list + assert isinstance(module.parametrizations, ModuleDict) # Make mypy happy + module.parametrizations[tensor_name].append(parametrization) + # If unsafe was True in previous parametrization, keep it enabled + module.parametrizations[tensor_name].unsafe |= unsafe # type: ignore[index, union-attr] + elif tensor_name in module._buffers or tensor_name in module._parameters: + # Set the parametrization mechanism + # Fetch the original buffer or parameter + original = getattr(module, tensor_name) + # We create this early to check for possible errors + parametrizations = ParametrizationList( + [parametrization], original, unsafe=unsafe + ) + # Delete the previous parameter or buffer + delattr(module, tensor_name) + # If this is the first parametrization registered on the module, + # we prepare the module to inject the property + if not is_parametrized(module): + # Change the class + _inject_new_class(module) + # Inject a ``ModuleDict`` into the instance under module.parametrizations + module.parametrizations = ModuleDict() + # Add a property into the class + _inject_property(module, tensor_name) + # Add a ParametrizationList + assert isinstance(module.parametrizations, ModuleDict) # Make mypy happy + module.parametrizations[tensor_name] = parametrizations + else: + raise ValueError( + f"Module '{module}' does not have a parameter, a buffer, or a " + f"parametrized element with name '{tensor_name}'" + ) + return module + + +def is_parametrized(module: Module, tensor_name: Optional[str] = None) -> bool: + r"""Determine if a module has a parametrization. + + Args: + module (nn.Module): module to query + tensor_name (str, optional): name of the parameter in the module + Default: ``None`` + Returns: + ``True`` if :attr:`module` has a parametrization for the parameter named :attr:`tensor_name`, + or if it has any parametrization when :attr:`tensor_name` is ``None``; + otherwise ``False`` + """ + parametrizations = getattr(module, "parametrizations", None) + if parametrizations is None or not isinstance(parametrizations, ModuleDict): + return False + if tensor_name is None: + # Check that there is at least one parametrized buffer or Parameter + return len(parametrizations) > 0 + else: + return tensor_name in parametrizations + + +def remove_parametrizations( + module: Module, + tensor_name: str, + leave_parametrized: bool = True, +) -> Module: + r"""Remove the parametrizations on a tensor in a module. + + - If ``leave_parametrized=True``, ``module[tensor_name]`` will be set to + its current output. In this case, the parametrization shall not change the ``dtype`` + of the tensor. + - If ``leave_parametrized=False``, ``module[tensor_name]`` will be set to + the unparametrised tensor in ``module.parametrizations[tensor_name].original``. + This is only possible when the parametrization depends on just one tensor. + + Args: + module (nn.Module): module from which remove the parametrization + tensor_name (str): name of the parametrization to be removed + leave_parametrized (bool, optional): leave the attribute :attr:`tensor_name` parametrized. + Default: ``True`` + + Returns: + Module: module + + Raises: + ValueError: if ``module[tensor_name]`` is not parametrized + ValueError: if ``leave_parametrized=False`` and the parametrization depends on several tensors + """ + if not is_parametrized(module, tensor_name): + raise ValueError( + f"Module {module} does not have a parametrization on {tensor_name}" + ) + + # Fetch the original tensor + assert isinstance(module.parametrizations, ModuleDict) # Make mypy happy + parametrizations = module.parametrizations[tensor_name] + if parametrizations.is_tensor: + original = parametrizations.original + if leave_parametrized: + with no_grad(): + t = getattr(module, tensor_name) + # We know they have the same dtype because we have checked this when registering the + # parametrizations. As such, we can use set_ + # We do this so that the parameter does not to change the id() + # This way the user does not need to update the optimizer + with no_grad(): + if type(original) is core.Tensor: + _maybe_set(original, t) + else: + try: + _maybe_set(original, t) + except RuntimeError as e: + # TODO: Fix this for tensor subclasses that are parameters: + # RuntimeError: set_storage is not allowed on a Tensor created from .data or .detach(). + raise RuntimeError( + "Calling remove_parametrizations() with leave_parametrized=True " + "for a parameter that is an instance of a tensor subclass requires " + "set_() to be implemented correctly for the tensor subclass." + "Alternatively, one can opt into the swap_tensors path" + "Either set leave_parametrized=False or provide a working implementation" + "for set_() in the tensor subclass or set " + "core.__future__.set_swap_module_params_on_conversion(True)." + ) from e + else: + if leave_parametrized: + # We cannot use no_grad because we need to know whether one or more + # original tensors required grad + t = getattr(module, tensor_name) + # We'll have to trust the user to add it to the optimizer + original = Parameter(t) if t.requires_grad else t + else: + raise ValueError( + "Cannot leave unparametrized (`leave_parametrized=False`) a tensor " + "that is parametrized in terms of a sequence of tensors." + ) + + # Delete the property that manages the parametrization + delattr(module.__class__, tensor_name) + # Delete the ParametrizationList + del module.parametrizations[tensor_name] + + # Restore the parameter / buffer into the main class + _register_parameter_or_buffer(module, tensor_name, original) + + # Roll back the parametrized class if no other buffer or parameter + # is currently parametrized in this class + if not is_parametrized(module): + delattr(module, "parametrizations") + # Restore class + orig_cls = module.__class__.__bases__[0] + module.__class__ = orig_cls + return module + + +def type_before_parametrizations(module: Module) -> type: + r"""Return the module type before parametrizations were applied and if not, then it returns the module type. + + Args: + module (nn.Module): module to get type of + """ + if is_parametrized(module): + return module.__class__.__bases__[0] + else: + return type(module) + + +def transfer_parametrizations_and_params( + from_module: Module, + to_module: Module, + tensor_name: Optional[str] = None, +) -> Module: + r"""Transfer parametrizations and the parameters they parametrize from :attr:`from_module` to :attr:`to_module`. + + If :attr:`tensor_name` is specified, only transfers the specified parameter, otherwise + transfers all parametrized parameters. If those parameters do not exist in to_module, it will create them. + Does nothing if from_module is not parametrized. + + Args: + from_module (nn.Module): module to transfer from + to_module (nn.Module): module to transfer to + tensor_name (str, optional): parameter to transfer + + Returns: + Module: to_module + """ + if is_parametrized(from_module): + assert isinstance(from_module.parametrizations, ModuleDict) # for mypy + + # get list of all params or the single param to transfer + parameters_to_transfer: Union[list, ModuleDict] = ( + from_module.parametrizations if tensor_name is None else [tensor_name] + ) + + assert hasattr(parameters_to_transfer, "__iter__") # for mypy + for parameter_name in parameters_to_transfer: + # initialize the to-be-transferred param in to_module if it doesn't exist already + if not hasattr(to_module, parameter_name): + setattr( + to_module, + parameter_name, + Parameter(getattr(from_module, parameter_name)), + ) + + # apply the params's parametrizations to to_module + for param_func in from_module.parametrizations[parameter_name]: + register_parametrization(to_module, parameter_name, param_func) + assert isinstance(to_module.parametrizations, ModuleDict) # for mypy + + # make values match, original values can be stored in either original or + # original0, original1..., need to check both cases + if hasattr(from_module.parametrizations[parameter_name], "original"): + to_module.parametrizations[ + parameter_name + ].original = from_module.parametrizations[parameter_name].original + else: + num = 0 + orig_num = "original" + str(num) + # loop through each original# until all values have been set + while hasattr(from_module.parametrizations[parameter_name], orig_num): + setattr( + to_module.parametrizations[parameter_name], + orig_num, + getattr(from_module.parametrizations[parameter_name], orig_num), + ) + num = num + 1 + orig_num = "original" + str(num) + + return to_module diff --git a/mindnlp/core/nn/utils/weight_norm.py b/mindnlp/core/nn/utils/weight_norm.py new file mode 100644 index 000000000..7d6a9ffaf --- /dev/null +++ b/mindnlp/core/nn/utils/weight_norm.py @@ -0,0 +1,200 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +r"""Weight Normalization from https://arxiv.org/abs/1602.07868.""" +from typing import Any, TypeVar +from ..parameter import Parameter +from ..modules import Module +from ... import ops + +__all__ = ['WeightNorm', 'weight_norm', 'remove_weight_norm'] + +def norm_except_dim(weight_v, pows, dim): + r""" + calculte g/||weight_v|| * weight_v method + """ + if dim == -1: + return ops.norm(weight_v, pows) + if dim == 0: + w_shape_v = weight_v.shape[0] # avoid macOS error + output_size = (w_shape_v,) + (1,) * (weight_v.ndim - 1) + return ops.norm(weight_v.view((w_shape_v, -1)), pows, 1).view(output_size) + if dim == (weight_v.ndim - 1): + output_size = (1,) * (weight_v.ndim - 1) + (weight_v.shape[weight_v.ndim - 1],) + return ops.norm(weight_v.view((-1, weight_v.shape[weight_v.ndim - 1])), pows, 0).view(output_size) + return norm_except_dim(weight_v.swapaxes(0, dim), pows, dim).swapaxes(0, dim) + +def _weight_norm(weight_v, weight_g, dim): + r""" + calculte weight_g/||weight_v|| * weight_v method + """ + return weight_v * (weight_g / norm_except_dim(weight_v, 2, dim)) + + +class WeightNorm: + + r""" + The 'WeightNorm' class implements weight normalization for neural network modules. It provides methods to compute normalized weights, apply weight normalization to a cell, wrap a function, and remove + weight bias from a cell. The class also contains an initializer for the name and dimension of the weight parameters, as well as a method to compute the weight using the normalized parameters. Additionally, it + includes a method to remove the weight bias and a wrapper function for transposing cell_id to cell. + """ + name: str + dim: int + + def __init__(self, name: str, dim: int) -> None: + if dim is None: + dim = -1 + self.name = name + self.dim = dim + + # TODO Make return type more specific + def compute_weight(self, module: Module) -> Any: + g = getattr(module, self.name + '_g') + v = getattr(module, self.name + '_v') + return Parameter(_weight_norm(v, g, self.dim)) + + @staticmethod + def apply(module, name: str, dim: int) -> 'WeightNorm': + for k, hook in module._forward_pre_hooks.items(): + if isinstance(hook, WeightNorm) and hook.name == name: + raise RuntimeError("Cannot register two weight_norm hooks on " + "the same parameter {}".format(name)) + + if dim is None: + dim = -1 + + fn = WeightNorm(name, dim) + + weight = getattr(module, name) + # if isinstance(weight, UninitializedParameter): + # raise ValueError( + # 'The module passed to `WeightNorm` can\'t have uninitialized parameters. ' + # 'Make sure to run the dummy forward before applying weight normalization') + # remove w from parameter list + del module._parameters[name] + + # add g and v as new parameters and express w as g/||v|| * v + module.register_parameter(name + '_g', Parameter(norm_except_dim(weight, 2, dim))) + module.register_parameter(name + '_v', Parameter(weight)) + setattr(module, name, fn.compute_weight(module)) + + # recompute weight before every forward() + module.register_forward_pre_hook(fn) + + return fn + + def wrapper_func(self, cell, func): + r""" + wrapper_func where used to transpose cell_id to cell + """ + def new_func(_, inputs): + nonlocal cell + return func(cell, inputs) + return new_func + + def remove(self, module: Module) -> None: + weight = self.compute_weight(module) + delattr(module, self.name) + del module._parameters[self.name + '_g'] + del module._parameters[self.name + '_v'] + setattr(module, self.name, weight) + + def __call__(self, module: Module, inputs: Any) -> None: + setattr(module, self.name, self.compute_weight(module)) + + +T_module = TypeVar('T_module', bound=Module) + +def weight_norm(module: T_module, name: str = 'weight', dim: int = 0) -> T_module: + r"""Apply weight normalization to a parameter in the given module. + + .. math:: + \mathbf{w} = g \dfrac{\mathbf{v}}{\|\mathbf{v}\|} + + Weight normalization is a reparameterization that decouples the magnitude + of a weight tensor from its direction. This replaces the parameter specified + by :attr:`name` (e.g. ``'weight'``) with two parameters: one specifying the magnitude + (e.g. ``'weight_g'``) and one specifying the direction (e.g. ``'weight_v'``). + Weight normalization is implemented via a hook that recomputes the weight + tensor from the magnitude and direction before every :meth:`~Module.forward` + call. + + By default, with ``dim=0``, the norm is computed independently per output + channel/plane. To compute a norm over the entire weight tensor, use + ``dim=None``. + + See https://arxiv.org/abs/1602.07868 + + .. warning:: + + This function is deprecated. Use :func:`core.nn.utils.parametrizations.weight_norm` + which uses the modern parametrization API. The new ``weight_norm`` is compatible + with ``state_dict`` generated from old ``weight_norm``. + + Migration guide: + + * The magnitude (``weight_g``) and direction (``weight_v``) are now expressed + as ``parametrizations.weight.original0`` and ``parametrizations.weight.original1`` + respectively. If this is bothering you, please comment on + https://github.com/pytorch/pytorch/issues/102999 + + * To remove the weight normalization reparametrization, use + :func:`core.nn.utils.parametrize.remove_parametrizations`. + + * The weight is no longer recomputed once at module forward; instead, it will + be recomputed on every access. To restore the old behavior, use + :func:`core.nn.utils.parametrize.cached` before invoking the module + in question. + + Args: + module (Module): containing module + name (str, optional): name of weight parameter + dim (int, optional): dimension over which to compute the norm + + Returns: + The original module with the weight norm hook + + Example:: + + >>> m = weight_norm(nn.Linear(20, 40), name='weight') + >>> m + Linear(in_features=20, out_features=40, bias=True) + >>> m.weight_g.size() + core.Size([40, 1]) + >>> m.weight_v.size() + core.Size([40, 20]) + + """ + WeightNorm.apply(module, name, dim) + return module + +def remove_weight_norm(module: T_module, name: str = 'weight') -> T_module: + r"""Removes the weight normalization reparameterization from a module. + + Args: + module (Module): containing module + name (str, optional): name of weight parameter + + Example: + >>> m = weight_norm(nn.Linear(20, 40)) + >>> remove_weight_norm(m) + """ + for k, hook in module._forward_pre_hooks.items(): + if isinstance(hook, WeightNorm) and hook.name == name: + hook.remove(module) + del module._forward_pre_hooks[k] + return module + + raise ValueError("weight_norm of '{}' not found in {}" + .format(name, module)) diff --git a/mindnlp/core/npu/__init__.py b/mindnlp/core/npu/__init__.py new file mode 100644 index 000000000..c5ded5b1f --- /dev/null +++ b/mindnlp/core/npu/__init__.py @@ -0,0 +1,46 @@ +from typing import Any + +import mindspore +from mindspore import get_rng_state, set_rng_state, manual_seed +from mindspore.hal import * + +from mindnlp import core + +FloatTensor = core.FloatTensor +HalfTensor = core.FloatTensor +BFloat16Tensor = core.BFloat16Tensor + +def set_compile_mode(*args, **kwargs): + pass + +def manual_seed_all(seed: int): + manual_seed(seed) + +def current_device(): + return core.device('npu', 0) + +def is_available(): + return mindspore.get_context('device_target') == 'Ascend' + +def set_device(device): + pass + +def _lazy_call(callable, **kwargs): + callable() + +class device: + r"""Context-manager that changes the selected device. + + Args: + device (core.device or int): device index to select. It's a no-op if + this argument is a negative integer or ``None``. + """ + + def __init__(self, device: Any): + self.prev_idx = -1 + + def __enter__(self): + self.prev_idx = -1 + + def __exit__(self, type: Any, value: Any, traceback: Any): + return False diff --git a/mindnlp/core/npu/amp/__init__.py b/mindnlp/core/npu/amp/__init__.py new file mode 100644 index 000000000..de8bea8a6 --- /dev/null +++ b/mindnlp/core/npu/amp/__init__.py @@ -0,0 +1,8 @@ +from .autocast_mode import autocast, custom_bwd, custom_fwd + + +__all__ = [ + "autocast", + "custom_bwd", + "custom_fwd", +] \ No newline at end of file diff --git a/mindnlp/core/npu/amp/autocast_mode.py b/mindnlp/core/npu/amp/autocast_mode.py new file mode 100644 index 000000000..cd22ce019 --- /dev/null +++ b/mindnlp/core/npu/amp/autocast_mode.py @@ -0,0 +1,90 @@ +# mypy: allow-untyped-defs +import functools +from typing import Any +from typing_extensions import deprecated + +from mindnlp import core + + +__all__ = ["autocast", "custom_fwd", "custom_bwd"] + + +class autocast(core.amp.autocast_mode.autocast): + r"""See :class:`core.autocast`. + + ``core.cuda.amp.autocast(args...)`` is deprecated. Please use ``core.amp.autocast("cuda", args...)`` instead. + """ + + @deprecated( + "`core.cuda.amp.autocast(args...)` is deprecated. " + "Please use `core.amp.autocast('cuda', args...)` instead.", + category=FutureWarning, + ) + def __init__( + self, + enabled: bool = True, + dtype: core.dtype = core.float16, + cache_enabled: bool = True, + ): + if core._jit_internal.is_scripting(): + self._enabled = enabled + self.device = "cuda" + self.fast_dtype = dtype + return + super().__init__( + "cuda", enabled=enabled, dtype=dtype, cache_enabled=cache_enabled + ) + + def __enter__(self): + if core._jit_internal.is_scripting(): + return self + return super().__enter__() + + # TODO: discuss a unified TorchScript-friendly API for autocast + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any): # type: ignore[override] + if core._jit_internal.is_scripting(): + return + return super().__exit__(exc_type, exc_val, exc_tb) + + def __call__(self, func): + if core._jit_internal.is_scripting(): + return func + return super().__call__(func) + + +# Preserved only for BC reasons +@deprecated( + "`core.cuda.amp.autocast_mode._cast(value, dtype)` is deprecated. " + "Please use `core.amp.autocast_mode._cast(value, 'cuda', dtype)` instead.", + category=FutureWarning, +) +def _cast(value, dtype): + return core.amp.autocast_mode._cast(value, "cuda", dtype) + + +@deprecated( + "`core.cuda.amp.custom_fwd(args...)` is deprecated. " + "Please use `core.amp.custom_fwd(args..., device_type='cuda')` instead.", + category=FutureWarning, +) +def custom_fwd(fwd=None, *, cast_inputs=None): + """ + ``core.cuda.amp.custom_fwd(args...)`` is deprecated. Please use + ``core.amp.custom_fwd(args..., device_type='cuda')`` instead. + """ + return functools.partial(core.amp.custom_fwd, device_type="cuda")( + fwd=fwd, cast_inputs=cast_inputs + ) + + +@deprecated( + "`core.cuda.amp.custom_bwd(args...)` is deprecated. " + "Please use `core.amp.custom_bwd(args..., device_type='cuda')` instead.", + category=FutureWarning, +) +def custom_bwd(bwd): + """ + ``core.cuda.amp.custom_bwd(args...)`` is deprecated. Please use + ``core.amp.custom_bwd(args..., device_type='cuda')`` instead. + """ + return functools.partial(core.amp.custom_bwd, device_type="cuda")(bwd) diff --git a/mindnlp/core/onnx/__init__.py b/mindnlp/core/onnx/__init__.py new file mode 100644 index 000000000..a157231b9 --- /dev/null +++ b/mindnlp/core/onnx/__init__.py @@ -0,0 +1 @@ +from .utils import register_custom_op_symbolic \ No newline at end of file diff --git a/mindnlp/core/onnx/symbolic_helper.py b/mindnlp/core/onnx/symbolic_helper.py new file mode 100644 index 000000000..2f5c28f74 --- /dev/null +++ b/mindnlp/core/onnx/symbolic_helper.py @@ -0,0 +1,102 @@ +import inspect +import functools + +from typing import Any, Callable, Literal, NoReturn, TypeVar as _TypeVar +from typing_extensions import Concatenate as _Concatenate, ParamSpec as _ParamSpec + +_T = _TypeVar("_T") +_U = _TypeVar("_U") +_P = _ParamSpec("_P") + +_ValueDescriptor = Literal[ + "v", + "i", + "is", + "f", + "fs", + "b", + "s", + "t", + "none", +] + +def parse_args( + *arg_descriptors: _ValueDescriptor, +) -> Callable[[Callable[_Concatenate[_U, _P], _T]], Callable[_Concatenate[_U, _P], _T]]: + """A decorator which converts args from torch._C.Value to built-in types. + + For example: + + ``` + @parse_args('v', 'i', 'fs') + foo(g, a, b, c): + assert isinstance(a, torch._C.Value) + assert isinstance(b, int) + assert isinstance(c, list) + assert isinstance(c[0], float) + ``` + + Args: + arg_descriptors: list of str, where each element is + a string that specifies the type to convert to. Valid descriptors: + "v": no conversion, keep torch._C.Value. + "i": int + "is": list of int + "f": float + "fs": list of float + "b": bool + "s": str + "t": torch.Tensor + "none": the variable is unused + """ + + def decorator( + fn: Callable[_Concatenate[_U, _P], _T], + ) -> Callable[_Concatenate[_U, _P], _T]: + fn._arg_descriptors = arg_descriptors # type: ignore[attr-defined] + + @functools.wraps(fn) + def wrapper(g: _U, *args: _P.args, **kwargs: _P.kwargs) -> _T: + # some args may be optional, so the length may be smaller + FILE_BUG_MSG = ( + "If you believe this is not due to custom symbolic implementation within your code or " + "an external library, please file an issue at " + "https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml to report this bug." + ) + assert len(arg_descriptors) >= len(args), ( + f"A mismatch between the number of arguments ({len(args)}) and " + f"their descriptors ({len(arg_descriptors)}) was found at symbolic function '{fn.__name__}'. " + f"{FILE_BUG_MSG}" + ) + + try: + sig = inspect.signature(fn) + arg_names = list(sig.parameters.keys())[1:] + fn_name = fn.__name__ + except Exception: + # FIXME(justinchuby): Avoid catching Exception. + # Catch a more specific exception instead. + arg_names = [None] * len(args) # type: ignore[list-item] + fn_name = None + args = [ + _parse_arg(arg, arg_desc, arg_name, fn_name) # type: ignore[method-assign] + for arg, arg_desc, arg_name in zip(args, arg_descriptors, arg_names) + ] + # only support _outputs in kwargs + assert len(kwargs) <= 1, ( + f"Symbolic function {fn.__name__}'s '**kwargs' can contain a single " + f"key/value entry. " + f"{FILE_BUG_MSG}" + ) + + if len(kwargs) == 1: + assert "_outputs" in kwargs, ( + f"Symbolic function {fn.__name__}'s '**kwargs' can only contain " + f"'_outputs' key at '**kwargs'. " + f"{FILE_BUG_MSG}" + ) + return fn(g, *args, **kwargs) + + return wrapper + + return decorator diff --git a/mindnlp/core/onnx/symbolic_opset11.py b/mindnlp/core/onnx/symbolic_opset11.py new file mode 100644 index 000000000..e69de29bb diff --git a/mindnlp/core/onnx/utils.py b/mindnlp/core/onnx/utils.py new file mode 100644 index 000000000..578038e0e --- /dev/null +++ b/mindnlp/core/onnx/utils.py @@ -0,0 +1,25 @@ +from typing import Any, Callable, cast + +def register_custom_op_symbolic( + symbolic_name: str, + symbolic_fn: Callable, + opset_version: int, +): + """Registers a symbolic function for a custom operator. + + When the user registers symbolic for custom/contrib ops, + it is highly recommended to add shape inference for that operator via setType API, + otherwise the exported graph may have incorrect shape inference in some extreme cases. + An example of setType is `test_aten_embedding_2` in `test_operators.py`. + + See "Custom Operators" in the module documentation for an example usage. + + Args: + symbolic_name (str): The name of the custom operator in "::" + format. + symbolic_fn (Callable): A function that takes in the ONNX graph and + the input arguments to the current operator, and returns new + operator nodes to add to the graph. + opset_version (int): The ONNX opset version in which to register. + """ + pass diff --git a/mindnlp/core/ops/__init__.py b/mindnlp/core/ops/__init__.py new file mode 100644 index 000000000..5f14d9d90 --- /dev/null +++ b/mindnlp/core/ops/__init__.py @@ -0,0 +1,36 @@ +"""core ops like torch funcional api""" +from . import array, blas, comparison, pointwise, creation, random, reduction, other, \ + tensor, _inner, optim +from .array import * +from .blas import * +from .comparison import * +from .pointwise import * +from .creation import * +from .random import * +from .reduction import * +from .other import * +from .tensor import * +# from .fft_op import * +# from .spectral import * +from ._inner import * +from .optim import * + +def load_library(lib_path): + raise ImportError('not support import any ops for now.') + +aten = None + +__all__ = [] +__all__.extend(_inner.__all__) +__all__.extend(array.__all__) +__all__.extend(blas.__all__) +__all__.extend(comparison.__all__) +__all__.extend(creation.__all__) +# __all__.extend(fft_op.__all__) +__all__.extend(pointwise.__all__) +__all__.extend(random.__all__) +__all__.extend(reduction.__all__) +# __all__.extend(spectral.__all__) +__all__.extend(tensor.__all__) +__all__.extend(other.__all__) +__all__.extend(optim.__all__) diff --git a/mindnlp/core/ops/_inner.py b/mindnlp/core/ops/_inner.py new file mode 100644 index 000000000..0b292ee8d --- /dev/null +++ b/mindnlp/core/ops/_inner.py @@ -0,0 +1,20 @@ +"""inner ops""" +import mindspore +from mindspore import ops +from ..configs import use_pyboost + +def cast(input, dtype): + return ops.cast(input, dtype) + +def assign(input, other): + return ops.assign(input, other) + +def call_ms_func(func_name, *args, **kwargs): + out = kwargs.pop('out', None) + if out is None: + return func_name(*args, **kwargs) + else: + tmp = func_name(*args, **kwargs) + return out.copy_(tmp) + +__all__ = ['cast', 'assign'] diff --git a/mindnlp/core/ops/array.py b/mindnlp/core/ops/array.py new file mode 100644 index 000000000..ba54bf936 --- /dev/null +++ b/mindnlp/core/ops/array.py @@ -0,0 +1,730 @@ +"""array op""" +import numbers +import numpy as np +import mindspore +from mindspore import ops +from mindspore.ops._primitive_cache import _get_cache_prim +from mindspore.ops.operations._grad_ops import StridedSliceGrad + +from ..configs import use_pyboost, ON_ORANGE_PI +from .other import broadcast_tensors +from ._inner import call_ms_func + +# adjoint + +# argwhere +def argwhere(input): + if use_pyboost(): + return mindspore.mint.nonzero(input) + return ops.argwhere(input) + +# cat +has_cat = hasattr(mindspore.mint, 'cat') +def cat(tensors, dim=0, *, out=None): + if use_pyboost() and has_cat: + return call_ms_func(mindspore.mint.cat, tensors, dim, out=out) + return call_ms_func(ops.cat, tensors, dim, out=out) + +# concat +has_concat = hasattr(mindspore.mint, 'concat') +def concat(tensors, dim=0, *, out=None): + return cat(tensors, dim, out=out) + +# concatenate +def concatenate(tensors, dim=0, out=None): + return cat(tensors, dim, out=out) + +# conj +def conj(input): + return ops.conj(input) + +# chunk +has_chunk = hasattr(mindspore.mint, 'chunk') +def chunk(input, chunks, dim=0): + if use_pyboost() and has_chunk: + return mindspore.mint.chunk(input, chunks, dim) + return ops.chunk(input, chunks, dim) + +# dsplit + + +# column_stack + + +# dstack + + +# gather +has_gather = hasattr(mindspore.mint, 'gather') +def gather(input, dim, index): + if use_pyboost() and has_gather: + return mindspore.mint.gather(input, dim, index) + index = ops.where(index < input.shape[dim], index, index - input.shape[dim]) + return ops.gather_elements(input, dim, index) + +def gather_nd(input, indices): + return ops.gather_nd(input, indices) + +def tf_gather(input, indices, axis, batch_dims=0): + return ops.gather(input, indices, axis, batch_dims) + +# hsplit + + +# hstack +def hstack(tensors): + return ops.hstack(tensors) + + +# index_fill +def index_fill(input, dim, index, value): + return ops.index_fill(input, dim, index, value) + +# index_add +def index_add(input, dim, index, source, *, alpha=1): + if use_pyboost(): + return mindspore.mint.index_add(input, dim, index, source, alpha=alpha) + return ops.index_add(input, index, source, dim) + +def inplace_index_add(input, dim, index, source): + _inplace = _get_cache_prim(ops.InplaceIndexAdd)(dim) + return _inplace(input, index, source) + +# index_copy + + +# index_reduce + + +# index_select +has_index_select = hasattr(mindspore.mint, 'index_select') +def index_select(input, dim, index, *, out=None): + if use_pyboost() and has_index_select: + return call_ms_func(mindspore.mint.index_select, input, dim, index, out=out) + return call_ms_func(ops.index_select, input, dim, index, out=out) + +# masked_select +has_masked_select = hasattr(mindspore.mint, 'masked_select') +def masked_select(input, mask, *, out=None): + if use_pyboost() and has_masked_select: + return call_ms_func(mindspore.mint.masked_select, input, mask, out=out) + return call_ms_func(ops.masked_select, input, mask, out=out) + +# movedim + + +# moveaxis + + +# narrow +has_narrow = hasattr(mindspore.mint, 'narrow') +def narrow(input, dim, start, length): + if use_pyboost() and has_narrow: + return mindspore.mint.narrow(input, dim, start, length) + return ops.narrow(input, dim, start, length) + +# narrow_copy + + +# nonzero +has_nonzero = hasattr(mindspore.mint, 'nonzero') +def nonzero(input, *, as_tuple=False): + if use_pyboost() and has_nonzero: + return mindspore.mint.nonzero(input, as_tuple=as_tuple) + _nonzero = _get_cache_prim(ops.NonZero)() + out = _nonzero(input) + if as_tuple: + if 0 in out.shape: + return (out, out) + return unbind(out, 1) + return out + +# permute +has_permute = hasattr(mindspore.mint, 'permute') +def permute(input, dims): + if use_pyboost() and has_permute: + return mindspore.mint.permute(input, dims) + return ops.permute(input, dims) + +# reshape +has_reshape = hasattr(mindspore.mint, 'reshape') +def reshape(input, shape): + if use_pyboost() and has_reshape: + return mindspore.mint.reshape(input, shape) + return ops.reshape(input, shape) + +def view(input, *shape): + # if use_pyboost(): + # return mindspore.ops.auto_generate.gen_ops_prim.view_op(input, shape) + return reshape(input, shape) + +# row_stack + +# select +has_select = hasattr(mindspore.mint, 'select') +def select(input, dim, index): + if use_pyboost() and has_select: + return mindspore.mint.select(input, dim, index) + slices = () + for _ in range(dim): + slices += (slice(None, None, None),) + slices += (index,) + return input[slices] + +# scatter +has_scatter = hasattr(mindspore.mint, 'scatter') +def scatter(input, dim, index, src): + if use_pyboost() and has_scatter: + return mindspore.mint.scatter(input, dim, index, src) + if not isinstance(src, mindspore.Tensor): + src = ops.full(index.shape, src, dtype=input.dtype) + return ops.tensor_scatter_elements(input, index, src, dim) + +def tf_scatter_nd_update(input, indices, updates): + return ops.scatter_nd_update(input, indices, updates) + +def tf_scatter_nd(indices, updates, shape): + return ops.scatter_nd(indices, updates, shape) + +# diagonal_scatter + + +# select_scatter + + +# slice_scatter + + +# scatter_add +has_scatter_add = hasattr(mindspore.mint, 'scatter_add') +def scatter_add(input, dim, index, src): + if use_pyboost() and has_scatter_add: + return mindspore.mint.scatter_add(input, dim, index, src) + return ops.tensor_scatter_elements(input, index, src, dim, 'add') + +# scatter_reduce + + +# scatter_nd_update +def scatter_nd_update(input, indices, update): + return ops.scatter_nd_update(input, indices, update) + + +def scatter_update(input, indices, updates): + return ops.scatter_update(input, indices, updates) + +# split +has_split = hasattr(mindspore.mint, 'split') +def split(tensor, split_size_or_sections, dim=0): + # FIXME: mint.split accuracy issue + if use_pyboost() and has_split: + return mindspore.mint.split(tensor, split_size_or_sections, dim) + return ops.split(tensor, split_size_or_sections, dim) + +# squeeze +has_squeeze = hasattr(mindspore.mint, 'squeeze') +def squeeze(input, dim=None): + if use_pyboost() and has_squeeze: + return mindspore.mint.squeeze(input, dim) + return ops.squeeze(input, dim) + +# stack +has_stack = hasattr(mindspore.mint, 'stack') +def stack(tensors, dim=0, *, out=None): + if use_pyboost() and has_stack: + return call_ms_func(mindspore.mint.stack, tensors, dim, out=out) + return call_ms_func(ops.stack, tensors, dim, out=out) + +# swapaxes +has_swapaxes = hasattr(mindspore.mint, 'swapaxes') +def swapaxes(input, dim0, dim1): + return transpose(input, dim0, dim1) + +# swapdims +def swapdims(input, dim0, dim1): + return transpose(input, dim0, dim1) + +# take +def take(input, index): + input = input.view(-1) + index_shape = index.shape + index = index.view(-1) + if ON_ORANGE_PI: + return tf_gather(input, index, 0).view(index_shape) + return gather(input, 0, index).view(index_shape) + +# take_along_dim + + +# tensor_split +def tensor_split(input, indices_or_sections, dim=0): + if isinstance(indices_or_sections, mindspore.Tensor): + indices_or_sections = indices_or_sections.tolist() + else: + indices_or_sections = tuple([get_item(t) for t in indices_or_sections]) + return ops.tensor_split(input, indices_or_sections, dim) + +# tile +has_tile = hasattr(mindspore.mint, 'tile') +def tile(input, dims): + if use_pyboost() and has_tile: + return mindspore.mint.tile(input, dims) + return ops.tile(input, dims) + +# transpose +has_transpose = hasattr(mindspore.mint, 'transpose') +def transpose(input, dim0, dim1): + if use_pyboost() and has_transpose: + return mindspore.mint.transpose(input, dim0, dim1) + ranks = list(range(input.ndim)) + rank0 = ranks[dim0] + rank1 = ranks[dim1] + ranks[dim0] = rank1 + ranks[dim1] = rank0 + return permute(input, tuple(ranks)) + +def t(input): + assert input.ndim <= 2, 'Expects input to be <= 2-D tensor and transposes dimensions 0 and 1.' + if input.ndim == 1: + return input + return transpose(input, 0, 1) + +# unbind +has_unbind = hasattr(mindspore.mint, 'unbind') +def unbind(input, dim=0): + if use_pyboost() and has_unbind: + return mindspore.mint.unbind(input, dim) + return ops.unbind(input, dim) + +# unravel_index + +# unsqueeze +has_unsqueeze = hasattr(mindspore.mint, 'unsqueeze') +def unsqueeze(input, dim=None): + if use_pyboost() and has_unsqueeze: + return mindspore.mint.unsqueeze(input, dim) + return ops.expand_dims(input, dim) + +# vsplit + +# vstack +def vstack(input): + return ops.vstack(input) + +_SLICE_ERROR = ( + 'only integers, slices (`:`), ellipsis (`...`), ' + 'newaxis (`None`) and integer or boolean arrays are valid indices' +) + +# where +def where(condition, *args, out=None): + if len(args) == 0: + return nonzero(condition, as_tuple=True) + assert len(args) == 2 + input, other = args + output = mindspore.mint.where(condition, input, other) + if out is not None: + out.assign_value(output) + return output + +def _as_index(idx, need_scalar=True): + """Helper function to parse idx as an index. + """ + if isinstance(idx, numbers.Integral): + return idx, True + + idx = mindspore.Tensor(idx) + if need_scalar and idx.ndim not in (None, 0): + raise IndexError(_SLICE_ERROR + ', got {!r}'.format(idx)) + + if idx.ndim == 0: + return idx.item(), True + return idx, False + + +def cumprod(x, axis=0, exclusive=False, reverse=False): + x = np.array(x) + if reverse: + x = np.flip(x, axis=axis) + + if exclusive: + shifted_x = np.ones_like(x) + if axis == 0: + shifted_x[1:] = x[:-1] + else: + shifted_x[:, 1:] = x[:, :-1] + result = np.cumprod(shifted_x, axis=axis) + else: + result = np.cumprod(x, axis=axis) + + if reverse: + result = np.flip(result, axis=axis) + + return result + +def moveaxis(a, source, destination): + """Raises ValueError if source, destination not in (-ndim(a), ndim(a)).""" + if not source and not destination: + return a + + if isinstance(source, int): + source = (source,) + if isinstance(destination, int): + destination = (destination,) + if len(source) != len(destination): + raise ValueError('The lengths of source and destination must equal') + + a_rank = a.ndim + + def _correct_axis(axis, rank): + if axis < 0: + return axis + rank + return axis + + source = tuple(_correct_axis(axis, a_rank) for axis in source) + destination = tuple(_correct_axis(axis, a_rank) for axis in destination) + + if a.ndim is not None: + perm = [i for i in range(a_rank) if i not in source] + for dest, src in sorted(zip(destination, source)): + assert dest <= len(perm) + perm.insert(dest, src) + else: + r = ops.range(0, a_rank, 1) + + def _remove_indices(a, b): + """Remove indices (`b`) from `a`.""" + items = ops.unstack( + ops.sort(ops.stack(b)) + ) + + i = 0 + result = [] + + for item in items: + result.append(a[i:item]) + i = item + 1 + + result.append(a[i:]) + + return ops.concat(result, 0) + + minus_sources = _remove_indices(r, source) + minus_dest = _remove_indices(r, destination) + + perm = ops.scatter_nd( + ops.expand_dims(minus_dest, 1), minus_sources, [a_rank] + ) + perm = ops.tensor_scatter_update( + perm, ops.expand_dims(destination, 1), source + ) + a = ops.transpose(a, tuple(perm)) + + return a + +def _slice_helper(tensor, slice_spec, do_update=False, updates=None): + """Helper function for __getitem__ and _with_index_update_helper. + """ + begin, end, strides = [], [], [] + new_axis_mask, shrink_axis_mask = 0, 0 + begin_mask, end_mask = 0, 0 + ellipsis_mask = 0 + advanced_indices = [] + shrink_indices = [] + for index, s in enumerate(slice_spec): + if isinstance(s, slice): + if s.start is not None: + begin.append(s.start) + else: + begin.append(0) + begin_mask |= (1 << index) + if s.stop is not None: + end.append(s.stop) + else: + end.append(0) + end_mask |= (1 << index) + if s.step is not None: + strides.append(s.step) + else: + strides.append(1) + elif s is Ellipsis: + begin.append(0) + end.append(0) + strides.append(1) + ellipsis_mask |= (1 << index) + elif s is None: + # begin.append(0) + # end.append(0) + # strides.append(1) + new_axis_mask |= (1 << index) + else: + s, is_scalar = _as_index(s, False) + if is_scalar: + begin.append(s) + end.append(s + 1) + strides.append(1) + shrink_axis_mask |= (1 << index) + shrink_indices.append(index) + else: + begin.append(0) + end.append(0) + strides.append(1) + begin_mask |= (1 << index) + end_mask |= (1 << index) + advanced_indices.append((index, s, ellipsis_mask != 0)) + + if do_update and not advanced_indices: + return strided_slice_update( + tensor, + begin, + end, + strides, + updates, + begin_mask=begin_mask, + end_mask=end_mask, + shrink_axis_mask=shrink_axis_mask, + new_axis_mask=new_axis_mask, + ellipsis_mask=ellipsis_mask, + ) + else: + if updates is not None: + original_tensor = tensor + tensor = ops.strided_slice( + tensor, + begin, + end, + strides, + begin_mask=begin_mask, + end_mask=end_mask, + shrink_axis_mask=shrink_axis_mask, + new_axis_mask=new_axis_mask, + ellipsis_mask=ellipsis_mask, + ) + + if not advanced_indices: + return tensor + + advanced_indices_map = {} + for index, data, had_ellipsis in advanced_indices: + if had_ellipsis: + num_shrink = len([x for x in shrink_indices if x > index]) + dim = index - len(slice_spec) + num_shrink + else: + num_shrink = len([x for x in shrink_indices if x < index]) + dim = index - num_shrink + advanced_indices_map[dim] = data + dims = sorted(advanced_indices_map.keys()) + dims_contiguous = True + if len(dims) > 1: + if dims[0] < 0 and dims[-1] >= 0: # not all same sign + dims_contiguous = False + else: + for i in range(len(dims) - 1): + if dims[i] + 1 != dims[i + 1]: + dims_contiguous = False + break + indices = [advanced_indices_map[x] for x in dims] + indices = broadcast_tensors(*indices) + stacked_indices = ops.stack(indices, axis=-1) + # Skip the contiguous-dims optimization for update because there is no + # tf.*scatter* op that supports the `axis` argument. + if not dims_contiguous or updates is not None: + if range(len(dims)) != dims: + tensor = moveaxis(tensor, dims, range(len(dims))) + tensor_shape_prefix = mindspore.Tensor(tensor.shape[: len(dims)]) + stacked_indices = where( + stacked_indices < 0, + stacked_indices + tensor_shape_prefix, + stacked_indices, + ) + if updates is None: + return ops.gather_nd(tensor, stacked_indices) + else: + # We only need to move-axis `updates` in the contiguous case becausce + # only in this case the result dimensions of advanced indexing are in + # the middle of `updates`. In the non-contiguous case, those dimensions + # are always at the front. + if dims_contiguous: + batch_size = stacked_indices.ndim - 1 + batch_start = dims[0] + if batch_start < 0: + batch_start += len(dims) - batch_size + + def range_(start, length): + return range(start, start + length) + + updates = moveaxis( + updates, range_(batch_start, batch_size), range(batch_size) + ) + tensor = ops.tensor_scatter_update(tensor, stacked_indices, updates) + if range(len(dims)) != dims: + tensor = moveaxis(tensor, range(len(dims)), dims) + return strided_slice_update( + original_tensor, + begin, + end, + strides, + tensor, + begin_mask=begin_mask, + end_mask=end_mask, + shrink_axis_mask=shrink_axis_mask, + new_axis_mask=new_axis_mask, + ellipsis_mask=ellipsis_mask, + ) + + # Note that gather_nd does not support gathering from inside the array. + # To avoid shuffling data back and forth, we transform the indices and + # do a gather instead. + rank = tensor.ndim + dims = [(x + rank if x < 0 else x) for x in dims] + shape_tensor = tensor.shape + dim_sizes = np.take_along_axis(np.array(shape_tensor), np.array(dims), axis=0) + if len(dims) == 1: + stacked_indices = indices[0] + stacked_indices = ops.cast(stacked_indices, mindspore.int32) + stacked_indices = where( + stacked_indices < 0, stacked_indices + mindspore.Tensor(dim_sizes), stacked_indices + ) + axis = dims[0] + if len(dims) > 1: + index_scaling = cumprod(dim_sizes, reverse=True, exclusive=True) + + def _tensordot(a, b): + # TODO(b/168657656): This function should be replaced by + # tensordot(axis=1) once MatMul has int32 XLA kernel. + b = ops.broadcast_to(b, a.shape) + return ops.sum(a * b, dim=-1) + + stacked_indices = _tensordot(stacked_indices, mindspore.Tensor(index_scaling)) + flat_shape = shape_tensor[:axis] + (-1,) + shape_tensor[axis + len(dims) :] + tensor = ops.reshape(tensor, flat_shape) + + return ops.gather(tensor, stacked_indices, axis=axis) + +def _as_spec_tuple(slice_spec): + """Convert slice_spec to tuple.""" + if isinstance(slice_spec, (list, tuple)): + is_index = True + for s in slice_spec: + if s is None or s is Ellipsis or isinstance(s, (list, tuple, slice)): + is_index = False + break + if not is_index: + return tuple(slice_spec) + return (slice_spec,) + +def getitem(self, slice_spec): + if ( + isinstance(slice_spec, bool) + or ( + isinstance(slice_spec, mindspore.Tensor) + and slice_spec.dtype == mindspore.bool_ + ) + ): + return ops.boolean_mask(tensor=self, mask=slice_spec) + + if not isinstance(slice_spec, tuple): + slice_spec = _as_spec_tuple(slice_spec) + + result_t = _slice_helper(self, slice_spec) + return result_t + +def setitem(a, slice_spec, updates): + """Implementation of ndarray._with_index_*.""" + if ( + isinstance(slice_spec, bool) + or ( + isinstance(slice_spec, mindspore.Tensor) + and slice_spec.dtype == mindspore.bool_ + ) + ): + slice_spec = nonzero(slice_spec) + + if not isinstance(slice_spec, tuple): + slice_spec = _as_spec_tuple(slice_spec) + + a_dtype = a.dtype + result_t = _slice_helper(a, slice_spec, True, updates) + return result_t.astype(a_dtype) + +def tensor_scatter_add(input, indeices, updates): + return ops.tensor_scatter_add(input, indeices, updates) + +def tensor_scatter_max(input, indeices, updates): + return ops.tensor_scatter_max(input, indeices, updates) + +def tensor_scatter_min(input, indeices, updates): + return ops.tensor_scatter_min(input, indeices, updates) + +def strided_slice_update(input, begin, end, strides, update, begin_mask=0, end_mask=0, ellipsis_mask=0, new_axis_mask=0, shrink_axis_mask=0): + strided_slice_grad = _get_cache_prim(StridedSliceGrad)(begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask) + updated_tensor = strided_slice_grad(update, input.shape, begin, end, strides) + return ops.assign(input, where(updated_tensor != 0, updated_tensor, input)) + +__all__ = [ + # adjoint, + 'argwhere', + 'cat', + 'concat', + 'concatenate', + 'conj', + 'chunk', + # dsplit, + # column_stack + # dstack + 'gather', + 'gather_nd', + 'tf_gather', + # hsplit + 'hstack', + 'index_fill', + 'index_add', + 'inplace_index_add', + # index_copy + # index_reduce + 'index_select', + 'masked_select', + # movedim + # moveaxis + 'narrow', + # narrow_copy + 'nonzero', + 'permute', + 'reshape', + 'view', + # row_stack + 'select', + 'scatter', + 'tf_scatter_nd_update', + 'tf_scatter_nd', + # diagonal_scatter + # select_scatter + # slice_scatter + 'scatter_add', + # scatter_reduce + 'scatter_nd_update', + 'scatter_update', + 'split', + 'squeeze', + 'stack', + 'swapaxes', + 'swapdims', + 'take', + # take_along_dim + 'tensor_split', + 'tile', + 'transpose', + 't', + 'unbind', + # unravel_index + 'unsqueeze', + # vsplit + 'vstack', + 'where', + 'getitem', + 'setitem', + 'tensor_scatter_add', + 'tensor_scatter_max', + 'tensor_scatter_min', + 'strided_slice_update' +] diff --git a/mindnlp/core/ops/blas.py b/mindnlp/core/ops/blas.py new file mode 100644 index 000000000..6506a2277 --- /dev/null +++ b/mindnlp/core/ops/blas.py @@ -0,0 +1,185 @@ +"""blas op""" +import mindspore + +from mindspore import ops +from ..configs import use_pyboost, ON_ORANGE_PI +from ._inner import call_ms_func + +# addbmm +has_addbmm = hasattr(mindspore.mint, 'addbmm') +def addbmm(input, batch1, batch2, *, beta=1, alpha=1, out=None): + if use_pyboost() and has_addbmm: + return call_ms_func(mindspore.mint.addbmm, input, batch1, batch2, beta=beta, alpha=alpha, out=out) + return call_ms_func(ops.addbmm, input, batch1, batch2, beta=beta, alpha=alpha, out=out) + +# addmm +has_addmm = hasattr(mindspore.mint, 'addmm') +def addmm(input, mat1, mat2, *, beta=1, alpha=1): + if use_pyboost() and has_addmm: + return mindspore.mint.addmm(input, mat1, mat2, beta=beta, alpha=alpha) + return ops.addmm(input, mat1, mat2, beta=beta, alpha=alpha) + +# addmv + + +# addr + + +# baddbmm +has_baddbmm = hasattr(mindspore.mint, 'baddbmm') +def baddbmm(input, batch1, batch2, *, beta=1, alpha=1, out=None): + if use_pyboost() and has_baddbmm: + return call_ms_func(mindspore.mint.baddbmm, input, batch1, batch2, beta=beta, alpha=alpha, out=out) + return call_ms_func(ops.baddbmm, input, batch1, batch2, beta=beta, alpha=alpha, out=out) + +# bmm +has_bmm = hasattr(mindspore.mint, 'bmm') +def bmm(input, other, *, out=None): + if ON_ORANGE_PI: + input = input.to(mindspore.float16) + other = input.to(mindspore.float16) + if use_pyboost() and has_bmm: + return call_ms_func(mindspore.mint.bmm, input, other, out=out) + return call_ms_func(ops.bmm, input, other, out=out) + +# chain_matmul + + +# cholesky + +# cholesky_inverse + +# cholesky_solve + +# dot +has_dot = hasattr(mindspore.mint, 'dot') +def dot(input, other): + if use_pyboost() and has_dot: + return mindspore.mint.dot(input, other) + return (input * other).sum() + +# geqrf + +# ger + +# inner + +# inverse + +# det + +# logdet + +# slogdet + +# lu + +# lu_solve + + +# lu_unpack + +# matmul +has_matmul = hasattr(mindspore.mint, 'matmul') +def matmul(input, other, *, out=None): + if ON_ORANGE_PI: + input = input.to(mindspore.float16) + other = other.to(mindspore.float16) + if use_pyboost() and has_matmul: + return call_ms_func(mindspore.mint.matmul, input, other, out=out) + return call_ms_func(ops.matmul, input, other, out=out) + +# matrix_power + +# matrix_exp + +# mm +has_mm = hasattr(mindspore.mint, 'mm') +def mm(input, other): + return matmul(input, other) + +# mv + + +# orgqr + +# ormqr + +# outer +has_outer = hasattr(mindspore.mint, 'outer') +def outer(input, vec2, *, out=None): + if use_pyboost() and has_outer: + return call_ms_func(mindspore.mint.outer, input, vec2, out=out) + return call_ms_func(ops.outer, input, vec2, out=out) + +# pinverse + + +# qr + +# svd + +# svd_lowrank + +# pca_lowrank + + +# lobpcg + + +# trapz + + +# trapezoid + + +# cumulative_trapezoid + + +# triangular_solve + + +# vdot + +__all__ = [ + 'addbmm', + 'addmm', + # addmv + # addr + 'baddbmm', + 'bmm', + # chain_matmul + # cholesky + # cholesky_inverse + # cholesky_solve + 'dot', + # geqrf + # ger + # inner + # inverse + # det + # logdet + # slogdet + # lu + # lu_solve + # lu_unpack + 'matmul', + # matrix_power + # matrix_exp + 'mm', + # mv + # orgqr + # ormqr + 'outer', + # pinverse + # qr + # svd + # svd_lowrank + # pca_lowrank + # lobpcg + # trapz + # trapezoid + # cumulative_trapezoid + # triangular_solve + # vdot +] diff --git a/mindnlp/core/ops/comparison.py b/mindnlp/core/ops/comparison.py new file mode 100644 index 000000000..1be22e8c5 --- /dev/null +++ b/mindnlp/core/ops/comparison.py @@ -0,0 +1,214 @@ +"""comparison op""" +import numpy as np +import mindspore +from mindspore import ops +from ..configs import use_pyboost + +from ._inner import call_ms_func + +# allclose +has_allclose = hasattr(mindspore.mint, 'allclose') +def allclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False): + if use_pyboost() and has_allclose: + return mindspore.mint.allclose(input, other, rtol=rtol, atol=atol, equal_nan=equal_nan) + return np.allclose(input.numpy(), other.numpy(), rtol, atol, equal_nan) + +# argsort +has_argsort = hasattr(mindspore.mint, 'argsort') +def argsort(input, dim=-1, descending=False, stable=False): + if use_pyboost() and has_argsort: + return mindspore.mint.argsort(input, dim=dim, descending=descending, stable=stable) + return sort(input, dim=dim, descending=descending, stable=stable)[1] + +# eq +has_eq = hasattr(mindspore.mint, 'eq') +def eq(input, other, *, out=None): + if use_pyboost() and has_eq: + return call_ms_func(mindspore.mint.eq, input, other, out=out) + return call_ms_func(ops.eq, input, other, out=out) + +# equal +has_equal = hasattr(mindspore.mint, 'equal') +def equal(input, other): + if use_pyboost() and has_equal: + return mindspore.mint.equal(input, other) + if input.shape != other.shape: + return False + out = eq(input, other) + return out.all() + +# ge +def ge(input, other): + return ops.ge(input, other) + +# gt +has_gt = hasattr(mindspore.mint, 'gt') +def gt(input, other, *, out=None): + if use_pyboost() and has_gt: + return call_ms_func(mindspore.mint.gt, input, other, out=out) + return call_ms_func(ops.gt, input, other, out=out) + + +# greater +has_greater = hasattr(mindspore.mint, 'greater') +def greater(input, other, *, out=None): + return gt(input, other, out=out) + +# isclose +has_isclose = hasattr(mindspore.mint, 'isclose') +def isclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False): + if use_pyboost() and has_isclose: + return mindspore.mint.isclose(input, other, rtol, atol, equal_nan) + return mindspore.tensor(np.isclose(input.numpy(), other.numpy(), rtol, atol, equal_nan)) + +# isfinite +has_isfinite = hasattr(mindspore.mint, 'isfinite') +def isfinite(input): + if use_pyboost() and has_isfinite: + return mindspore.mint.isfinite(input) + return ops.isfinite(input) + +# isin +def isin(elements, test_elements): + elements = elements.ravel().expand_dims(-1) + test_elements = test_elements.ravel() + included = ops.equal(elements, test_elements) + # F.reduce_sum only supports float + res = ops.sum(included.int(), -1).astype(mindspore.bool_) + + return res + +# isinf +has_isinf = hasattr(mindspore.mint, 'isinf') +def isinf(input): + if use_pyboost() and has_isinf: + return mindspore.mint.isinf(input) + if input.dtype in (mindspore.int32, mindspore.int64): + input = input.to(mindspore.float32) + return ops.isinf(input) + +# isposinf + +# isneginf + +# isnan +has_isnan = hasattr(mindspore.mint, 'isnan') +def isnan(input): + if use_pyboost() and has_isnan: + return mindspore.mint.isnan(input) + if input.dtype in (mindspore.int32, mindspore.int64): + input = input.to(mindspore.float32) + return ops.isnan(input) + +# isreal + +# kthvalue + +# le +has_le = hasattr(mindspore.mint, 'le') +def le(input, other, *, out=None): + if use_pyboost() and has_le: + return call_ms_func(mindspore.mint.le, input, other, out=out) + return call_ms_func(ops.le, input, other, out=out) + +# less_equal +has_less_equal = hasattr(mindspore.mint, 'less_equal') +def less_equal(input, other, *, out=None): + return le(input, other, out=out) + +# lt +has_lt = hasattr(mindspore.mint, 'lt') +def lt(input, other, *, out=None): + if use_pyboost() and has_lt: + return call_ms_func(mindspore.mint.lt, input, other, out=out) + return call_ms_func(ops.lt, input, other, out=out) + +# less +has_less = hasattr(mindspore.mint, 'less') +def less(input, other, *, out=None): + return lt(input, other, out=out) + +# maximum +has_maximum = hasattr(mindspore.mint, 'maximum') +def maximum(input, other, *, out=None): + if use_pyboost() and has_maximum: + return call_ms_func(mindspore.mint.maximum, input, other, out=out) + return call_ms_func(ops.maximum, input, other, out=out) + +# minimum +has_minimum = hasattr(mindspore.mint, 'minimum') +def minimum(input, other, *, out=None): + if use_pyboost() and has_minimum: + return call_ms_func(mindspore.mint.minimum, input, other, out=out) + return call_ms_func(ops.minimum, input, other, out=out) + + +# fmax +def fmax(input, other): + return ops.fmax(input, other) + +# fmin +def fmin(input, other): + return ops.fmin(input, other) + +# ne +has_ne = hasattr(mindspore.mint, 'ne') +def ne(input, other, *, out=None): + if use_pyboost() and has_ne: + return call_ms_func(mindspore.mint.ne, input, other, out=out) + return call_ms_func(ops.ne, input, other, out=out) + +# not_equal +has_not_equal = hasattr(mindspore.mint, 'not_equal') +def not_equal(input, other): + return ne(input, other) + +# sort +has_sort = hasattr(mindspore.mint, 'sort') +def sort(input, *, dim=-1, descending=False, stable=False): + if use_pyboost() and has_sort: + return mindspore.mint.sort(input, dim=dim, descending=descending, stable=stable) + return ops.sort(input, dim, descending) + +# topk +has_topk = hasattr(mindspore.mint, 'topk') +def topk(input, k, dim=-1, largest=True, sorted=True): + if use_pyboost() and has_topk: + return mindspore.mint.topk(input, k, dim, largest, sorted) + return ops.topk(input, k, dim, largest, sorted) + +# msort +def msort(input): + return sort(input, dim=0) + +__all__ = [ + 'allclose', + 'argsort', + 'eq', + 'equal', + 'ge', + 'gt', + 'greater', + 'isclose', + 'isfinite', + 'isin', + 'isinf', + # isposinf, + # isneginf, + 'isnan', + # isreal, + # kthvalue, + 'le', + 'less_equal', + 'lt', + 'less', + 'maximum', + 'minimum', + 'fmax', + 'fmin', + 'ne', + 'not_equal', + 'sort', + 'topk', + 'msort', +] diff --git a/mindnlp/core/ops/complex.py b/mindnlp/core/ops/complex.py new file mode 100644 index 000000000..11e063edc --- /dev/null +++ b/mindnlp/core/ops/complex.py @@ -0,0 +1,7 @@ +def real(input): + return execute('real', input) + +def imag(input): + return execute('imag', input) + +__all__ = ['real', 'imag'] \ No newline at end of file diff --git a/mindnlp/core/ops/creation.py b/mindnlp/core/ops/creation.py new file mode 100644 index 000000000..657c3e84f --- /dev/null +++ b/mindnlp/core/ops/creation.py @@ -0,0 +1,240 @@ +"""creation ops""" +import numpy as np +from ml_dtypes import bfloat16 as np_bfloat16 +import mindspore +try: + from mindspore._c_expression import Tensor as CTensor # pylint: disable=no-name-in-module, import-error +except: + from mindspore._c_expression import TensorPy as CTensor # pylint: disable=no-name-in-module, import-error +from mindspore import ops +from mindspore.ops._primitive_cache import _get_cache_prim +from ..configs import use_pyboost, ON_ORANGE_PI +from .._bind import get_default_dtype, get_default_device + +def as_strided(self, size, stride, storage_offset=None): + if len(size) != len(stride): + raise RuntimeError("mismatch in length of strides and shape.") + index = np.arange(0, size[0]*stride[0], stride[0]) + for i in np.arange(1, len(size)): + tmp = np.arange(0, size[i]*stride[i], stride[i]) + index = np.expand_dims(index, -1) + index = index + tmp + if storage_offset is not None: + index = index + storage_offset + + if index.size == 0: + input_indices = mindspore.numpy.empty(index.shape, dtype=mindspore.int32) + else: + input_indices = mindspore.tensor(index.astype(np.int32)) + out = ops.gather(self.reshape(-1), input_indices, 0) + return out + +# from_numpy +def from_numpy(ndarray): + return mindspore.Tensor(ndarray) + +# frombuffer + +# zeros +_zeros = ops.Zeros() +has_zeros = hasattr(mindspore.mint, 'zeros') +def zeros(*size, dtype=None, device=None, requires_grad=False, **kwargs): + if dtype is None: + dtype = get_default_dtype() + if len(size) == 0: + size = kwargs.get('size', None) + if size == () or size == []: + size = ((),) + if isinstance(size[0], (tuple, list)): + size = size[0] + if use_pyboost() and has_zeros: + if device == 'cpu': + return mindspore.Tensor(np.zeros(size), dtype=dtype) + return mindspore.mint.zeros(size, dtype=dtype) + size = tuple(size) + return _zeros(size, dtype) + +# zeros_like +has_zeros_like = hasattr(mindspore.mint, 'zeros_like') +def zeros_like(input, *, dtype=None, memory_format=None): + if dtype is None: + dtype = input.dtype + if use_pyboost() and has_zeros_like: + return mindspore.mint.zeros_like(input, dtype=dtype) + return ops.zeros_like(input, dtype=dtype) + +# ones +_ones = ops.Ones() +has_ones = hasattr(mindspore.mint, 'ones') +def ones(*size, dtype=None, device=None): + if isinstance(size[0], (tuple, list)): + size = size[0] + if dtype is None: + dtype = get_default_dtype() + if dtype == bool: + dtype = mindspore.bool_ + if use_pyboost() and has_ones: + return mindspore.mint.ones(size, dtype=dtype) + return _ones(size, dtype) + +# ones_like +has_ones_like = hasattr(mindspore.mint, 'ones_like') +def ones_like(input, *, dtype=None, device=None): + if dtype is None: + dtype = input.dtype + if use_pyboost() and has_ones_like: + return mindspore.mint.ones_like(input, dtype=dtype) + return ops.ones_like(input, dtype=dtype) + +# arange +has_arange = hasattr(mindspore.mint, 'arange') +def arange(start=0, end=None, step=1, *, dtype=None, device=None): + if ON_ORANGE_PI and dtype in (None, mindspore.int64): + dtype = mindspore.int32 + if use_pyboost() and has_arange: + return mindspore.mint.arange(start, end, step, dtype=dtype) + return ops.arange(start, end, step, dtype=dtype) + +# range +def range(start=0, end=None, step=1, dtype=None): + if end is None: + start, end = 0, start + out = ops.range(start, end+1, step) + if dtype is not None: + out = out.to(dtype) + return out + +# linspace +has_linspace = hasattr(mindspore.mint, 'linspace') +def linspace(start, end, steps, *, dtype=None): + if dtype is None: + dtype = mindspore.float32 + if use_pyboost() and has_linspace: + return mindspore.mint.linspace(start, end, steps, dtype=dtype) + return ops.linspace(start, end, steps).to(dtype) + +# logspace +def logspace(start, end, steps, base=10.0, *, dtype=None): + return ops.logspace(start, end, steps, base, dtype=dtype) + +# eye +has_eye = hasattr(mindspore.mint, 'eye') +def eye(n, m=None, *, dtype=None): + if use_pyboost() and has_eye: + return mindspore.mint.eye(n, m, dtype) + return ops.eye(n, m, dtype) + +# empty +has_empty = hasattr(mindspore.mint, 'empty') +def empty(*size, dtype=None, device=None, requires_grad=False, pin_memory=False, **kwargs): + size = size or kwargs.get('size', None) + if device is None: + device= get_default_device() + + if isinstance(size[0], (tuple, list)): + size = size[0] + + if dtype is None: + dtype = get_default_dtype() + + if device: + if not isinstance(device, str) and hasattr(device, "type"): + device = device.type + if device.lower() == 'cpu': + device = device.upper() + elif device.lower() in ['cuda', 'npu']: + device = 'Ascend' + else: + device = 'CPU' + + # To avoid the problem in irecv and recv of using empty. + + if has_empty: + out = mindspore.mint.empty(size, dtype=dtype, device=device) + else: + out = CTensor(dtype, size) + out = mindspore.Tensor(out) + if requires_grad: + out.requires_grad = True + return out + +# empty_like +has_empty_like = hasattr(mindspore.mint, 'empty_like') +def empty_like(input, *, dtype=None, layout=None, device=None, requires_grad=False, memory_format=None): + return mindspore.mint.empty_like(input, dtype=dtype, device=device) + +# empty_strided + + +# full +has_full = hasattr(mindspore.mint, 'full') +def full(size, fill_value, *, dtype=None, device=None): + if isinstance(fill_value, np.generic): + fill_value = fill_value.item() + if use_pyboost() and has_full: + return mindspore.mint.full(size, fill_value, dtype=dtype) + return ops.full(size, fill_value, dtype=dtype) + +# full_like +has_full_like = hasattr(mindspore.mint, 'full_like') +def full_like(input, fill_value, *, dtype=None, device=None): + if use_pyboost() and has_full_like: + return mindspore.mint.full_like(input, fill_value, dtype=dtype) + if dtype is None: + dtype = input.dtype + return full(input.shape, fill_value, dtype=dtype) + +# quantize_per_tensor + + +# quantize_per_channel + + +# dequantize + + +# complex +def complex(real, imag): + _complex = _get_cache_prim(ops.Complex)() + return _complex(real, imag) + +# polar +has_polar = hasattr(mindspore.mint, 'polar') +def polar(abs, angle): + if use_pyboost() and has_polar: + return mindspore.mint.polar(abs, angle) + return ops.polar(abs, angle) + +# heaviside +def heaviside(input, values): + return ops.heaviside(input, values) + +_TypeDict = { + mindspore.float16: np.float16, + mindspore.float32: np.float32, + mindspore.float64: np.float64, + mindspore.bfloat16: np_bfloat16, + mindspore.int8: np.int8, + mindspore.int16: np.int16, + mindspore.int32: np.int32, + mindspore.int64: np.int64, + mindspore.uint8: np.uint8, + mindspore.bool_: np.bool_, + mindspore.complex64: np.complex64, + mindspore.complex128: np.complex128, +} + + +def frombuffer(buffer, *, dtype=None, count=-1, offset=0, requires_grad=False): + np_dtype = _TypeDict[dtype] + output = np.frombuffer(buffer=buffer, dtype=np_dtype, count=count, offset=offset) + if dtype == mindspore.bfloat16: + return mindspore.Tensor(output.astype(np.float32), dtype=dtype) + return mindspore.Tensor(output, dtype=dtype) + + +__all__ = ['arange', 'as_strided', 'complex', 'empty', 'empty_like', + 'eye', 'from_numpy', 'full', 'full_like', 'frombuffer', + 'heaviside', 'linspace', 'logspace', 'ones', 'ones_like', + 'polar', 'range', 'zeros', 'zeros_like' +] \ No newline at end of file diff --git a/mindnlp/core/ops/fft_op.py b/mindnlp/core/ops/fft_op.py new file mode 100644 index 000000000..9ea151c05 --- /dev/null +++ b/mindnlp/core/ops/fft_op.py @@ -0,0 +1,38 @@ +"""fft""" +from mindspore import ops +from mindspore.ops._primitive_cache import _get_cache_prim +from ..configs import use_pyboost +from .array import narrow +from ._inner import pad + +def rfft(input, n=None, dim=-1, norm="backward"): + if use_pyboost(): + return ops.rfft(input, n, dim, norm) + if input.shape[dim] < n: + pad_inf = (0, n - input.shape[dim]) + pad_dims = (0, 0) * (input.ndim - (dim + 1)) + pad_inf + input = pad(input, pad_dims) + else: + input = narrow(input, dim, 0, n) + _rfft = _get_cache_prim(ops.FFTWithSize)(input.ndim, False, True, norm) + return _rfft(input) + +def irfft(input, n=None, dim=-1, norm="backward"): + if use_pyboost(): + return ops.irfft(input, n, dim, norm) + if input.shape[dim] < n: + pad_inf = (0, n - input.shape[dim]) + pad_dims = (0, 0) * (input.ndim - (dim + 1)) + pad_inf + input = pad(input, pad_dims) + else: + input = narrow(input, dim, 0, n) + _irfft = _get_cache_prim(ops.FFTWithSize)(input.ndim, True, True, norm) + return _irfft(input) + +def fftn(input, s=None, dim=None, norm=None): + return ops.fftn(input, s, dim, norm) + +def fft(input, s=None, dim=-1, norm=None): + return ops.fft(input, s, dim, norm) + +__all__ = ['fft', 'fftn', 'irfft', 'rfft'] \ No newline at end of file diff --git a/mindnlp/core/ops/inplace.py b/mindnlp/core/ops/inplace.py new file mode 100644 index 000000000..107a82dce --- /dev/null +++ b/mindnlp/core/ops/inplace.py @@ -0,0 +1,94 @@ + +from mindspore.common.generator import default_generator + + +generator_step_ = 12 + +def inplace_copy(self, other): + if self.device != other.device: + other = other.to(self.device) + if self.device.type == 'cpu': + # execute('assign', self, other) + # # self._data.assign_value_cpp(other._data) + self.data = other + else: + execute('inplace_copy', self, other) + return self + +def inplace_zero(input): + device = input.device + if input.device == 'npu': + execute('inplace_zero', input) + elif input.device.type == 'cpu': + out = execute('zeros', input.shape, input.dtype, device=device) + input.data = out + return input + +def inplace_fill(input, value): + device = input.device + if input.device == 'npu': + if isinstance(value, (int, float, bool)): + execute('inplace_fill_scalar', input, value) + execute('inplace_fill_tensor', input, value) + elif input.device.type == 'cpu': + out = execute('full', input.shape, value, device=device) + input.data = out + return input + +def inplace_normal(input, mean=0, std=1, *, generator=None): + if generator is None: + generator = default_generator + seed, offset = generator._step(generator_step_) + if input.device.type == 'npu': + execute('inplace_normal', input, mean, std, seed, offset) + elif input.device.type == 'cpu': + core.normal(mean, std, size=input.size, generator=generator, out=input) + + return input + +# uniform_ +def inplace_uniform(input, *args, **kwargs): + if len(args) == 1: + from_ = args[0] + to_ = None + elif len(args) == 2: + from_ = args[0] + to_ = args[1] + elif len(args) == 3: + from_ = args[0] + to_ = args[1] + else: + from_ = 0 + to_ = 1 + + from_ = kwargs.get("from", 0) if from_ is None else from_ + # to_ = kwargs.get("to", 1) + generator_ = kwargs.get("generator", None) + if generator_ is None: + generator_ = default_generator + seed, offset = generator_._step(generator_step_) + if input.device.type == 'npu': + execute("inplace_uniform", input, from_, to_, seed, offset) + elif input.device.type == 'cpu': + input.data = core.rand(input.shape, generator=generator_, dtype=input.dtype) * (to_ - from_) + from_ + return input + +def inplace_add(input, other, alpha): + execute('inplace_add_ext', input, other, alpha) + return input + +def inplace_scatter(input, dim, index, src): + if not isinstance(src, core.Tensor): + return execute('inplace_scatter_value', input, dim, index, src) + return execute('inplace_scatter', input, dim, index, src) + + +__all__ = [ + 'inplace_copy', + 'inplace_zero', + 'inplace_normal', + 'inplace_fill', + 'inplace_uniform', + 'inplace_add', + 'inplace_scatter' +] diff --git a/mindnlp/core/ops/optim.py b/mindnlp/core/ops/optim.py new file mode 100644 index 000000000..7265adcbf --- /dev/null +++ b/mindnlp/core/ops/optim.py @@ -0,0 +1,44 @@ +"""optim op""" +import mindspore +from mindspore import ops +from mindspore.ops._primitive_cache import _get_cache_prim + +DEVICE_TARGET = mindspore.get_context('device_target') + +_adadelta = ops.ApplyAdadelta() +def raw_adadelta(param, square_avg, acc_delta, lr, rho, eps, grad): + return _adadelta(param, square_avg, acc_delta, lr, rho, eps, grad) + +_adam = ops.Adam() +def raw_adam(param, exp_avg, exp_avg_sq, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad): + # var, m, v, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad + if DEVICE_TARGET == 'GPU' and param.dtype != mindspore.float32: + beta1_power, beta2_power, lr, beta1, beta2, epsilon = mindspore.tensor(beta1_power, dtype=param.dtype), \ + mindspore.tensor(beta2_power, dtype=param.dtype), \ + mindspore.tensor(lr, dtype=param.dtype), \ + mindspore.tensor(beta1, dtype=param.dtype), \ + mindspore.tensor(beta2, dtype=param.dtype), \ + mindspore.tensor(epsilon, dtype=param.dtype) + return _adam(param, exp_avg, exp_avg_sq, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad) + +_adam_amsgrad = ops.ApplyAdamWithAmsgradV2() +def raw_adam_amsgrad(param, exp_avg, exp_avg_sq, max_exp_avg_sq, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad): + # var, m, v, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad + + if DEVICE_TARGET == 'GPU' and param.dtype != mindspore.float32: + beta1_power, beta2_power, lr, beta1, beta2, epsilon = mindspore.tensor(beta1_power, dtype=param.dtype), \ + mindspore.tensor(beta2_power, dtype=param.dtype), \ + mindspore.tensor(lr, dtype=param.dtype), \ + mindspore.tensor(beta1, dtype=param.dtype), \ + mindspore.tensor(beta2, dtype=param.dtype), \ + mindspore.tensor(epsilon, dtype=param.dtype) + + return _adam_amsgrad(param, exp_avg, exp_avg_sq, max_exp_avg_sq, + beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad) + + +def raw_sgd(param, grad, lr, dampening, weight_decay, nesterov, accum, momentum, stat): + _sgd = _get_cache_prim(ops.SGD)(dampening, weight_decay, nesterov) + return _sgd(param, grad, lr, accum, momentum, stat) + +__all__ = ['raw_adadelta', 'raw_adam', 'raw_adam_amsgrad', 'raw_sgd'] diff --git a/mindnlp/core/ops/other.py b/mindnlp/core/ops/other.py new file mode 100644 index 000000000..d786a96cd --- /dev/null +++ b/mindnlp/core/ops/other.py @@ -0,0 +1,994 @@ +"""other op""" + +import copy +import numpy as np +import mindspore +from mindspore import ops +from mindspore.common.initializer import initializer +from mindspore.ops._primitive_cache import _get_cache_prim + +from ..configs import use_pyboost, ON_ORANGE_PI +from .reduction import any +from .comparison import eq +from ._inner import call_ms_func + +# atleast_2d + + +# atleast_3d + + +# bincount +has_bincount = hasattr(mindspore.mint, "bincount") + + +def bincount(input, weights=None, minlength=0): + if use_pyboost() and has_bincount: + return mindspore.mint.bincount(input, weights, minlength) + return ops.bincount(input, weights, minlength) + + +# block_diag + + +# broadcast_tensors +def broadcast_tensors(*tensors): + target_shape = broadcast_shapes(*[t.shape for t in tensors]) + + broadcasted_tensors = [t.broadcast_to(target_shape) for t in tensors] + + return broadcasted_tensors + + +def manual_expand(tensor, shape): + assert ( + len(shape) >= tensor.dim() + ), "Target shape must have equal or more dimensions than the tensor." + + for _ in range(len(shape) - tensor.dim()): + tensor = tensor.unsqueeze(0) + + repeats = [] + for i, (tensor_dim, target_dim) in enumerate(zip(tensor.shape, shape)): + if target_dim == -1: + repeats.append(1) + else: + repeats.append(target_dim // tensor_dim if tensor_dim == 1 else 1) + + return tensor.tile(tuple(repeats)) + + +# broadcast_to +has_broadcast_to = hasattr(mindspore.mint, "broadcast_to") + + +def broadcast_to(input, shape): + if ON_ORANGE_PI and not use_pyboost(): + # return input.expand(mindspore.tensor(shape)) + return manual_expand(input, shape) + if use_pyboost() and has_broadcast_to: + return mindspore.mint.broadcast_to(input, shape) + return ops.broadcast_to(input, shape) + + +# broadcast_shapes +def broadcast_shapes(*shapes): + reversed_shapes = [list(reversed(shape)) for shape in shapes] + + max_dim = max(len(shape) for shape in reversed_shapes) + + result_shape = [1] * max_dim + + for i in range(max_dim): + current_dim_size = 1 + for shape in reversed_shapes: + if i < len(shape): + if shape[i] == 1: + continue + if current_dim_size == 1: + current_dim_size = shape[i] + elif current_dim_size != shape[i]: + raise ValueError(f"Shapes {shapes} are not broadcastable.") + result_shape[i] = current_dim_size + + return tuple(reversed(result_shape)) + + +# bucketize + +# cartesian_prod + + +# cdist +has_cdist = hasattr(mindspore.mint, "cdist") + + +def cdist(x1, x2, p=2.0, compute_mode="use_mm_for_euclid_dist_if_necessary"): + if use_pyboost() and has_cdist: + return mindspore.mint.cdist(x1, x2, p, compute_mode) + return ops.cdist(x1, x2, float(p)) + + +# clone +has_clone = hasattr(mindspore.mint, "clone") + + +def clone(input): + if use_pyboost() and has_clone: + return mindspore.mint.clone(input) + return copy.deepcopy(input) + + +# combinations + + +# corrcoef + + +# cov + + +# cross + +# cummax + +# cummin + +# cumprod + +# cumsum +has_cumsum = hasattr(mindspore.mint, "cumsum") + + +def cumsum(input, dim, dtype=None, out=None): + if ( + use_pyboost() and has_cumsum and not ON_ORANGE_PI + ): # since cann8.0 community remove aclnn cumsum + output = mindspore.mint.cumsum(input, dim, dtype) + else: + if input.dtype == mindspore.bool_: + input = input.to(mindspore.int32) + output = ops.cumsum(input, dim, dtype) + if out is not None: + out.assign_value(output) + return output + + +# diag +has_diag = hasattr(mindspore.mint, "diag") + + +def diag(input, diagonal=0): + if use_pyboost() and has_diag: + return mindspore.mint.diag(input, diagonal) + return ops.diag(input) + + +# diag_embed + + +# diagflat + + +# diagonal + +# diff + + +# einsum + + +def einsum_label_to_index(label): + """ + Args: + label (str): The label representing a dimension in an Einstein sum. + It should be a single character from the alphabet (upper or lower case) or '.'. + + Returns: + NoneType: This function returns None. + + Raises: + None. + """ + if label == ".": + return 52 + NUM_OF_LETTERS = ord("z") - ord("a") + 1 + return ( + (ord(label) - ord("A")) + if (label.isupper()) + else (NUM_OF_LETTERS + (ord(label) - ord("a"))) + ) + + +def maybe_wrap_dim(dim: int, dim_post_expr: int, wrap_scalar: bool = True): + r""" + This function takes three parameters: dim, dim_post_expr, and wrap_scalar. + + Args: + - dim (int): Represents the dimension to be wrapped. + - dim_post_expr (int): Represents the value used to wrap the dimension. + - wrap_scalar (bool, optional): Specifies whether a scalar value should be wrapped. Default is True. + + Returns: + None: This function does not return a value directly. + + Raises: + AssertionError: Raised if the value of dim_post_expr is less than or equal to 0 and wrap_scalar is False. + AssertionError: Raised if the value of dim is less than the minimum or greater than the maximum allowed range. + AssertionError: Raised if the value of dim is negative and cannot be wrapped due to invalid dim_post_expr. + + """ + if dim_post_expr <= 0: + assert wrap_scalar + dim_post_expr = 1 + min = -dim_post_expr + max = dim_post_expr - 1 + assert not (dim < min or dim > max) + if dim < 0: + dim += dim_post_expr + return dim + + +def dim_list_to_bitset(opt_dims, ndims): + r""" + Converts a list of optional dimensions to a bitset representation. + + Args: + opt_dims (List[int]): The list of optional dimensions to be converted to a bitset representation. + ndims (int): The total number of dimensions. + + Returns: + List[bool]: A list representing the bitset, where True indicates the presence of the dimension and False indicates its absence. + + Raises: + None + """ + if opt_dims: + seen = [False] * (max(opt_dims) + 1) + for dim in opt_dims: + dim = maybe_wrap_dim(dim, ndims) + seen[dim] = True + else: + seen = [True for _ in range(ndims)] + return seen + + +def sumproduct_pair(left_, right_, sum_dims_, keep_dim_): + """ + Calculate the sum-product pair of two arrays along specified dimensions. + + Args: + left_ (array): The left input array. + right_ (array): The right input array. + sum_dims_ (list): A list of dimensions along which to calculate the sum-product pair. + keep_dim_ (bool): A flag indicating whether to keep the dimensions in the result. + + Returns: + None. The function performs the sum-product pair calculation and returns None. + + Raises: + AssertionError: If the number of dimensions of the input arrays do not match, + or if non-broadcast dimensions do not match. + """ + assert left_.ndim == right_.ndim, "number of dimensions must match" + if len(sum_dims_) == 0: + return ops.mul(left_, right_) + + dim = left_.ndim + sum_dims = dim_list_to_bitset(sum_dims_, dim) + + lro, lo, ro = [], [], [] + lro_size, lo_size, ro_size, sum_size = 1, 1, 1, 1 + left = left_ + right = right_ + + for i in range(dim): + sl = left.shape[i] > 1 + sr = right.shape[i] > 1 + if sum_dims[i]: + if sl and sr: + assert ( + left.shape[i] == right.shape[i] + ), "non-broadcast dimensions must match" + sum_size *= left.shape[i] + elif sl: + left = ops.sum(left, i, keepdim=True) + elif sr: + right = ops.sum(right, i, keepdim=True) + elif sl and sr: + assert ( + left.shape[i] == right.shape[i] + ), "non-broadcast dimensions must match" + lro.append(i) + lro_size *= left.shape[i] + elif sl: + lo.append(i) + lo_size *= left.shape[i] + else: + ro.append(i) + ro_size *= right.shape[i] + + out_size = [] + for d in lro: + out_size.append(left.shape[d]) + for d in lo: + out_size.append(left.shape[d]) + for d in sum_dims_: + out_size.append(1) + for d in ro: + out_size.append(right.shape[d]) + + lpermutation = lro.copy() + lpermutation += lo + lpermutation += sum_dims_ + lpermutation += ro + + rpermutation = lro.copy() + rpermutation += sum_dims_ + rpermutation += ro + rpermutation += lo + + opermutation = [-1] * (len(lro) + len(lo) + len(sum_dims_) + len(ro)) + i = 0 + for it in lro: + opermutation[it] = i + i += 1 + for it in lo: + opermutation[it] = i + i += 1 + for it in sum_dims_: + opermutation[it] = i + i += 1 + for it in ro: + opermutation[it] = i + i += 1 + + left = ops.transpose(left, tuple(lpermutation)).reshape(lro_size, lo_size, sum_size) + right = ops.transpose(right, tuple(rpermutation)).view(lro_size, sum_size, ro_size) + + result = ops.bmm(left, right) + result = result.view(*out_size).transpose(*opermutation) + + if not keep_dim_: + sizes = list(result.shape) + for i in range(dim - 1, 0, -1): + if sum_dims[i]: + sizes.pop(i) + result = result.view(*sizes) + + return result + + +ELLIPSIS = 52 + +has_einsum = hasattr(mindspore.mint, "einsum") + + +def einsum(equation, *operands): + """ + Args: + equation (str): A string representing the Einstein summation equation to be computed. + The equation should follow the Einstein summation convention with subscripts in [a-zA-Z], + commas separating operands, and '->' indicating the output structure. + It must include at least one operand. An ellipsis '...' can be used to represent multiple dimensions. + + Returns: + None: This function does not return a value. + + Raises: + AssertionError: If the function is called without providing at least one operand. + AssertionError: If an invalid subscript is given in the equation string. + AssertionError: If the number of subscripts in the equation does not match the number of dimensions for an operand. + AssertionError: If fewer operands are provided than specified in the equation. + AssertionError: If more operands are provided than specified in the equation. + RuntimeError: If operands do not broadcast with remapped shapes [original->remapped]. + """ + if use_pyboost() and has_einsum: + return mindspore.mint.einsum(equation, *operands) + assert operands, "einsum(): must provide at least one operand" + if isinstance(operands[0], tuple): + operands = operands[0] + + arrow_pos = equation.find("->") + num_ops = len(operands) + op_labels = [[] for _ in range(num_ops)] + lhs = equation[0:arrow_pos] + + curr_op = 0 + found_ell = False + ell_skip = 0 + for i, label in enumerate(lhs): + if label == " ": + continue + if label == ".": + if ell_skip != 0: + ell_skip -= 1 + continue + assert ( + not found_ell + ), f"einsum(): found {curr_op} for operand for which an ellipsis was already found" + assert ( + i + 2 < len(lhs) and lhs[i + 1] == "." + ), f"einsum(): found {curr_op} for operand that is not part of any ellipsis" + ell_skip = 2 + op_labels[curr_op].append(ELLIPSIS) + found_ell = True + elif label == ",": + curr_op += 1 + assert ( + curr_op < num_ops + ), "einsum(): fewer operands were provided than specified in the equation" + found_ell = False + else: + assert str.isalpha( + label + ), f"einsum(): invalid subscript given at index {i} in the equation string, subscripts must be in [a-zA-Z]" + op_labels[curr_op].append(einsum_label_to_index(label)) + + assert ( + curr_op == num_ops - 1 + ), "einsum(): more operands were provided than specified in the equation" + # Labels must be within [a-zA-Z]. + TOTAL_LABELS = 52 + label_count = [0] * TOTAL_LABELS + # The maximum number of dimensions covered by any ellipsis, needed when + # unsqueezing missing dimensions from operands to permute and broadcast + ell_num_dim = 0 + + # Compute label frequency and number of dimensions covered by ellipsis + # We do this after parsing labels to make it more readable and simpler + # to compute the number of dimensions covered by ellipsis. + for i, operand in enumerate(operands): + labels = op_labels[i] + ndims = operand.ndim + nlabels = len(labels) + has_ellipsis = False + + for label in labels: + if label == ELLIPSIS: + nlabels -= 1 + has_ellipsis = True + ell_num_dim = max(ell_num_dim, ndims - nlabels) + else: + label_count[label] += 1 + if has_ellipsis: + assert nlabels <= ndims, ( + f"einsum(): the number of subscripts in the equation ({nlabels}" + f") is more than the number of dimensions ({ndims}) for operand {i}" + ) + else: + assert nlabels == ndims, ( + f"einsum(): the number of subscripts in the equation ({nlabels}" + f") does not match the number of dimensions (" + f"{ndims}) for operand {i} and no ellipsis was given" + ) + + # We want to align the dimensions of every input tensor to have + # shape out_dims + sum_dims. For this, we create a mapping of label + # to index into the permuted shape. + label_perm_index = [-1] * TOTAL_LABELS + # Current index in the permuted shape + perm_index = 0 + # Start index of ellipsis dimensions in the permuted shape + ell_index = 0 + found_ell = False + + if arrow_pos == -1: + # Implicit output is ellipsis (...) + labels seen only once + perm_index = ell_num_dim + found_ell = True + for label, _label_count in enumerate(label_count): + if _label_count == 1: + label_perm_index[label] = perm_index + perm_index += 1 + else: + rhs = equation[arrow_pos + 2 :] + ell_skip = 0 + for i, label in enumerate(rhs): + if label == " ": + continue + if label == ".": + if ell_skip != 0: + ell_skip -= 1 + continue + assert ( + not found_ell + ), "einsum(): found '.' for output but an ellipsis (...) was already found" + assert ( + i + 2 < len(rhs) and rhs[i + 1] == "." + ), "einsum(): found '.' for output that is not part of any ellipsis (...)" + ell_skip = 2 + ell_index = perm_index + perm_index += ell_num_dim + found_ell = True + else: + assert str.isalpha(label), ( + f"einsum(): invalid subscript given at index {len(lhs) + 2 + i} " + f"in the equation string, subscripts must be in [a-zA-Z]" + ) + + index = einsum_label_to_index(label) + label_perm_index[index] = perm_index + perm_index += 1 + + out_size = perm_index + if not found_ell: + ell_index = perm_index + perm_index += ell_num_dim + + for label in range(TOTAL_LABELS): + if label_count[label] > 0 and label_perm_index[label] == -1: + label_perm_index[label] = perm_index + perm_index += 1 + + # Here we unsqueeze missing dimensions to make all operands have the same + # number of dimensions. We take diagonals for repeated labels within the + # same operand. Finally we permute the operands to align dimensions as + # per the perm_out_index we computed above. + permuted_operands = [] + for i, operand in enumerate(operands): + perm_shape = [-1] * perm_index + label_dim = [-1] * TOTAL_LABELS + operand = operands[i] + labels = op_labels[i] + original_sizes = operand.shape + + j = 0 + for label in labels: + if label == ELLIPSIS: + # Add missing dimensions covered by the ellipsis + num_missing_dim = ell_num_dim - (len(original_sizes) - len(labels) + 1) + for k in range(num_missing_dim): + operand = ops.unsqueeze(operand, j) + for k in range(ell_num_dim): + perm_shape[ell_index + k] = j + j += 1 + elif label_dim[label] != -1: + dim = label_dim[label] + operand = ops.diagonal(operand, offset=0, dim1=dim, dim2=j) + operand = ops.moveaxis(operand, -1, dim) + else: + label_dim[label] = j + perm_shape[label_perm_index[label]] = j + j += 1 + + # Add dimensions for missing labels + for idx, index in enumerate(perm_shape): + if index == -1: + operand = ops.unsqueeze(operand, -1) + perm_shape[idx] = j + j += 1 + + operand = ops.transpose(operand, tuple(perm_shape)) + permuted_operands.append(operand) + + # Check if operands broadcast and keep track of last operand with + # dimension size != 1 for optimizing reductions + dim_last_op = [0] * perm_index + has_zero_size_dim = False + for dim in range(perm_index): + broadcast_size = permuted_operands[0].shape[dim] + for i in range(1, len(operands)): + dim_size = permuted_operands[i].shape[dim] + if broadcast_size != dim_size and broadcast_size != 1 and dim_size != 1: + raise RuntimeError( + "einsum(): operands do not broadcast with remapped shapes [original->remapped]" + ) + if dim_size != 1: + broadcast_size = dim_size + dim_last_op[dim] = i + has_zero_size_dim = has_zero_size_dim or (broadcast_size == 0) + + # Compute result + result = permuted_operands[0] + if has_zero_size_dim: + out_shape = [-1] * out_size + for i in range(out_size): + out_shape[i] = permuted_operands[dim_last_op[i]].shape[i] + return ops.zeros(out_shape) + + # Sum out or squeeze dimensions that are size 1 for all later operands + dim = out_size + for i in range(dim, perm_index): + if dim_last_op[i] == 0: + if result.shape[dim] == 1: + result = ops.squeeze(result, dim) + dim -= 1 + else: + result = ops.sum(result, dim) + dim -= 1 + dim += 1 + + for i in range(1, num_ops): + operand = permuted_operands[i] + sum_dims = [] + + # Sum out or squeeze dimensions that are size 1 for all later operands + dim = out_size + for j in range(dim, perm_index): + if dim_last_op[j] < i: + operand = ops.squeeze(operand, dim) + dim -= 1 + elif dim_last_op[j] == i: + if result.shape[dim] == 1: + operand = ops.sum(operand, dim) + result = ops.squeeze(result, dim) + dim -= 1 + else: + sum_dims.append(dim) + dim += 1 + if len(sum_dims) == 0: + result = result.mul(operand) + elif len(sum_dims) == len(result.shape): + result = result.flatten().dot(operand.flatten()) + else: + result = sumproduct_pair(result, operand, sum_dims, False) + return result + + +# flatten +has_flatten = hasattr(mindspore.mint, "flatten") + + +def flatten(input, start_dim=0, end_dim=-1): + if use_pyboost() and has_flatten: + return mindspore.mint.flatten(input, start_dim, end_dim) + if end_dim < 0: + end_dim = input.ndim + end_dim + new_shape = input.shape[:start_dim] + (-1,) + input.shape[end_dim + 1 :] + return ops.reshape(input, new_shape) + + +# flip +has_flip = hasattr(mindspore.mint, "flip") + + +def flip(input, dims): + if use_pyboost() and has_flip: + return mindspore.mint.flip(input, dims) + return ops.flip(input, dims) + + +# fliplr + + +# flipud + + +# kron + + +# rot90 + + +# gcd + + +# histc +has_histc = hasattr(mindspore.mint, "histc") + + +def histc(input, bins, min, max, *, out=None): + if use_pyboost() and has_histc: + return call_ms_func( + mindspore.mint.histc, input, bins=bins, min=min, max=max, out=out + ) + return call_ms_func(ops.histc, input, bins=bins, min=min, max=max, out=out) + + +# histogram + + +# histogramdd + + +# meshgrid +has_meshgrid = hasattr(mindspore.mint, "meshgrid") + + +def meshgrid(*tensors, indexing=None): + if use_pyboost() and has_meshgrid: + return mindspore.mint.meshgrid(*tensors, indexing) + if isinstance(tensors[0], (list, tuple)): + tensors = tensors[0] + if len(tensors) == 1: + return tensors + if indexing is None: + indexing = "ij" + return ops.meshgrid(*tensors, indexing=indexing) + + +# lcm + + +# logcumsumexp + +# ravel + + +# renorm + + +# repeat_interleave +has_repeat_interleave = hasattr(mindspore.mint, "repeat_interleave") + + +def repeat_interleave(*args, **kwargs): + if use_pyboost() and has_repeat_interleave: + return mindspore.mint.repeat_interleave(*args, **kwargs) + + input, repeats, dim = args.get("input"), args.get("repeats"), args.get("dim") + if input.dtype == mindspore.bool_: + input = input.int() + return input.repeat(repeats, dim).bool() + return input.repeat(repeats, dim) + + +# roll +DEVICE_TARGET = mindspore.get_context("device_target") +has_roll = hasattr(mindspore.mint, "roll") + + +def roll(input, shifts, dims=None): + if use_pyboost() and has_roll: + return mindspore.mint.roll(input, shifts, dims) + if DEVICE_TARGET == "CPU": + return mindspore.numpy.roll(input, shifts, dims) + return ops.roll(input, shifts, dims) + + +# searchsorted +has_searchsorted = hasattr(mindspore.mint, "searchsorted") + + +def searchsorted( + sorted_sequence, + values, + *, + out_int32=False, + right=False, + side=None, + out=None, + sorter=None, +): + if use_pyboost() and has_searchsorted: + return call_ms_func( + mindspore.mint.searchsorted, + sorted_sequence, + values, + out_int32=out_int32, + right=right, + side=side, + out=out, + sorter=sorter, + ) + return call_ms_func( + ops.searchsorted, + sorted_sequence, + values, + out_int32=out_int32, + right=right, + out=out, + ) + + +# tensordot + +# trace + +# tril +has_tril = hasattr(mindspore.mint, "tril") + + +def tril(input, diagonal=0, *, out=None): + if use_pyboost() and has_tril: + return call_ms_func(mindspore.mint.tril, input, diagonal, out=out) + return call_ms_func(ops.tril, input, diagonal, out=out) + + +# tril_indices + +# triu +has_triu = hasattr(mindspore.mint, "triu") + + +def triu(input, diagonal=0, *, out=None): + if use_pyboost() and has_triu: + return call_ms_func(mindspore.mint.triu, input, diagonal, out=out) + return call_ms_func(ops.triu, input, diagonal, out=out) + + +# triu_indices + + +# unflatten +def unflatten(x, dim, sizes): + new_shape = x.shape[:dim] + sizes + return ops.reshape(x, new_shape) + + +# vander + + +# view_as_real + +# view_as_complex + + +# resolve_conj + + +# resolve_neg + + +def masked_fill(input, mask, value): + masked_fill_ = _get_cache_prim(ops.MaskedFill)() + return masked_fill_(input, mask, mindspore.tensor(value, dtype=input.dtype)) + + +class finfo: + def __init__(self, bits, min, max, eps, tiny, smallest_normal, resolution, dtype): + self.bits = bits + self.min = min + self.max = max + self.eps = eps + self.tiny = tiny + self.smallest_normal = smallest_normal + self.resolution = resolution + self.dtype = dtype + + +finfo_dtype = { + mindspore.bfloat16: finfo( + bits=16, + resolution=0.01, + min=-3.38953e38, + max=3.38953e38, + eps=0.0078125, + smallest_normal=1.17549e-38, + tiny=1.17549e-38, + dtype="bfloat16", + ), + mindspore.float16: finfo( + bits=16, + resolution=0.001, + min=-65504, + max=65504, + eps=0.000976562, + smallest_normal=6.10352e-05, + tiny=6.10352e-05, + dtype="float16", + ), + mindspore.float32: finfo( + bits=32, + resolution=1e-06, + min=-3.40282e38, + max=3.40282e38, + eps=1.19209e-07, + smallest_normal=1.17549e-38, + tiny=1.17549e-38, + dtype="float32", + ), + mindspore.float64: finfo( + bits=64, + resolution=1e-15, + min=-1.79769e308, + max=1.79769e308, + eps=2.22045e-16, + smallest_normal=2.22507e-308, + tiny=2.22507e-308, + dtype='float64', + ), +} + + +def finfo(dtype): + return finfo_dtype[dtype] + + +def iinfo(dtype): + return np.iinfo(mindspore.dtype_to_nptype(dtype)) + + +def contains(self, key): + r""" + Args: + self (object): The object instance on which the method is called. + key (object): The key to be checked for containment in the object. + + Returns: + None: This function returns None, indicating whether the key is contained in the object. + + Raises: + None + """ + eq_res = eq(self, key) + res = any(eq_res) + return bool(res) + + +def initialize(self, init_method): + r""" + Initializes the object with the given initialization method. + + Args: + self (object): The instance of the class. + init_method (str): The method used for initialization. + This parameter determines how the data is initialized. + Valid values for `init_method` are: + - "random": Initializes the data with random values. + - "zeros": Initializes the data with zeros. + - "ones": Initializes the data with ones. + Default value is "random". + + Returns: + None. This function does not return any value. + + Raises: + None. + + Note: + This function sets the data of the object using the specified `init_method` and the object's shape and data type. + """ + self.assign_value(initializer(init_method, self.shape, self.dtype)) + + +_stop_gradient = ops.StopGradient() + + +def stop_gradient(input): + return _stop_gradient(input) + + +def _get_unfold_indices(input_shape, dimension, size, step): + if dimension < 0: + dimension += len(input_shape) + indices = [] + for i in range(0, input_shape[dimension] - size + 1, step): + indices.append(list(range(i, i + size))) + + return indices, dimension + + +def unfold(input, dimension, size, step): + _indices, _dimension = _get_unfold_indices(input.shape, dimension, size, step) + indices = mindspore.Tensor(_indices).astype(mindspore.int32) + output = ops.gather(input, indices, axis=_dimension) + output = ops.moveaxis(output, _dimension + 1, -1) + + return output + + +__all__ = [ + "bincount", + "broadcast_shapes", + "broadcast_tensors", + "broadcast_to", + "cdist", + "clone", + "contains", + "cumsum", + "diag", + "dim_list_to_bitset", + "einsum", + "einsum_label_to_index", + "finfo", + "flatten", + "flip", + "iinfo", + "initialize", + "manual_expand", + "masked_fill", + "maybe_wrap_dim", + "meshgrid", + "repeat_interleave", + "roll", + "searchsorted", + "stop_gradient", + "sumproduct_pair", + "tril", + "triu", + "unflatten", + "unfold", + "histc", +] diff --git a/mindnlp/core/ops/pointwise.py b/mindnlp/core/ops/pointwise.py new file mode 100644 index 000000000..530df82d4 --- /dev/null +++ b/mindnlp/core/ops/pointwise.py @@ -0,0 +1,622 @@ +"""pointwise op""" +import mindspore +from mindspore import ops +from ..configs import use_pyboost +from ._inner import call_ms_func + +# abs +has_abs = hasattr(mindspore.mint, 'abs') +def abs(input, *, out=None): + if use_pyboost() and has_abs: + return call_ms_func(mindspore.mint.abs, input, out=out) + return call_ms_func(ops.abs, input, out=out) + +# absolute +def absolute(input, *, out=None): + return abs(input, out=out) + +# acos +has_acos = hasattr(mindspore.mint, 'acos') +def acos(input, *, out=None): + if use_pyboost() and has_acos: + return call_ms_func(mindspore.mint.acos, input, out=out) + return call_ms_func(ops.acos, input, out=out) + +# arccos +def arrcos(input, out=None): + return acos(input, out=out) + +# acosh +has_acosh = hasattr(mindspore.mint, 'acosh') +def acosh(input, *, out=None): + if use_pyboost and has_acosh: + return call_ms_func(mindspore.mint.acosh, input, out=out) + return call_ms_func(ops.acosh, input, out=out) + +# arccosh +has_arccosh = hasattr(mindspore.mint, 'arccosh') +def arccosh(input): + return acosh(input) + +# add +has_add = hasattr(mindspore.mint, 'add') +def add(input, other, *, alpha=1, out=None): + if use_pyboost() and has_add: + return call_ms_func(mindspore.mint.add, input, other, alpha=alpha, out=out) + if alpha != 1: + other = mul(alpha, other) + return call_ms_func(ops.add, input, other, out=out) + +# addcdiv +def addcdiv(input, tensor1, tensor2, *, value=1): + return ops.addcdiv(input, tensor1, tensor2, value) + +# addcmul +def addcmul(input, tensor1, tensor2, *, value=1): + return ops.addcmul(input, tensor1, tensor2, value) + +# angle +def angle(input): + return ops.angle(input) + +# asin +has_asin = hasattr(mindspore.mint, 'asin') +def asin(input, *, out=None): + if use_pyboost and has_asin: + return call_ms_func(mindspore.mint.asin, input, out=out) + return call_ms_func(ops.asin, input, out=out) + +# arcsin +has_arcsin = hasattr(mindspore.mint, 'arcsin') +def arcsin(input, *, out=None): + return asin(input, out=out) + +# asinh +has_asinh = hasattr(mindspore.mint, 'asinh') +def asinh(input, *, out=None): + if use_pyboost and has_asinh: + return call_ms_func(mindspore.mint.asinh, input, out=out) + return call_ms_func(ops.asinh, input, out=out) + +# arcsinh +has_arcsinh = hasattr(mindspore.mint, 'arcsinh') +def arcsinh(input, *, out=None): + return asinh(input, out=out) + +# atan +has_atan = hasattr(mindspore.mint, 'atan') +def atan(input, *, out=None): + if use_pyboost and has_atan: + return call_ms_func(mindspore.mint.atan, input, out=out) + return call_ms_func(ops.atan, input, out=out) + +# arctan +has_arctan = hasattr(mindspore.mint, 'arctan') +def arctan(input, *, out=None): + return atan(input, out=out) + +# atanh +has_atanh = hasattr(mindspore.mint, 'atanh') +def atanh(input, *, out=None): + if use_pyboost and has_atanh: + return call_ms_func(mindspore.mint.atanh, input, out=out) + return call_ms_func(ops.atanh, input, out=out) + +# arctanh +has_arctanh = hasattr(mindspore.mint, 'arctanh') +def arctanh(input, *, out=None): + return atanh(input, out=out) + +# atan2 +has_atan2 = hasattr(mindspore.mint, 'atan2') +def atan2(input, other, *, out=None): + if use_pyboost() and has_atan2: + return call_ms_func(mindspore.mint.atan2, input, other, out=out) + return call_ms_func(ops.atan2, input, other, out=out) + +# arctan2 +has_arctan2 = hasattr(mindspore.mint, 'arctan2') +def arctan2(input, other, out=None): + return atan2(input, other, out=out) + +# bitwise_not + +# bitwise_and +has_bitwise_and = hasattr(mindspore.mint, 'bitwise_and') +def bitwise_and(input, other, *, out=None): + if use_pyboost() and has_bitwise_and: + return call_ms_func(mindspore.mint.bitwise_and, input, other, out=out) + return call_ms_func(ops.bitwise_and, input, other, out=out) + +# bitwise_or +has_bitwise_or = hasattr(mindspore.mint, 'bitwise_or') +def bitwise_or(input, other, *, out=None): + if use_pyboost() and has_bitwise_or: + return call_ms_func(mindspore.mint.bitwise_or, input, other, out=out) + return call_ms_func(ops.bitwise_or, input, other, out=out) + +# bitwise_xor +has_bitwise_xor = hasattr(mindspore.mint, 'bitwise_xor') +def bitwise_xor(input, other, *, out=None): + if use_pyboost() and has_bitwise_xor: + return call_ms_func(mindspore.mint.bitwise_xor, input, other, out=out) + return call_ms_func(ops.bitwise_xor, input, other, out=out) + +# bitwise_left_shift +def bitwise_left_shift(input, other): + return ops.bitwise_left_shift(input, other) + +# bitwise_right_shift +def bitwise_right_shift(input, other): + return ops.bitwise_right_shift(input, other) + +# ceil +has_ceil = hasattr(mindspore.mint, 'ceil') +def ceil(input, *, out=None): + if use_pyboost() and has_ceil: + return call_ms_func(mindspore.mint.ceil, input, out=out) + return call_ms_func(ops.ceil, input, out=out) + +# clamp +has_clamp = hasattr(mindspore.mint, 'clamp') +def clamp(input, min=None, max=None, *, out=None): + if use_pyboost() and has_clamp: + return call_ms_func(mindspore.mint.clamp, input, min, max, out=out) + return call_ms_func(ops.clamp, input, min, max, out=out) + +# clip +has_clip = hasattr(mindspore.mint, 'clip') +def clip(input, min=None, max=None): + return clamp(input, min, max) + +# conj_physical + + +# copysign + + +# cos +has_cos = hasattr(mindspore.mint, 'cos') +def cos(input, *, out=None): + if use_pyboost() and has_cos: + return call_ms_func(mindspore.mint.cos, input, out=out) + return call_ms_func(ops.cos, input, out=out) + +# cosh +has_cosh = hasattr(mindspore.mint, 'cosh') +def cosh(input, *, out=None): + if use_pyboost() and has_cosh: + return call_ms_func(mindspore.mint.cosh, input, out=out) + return call_ms_func(ops.cosh, input, out=out) + +# deg2rad +def deg2rad(input): + return ops.deg2rad(input) + +# div +has_div = hasattr(mindspore.mint, 'div') +def div(input, other, *, rounding_mode=None, out=None): + if use_pyboost() and has_div: + return call_ms_func(mindspore.mint.div, input, other, rounding_mode=rounding_mode, out=out) + return call_ms_func(ops.div, input, other, rounding_mode=rounding_mode, out=out) + +# divide +has_divide = hasattr(mindspore.mint, 'divide') +def divide(input, other, rounding_mode=None): + return div(input, other, rounding_mode=rounding_mode) + +# digamma +def digamma(input): + return ops.digamma(input) + +# erf +has_erf = hasattr(mindspore.mint, 'erf') +def erf(input, *, out=None): + if use_pyboost() and has_erf: + return call_ms_func(mindspore.mint.erf, input, out=out) + return call_ms_func(ops.erf, input, out=out) + +# erfc +has_erfc = hasattr(mindspore.mint, 'erfc') +def erfc(input, *, out=None): + if use_pyboost() and has_erfc: + return call_ms_func(mindspore.mint.erfc, input, out=out) + return call_ms_func(ops.erfc, input, out=out) + +# erfinv +has_erfinv = hasattr(mindspore.mint, 'erfinv') +def erfinv(input, *, out=None): + if use_pyboost() and has_erfinv: + return call_ms_func(mindspore.mint.erfinv, input, out=out) + return call_ms_func(ops.erfinv, input, out=out) + + +# exp +has_exp = hasattr(mindspore.mint, 'exp') +has_inplace_exp = hasattr(mindspore.Tensor, 'exp_') +def exp(input, out=None): + if has_inplace_exp: + return inplace_exp(input, out) + + if use_pyboost() and has_exp: + output = mindspore.mint.exp(input) + else: + output = ops.exp(input) + if out is not None: + # out.data = output + out.assign_value(output) + else: + return output + +def inplace_exp(input, out=None): + if out is None: + if use_pyboost() and has_exp: + output = mindspore.mint.exp(input) + else: + output = ops.exp(input) + return output + + if out is input: + return out.exp_() + else: + out.copy_(input) + return out.exp_() + +# exp2 +has_exp2 = hasattr(mindspore.mint, 'exp2') +def exp2(input): + if use_pyboost() and has_exp2: + return mindspore.mint.exp2(input) + return pow(2, input) + +# expm1 +has_expm1 = hasattr(mindspore.mint, 'expm1') +def expm1(input, *, out=None): + if use_pyboost() and has_expm1: + return call_ms_func(mindspore.mint.expm1, input, out=out) + return call_ms_func(ops.expm1, input, out=out) + +# fake_quantize_per_channel_affine + + +# fake_quantize_per_tensor_affine + + +# fix + + +# float_power +has_float_power = hasattr(mindspore.mint, 'float_power') +def float_power(input, exponent): + if use_pyboost() and has_float_power: + return mindspore.mint.float_power(input, exponent) + return ops.float_power(input, exponent) + +# floor +has_floor = hasattr(mindspore.mint, 'floor') +def floor(input, *, out=None): + if use_pyboost() and has_floor: + return call_ms_func(mindspore.mint.floor, input, out=out) + return call_ms_func(ops.floor, input, out=out) + +# floor_divide +def floor_divide(input, other): + return ops.floor_divide(input, other) + +# fmod +has_fmod = hasattr(mindspore.mint, 'fmod') +def fmod(input, other): + if use_pyboost() and has_fmod: + return mindspore.mint.fmod(input, other) + return ops.fmod(input, other) + +# frac +has_frac = hasattr(mindspore.mint, 'frac') +def frac(input): + if use_pyboost() and has_frac: + return mindspore.mint.frac(input) + return fmod(input, 1) + +# frexp + + +# imag +def imag(input): + return ops.imag(input) + +# ldexp + + +# lerp +has_lerp = hasattr(mindspore.mint, 'lerp') +def lerp(input, end, weight): + if use_pyboost() and has_lerp: + return mindspore.mint.lerp(input, end, weight) + return ops.lerp(input, end, weight) + +# lgamma +def lgamma(input): + return ops.lgamma(input) + +# log +has_log = hasattr(mindspore.mint, 'log') +def log(input, *, out=None): + if use_pyboost() and has_log: + return call_ms_func(mindspore.mint.log, input, out=out) + return call_ms_func(ops.log, input, out=out) + +# log10 + +# log1p +has_log1p = hasattr(mindspore.mint, 'log1p') +def log1p(input, *, out=None): + if use_pyboost() and has_log1p: + return call_ms_func(mindspore.mint.log1p, input, out=out) + return call_ms_func(ops.log1p, input, out=out) + +# log2 +has_log2 = hasattr(mindspore.mint, 'log2') +def log2(input): + if use_pyboost() and has_log2: + return mindspore.mint.log2(input) + return ops.log2(input) + +# logaddexp + + +# logaddexp2 + + +# logical_and +has_logical_and = hasattr(mindspore.mint, 'logical_and') +def logical_and(input, other, *, out=None): + if use_pyboost() and has_logical_and: + return call_ms_func(mindspore.mint.logical_and, input, other, out=out) + return call_ms_func(ops.logical_and, input, other, out=out) + +# logical_not +has_logical_not = hasattr(mindspore.mint, 'logical_not') +def logical_not(input, *, out=None): + if use_pyboost() and has_logical_not: + return call_ms_func(mindspore.mint.logical_not, input, out=out) + return call_ms_func(ops.logical_not, input, out=out) + +# logical_or +has_logical_or = hasattr(mindspore.mint, 'logical_or') +def logical_or(input, other, *, out=None): + if use_pyboost() and has_logical_or: + return call_ms_func(mindspore.mint.logical_or, input, other, out=out) + return call_ms_func(ops.logical_or, input, other, out=out) + +# logical_xor +has_logical_xor = hasattr(mindspore.mint, 'logical_xor') +def logical_xor(input, other, *, out=None): + if use_pyboost() and has_logical_xor: + return call_ms_func(mindspore.mint.logical_xor, input, other, out=out) + return call_ms_func(ops.logical_xor, input, other, out=out) + +# logit +def logit(input, eps=None): + return ops.logit(input, eps) + +# hypot +def hypot(input, other): + return ops.hypot(input, other) + +# i0 + +# igamma +def igamma(input, other): + return ops.igamma(input, other) + +# igammac +def igammac(input, other): + return ops.igammac(input, other) + +# mul +has_mul = hasattr(mindspore.mint, 'mul') +def mul(input, other, *, out=None): + if use_pyboost() and has_mul: + return call_ms_func(mindspore.mint.mul, input, other, out=out) + return call_ms_func(ops.mul, input, other, out=out) + +# multiply +def multiply(input, other): + return mul(input, other) + +# mvlgamma +def mvlgamma(input, p): + return ops.mvlgamma(input, p) + +# nan_to_num +has_nan_to_num = hasattr(mindspore.mint, 'nan_to_num') +def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): + if use_pyboost() and has_nan_to_num: + return call_ms_func(mindspore.mint.nan_to_num, input, nan, posinf, neginf, out=out) + return call_ms_func(ops.nan_to_num, input, nan, posinf, neginf, out=out) + +# neg +has_neg = hasattr(mindspore.mint, 'neg') +def neg(input, *, out=None): + if use_pyboost() and has_neg: + return call_ms_func(mindspore.mint.neg, input, out=out) + return call_ms_func(ops.neg, input, out=out) + +# negative +has_negative = hasattr(mindspore.mint, 'negative') +def negative(input): + return neg(input) + +# nextafter +def nextafter(input, other): + return ops.nextafter(input, other) + +# polygamma +def polygamma(n, input): + return ops.polygamma(n, input) + +# positive +def positive(input): + return input + +# pow +has_pow = hasattr(mindspore.mint, 'pow') +def pow(input, exponent, *, out=None): + if use_pyboost() and has_pow: + return call_ms_func(mindspore.mint.pow, input, exponent, out=out) + return call_ms_func(ops.pow, input, exponent, out=out) + +# quantized_batch_norm + + +# quantized_max_pool1d + + +# quantized_max_pool2d + + +# rad2deg +def rad2deg(input): + return ops.rad2deg(input) + +# real +def real(input): + return ops.real(input) + +# reciprocal +has_reciprocal = hasattr(mindspore.mint, 'reciprocal') +def reciprocal(input, *, out=None): + if use_pyboost() and has_reciprocal: + return call_ms_func(mindspore.mint.reciprocal, input, out=out) + return call_ms_func(ops.reciprocal, input, out=out) + +# remainder +has_remainder = hasattr(mindspore.mint, 'remainder') +def remainder(input, other, *, out=None): + if use_pyboost() and has_remainder: + return call_ms_func(mindspore.mint.remainder, input, other, out=out) + return call_ms_func(ops.remainder, input, other, out=out) + +# round +has_round = hasattr(mindspore.mint, 'round') +def round(input, *, decimals=0): + if use_pyboost() and has_round: + return mindspore.mint.round(input, decimals=decimals) + return ops.round(input, decimals=decimals) + +# rsqrt +has_rsqrt = hasattr(mindspore.mint, 'rsqrt') +def rsqrt(input, *, out=None): + if use_pyboost() and has_rsqrt: + return call_ms_func(mindspore.mint.rsqrt, input, out=out) + return call_ms_func(ops.rsqrt, input, out=out) + +# sigmoid +has_sigmoid = hasattr(mindspore.mint, 'sigmoid') +def sigmoid(input, *, out=None): + if use_pyboost() and has_sigmoid: + return call_ms_func(mindspore.mint.sigmoid, input, out=out) + return call_ms_func(ops.sigmoid, input, out=out) + +# sign +has_sign = hasattr(mindspore.mint, 'sign') +def sign(input, *, out=None): + if use_pyboost() and has_sign: + return call_ms_func(mindspore.mint.sign, input, out=out) + return call_ms_func(ops.sign, input, out=out) + +# sgn + +# signbit + +# sin +has_sin = hasattr(mindspore.mint, 'sin') +def sin(input, *, out=None): + if use_pyboost() and has_sin: + return call_ms_func(mindspore.mint.sin, input, out=out) + return call_ms_func(ops.sin, input, out=out) + +# sinc +has_sinc = hasattr(mindspore.mint, 'sinc') +def sinc(input, *, out=None): + if use_pyboost() and has_sinc: + return call_ms_func(mindspore.mint.sinc, input, out=out) + return call_ms_func(ops.sinc, input, out=out) + +# sinh +has_sinh = hasattr(mindspore.mint, 'sinh') +def sinh(input, *, out=None): + if use_pyboost() and has_sinh: + return call_ms_func(mindspore.mint.sinh, input, out=out) + return call_ms_func(ops.sinh, input, out=out) + +# softmax +def softmax(input, dim, *, dtype=None): + if use_pyboost(): + return mindspore.mint.nn.functional.softmax(input, dim, dtype=dtype) + return ops.softmax(input, dim, dtype=dtype) + +# sqrt +has_sqrt = hasattr(mindspore.mint, 'sqrt') +def sqrt(input, *, out=None): + if use_pyboost() and has_sqrt: + return call_ms_func(mindspore.mint.sqrt, input, out=out) + return call_ms_func(ops.sqrt, input, out=out) + +# square +has_square = hasattr(mindspore.mint, 'square') +def square(input, *, out=None): + if use_pyboost() and has_square: + return call_ms_func(mindspore.mint.square, input, out=out) + return call_ms_func(ops.square, input, out=out) + +# sub +has_sub = hasattr(mindspore.mint, 'sub') +def sub(input, other, *, alpha=1, out=None): + if use_pyboost() and has_sub: + return call_ms_func(mindspore.mint.sub, input, other, alpha=alpha, out=out) + return call_ms_func(ops.sub, input, other, out=out) + +# subtract +def subtract(input, other): + return sub(input, other) + +# tan +has_tan = hasattr(mindspore.mint, 'tan') +def tan(input, *, out=None): + if use_pyboost() and has_tan: + return call_ms_func(mindspore.mint.tan, input, out=out) + return call_ms_func(ops.tan, input, out=out) + +# tanh +has_tanh = hasattr(mindspore.mint, 'tanh') +def tanh(input, *, out=None): + if use_pyboost() and has_tanh: + return call_ms_func(mindspore.mint.tanh, input, out=out) + return call_ms_func(ops.tanh, input, out=out) + +# true_divide +def true_divide(input, other): + return div(input, other) + +# trunc +has_trunc = hasattr(mindspore.mint, 'trunc') +def trunc(input, *, out=None): + if use_pyboost() and has_trunc: + return call_ms_func(mindspore.mint.trunc, input, out=out) + return call_ms_func(ops.trunc, input, out=out) + +# xlogy +has_xlogy = hasattr(mindspore.mint, 'xlogy') +def xlogy(input, other, *, out=None): + if use_pyboost() and has_xlogy: + return call_ms_func(mindspore.mint.xlogy, input, other, out=out) + return call_ms_func(ops.xlogy, input, other, out=out) + +# relu +def relu(input): + if use_pyboost(): + return mindspore.mint.nn.functional.relu(input) + return ops.relu(input) + +__all__ = ['abs', 'absolute', 'acos', 'acosh', 'add', 'addcdiv', 'addcmul', 'angle', 'arccosh', 'arcsin', 'arcsinh', 'arctan', 'arctan2', 'arctanh', 'arrcos', 'asin', 'asinh', 'atan', 'atan2', 'atanh', 'bitwise_and', 'bitwise_left_shift', 'bitwise_or', 'bitwise_right_shift', 'bitwise_xor', 'ceil', 'clamp', 'clip', 'cos', 'cosh', 'deg2rad', 'digamma', 'div', 'divide', 'erf', 'erfc', 'erfinv', 'exp', 'exp2', 'expm1', 'float_power', 'floor', 'floor_divide', 'fmod', 'frac', 'hypot', 'igamma', 'igammac', 'imag', 'lerp', 'lgamma', 'log', 'log1p', 'log2', 'logical_and', 'logical_not', 'logical_or', 'logical_xor', 'logit', 'mul', 'multiply', 'mvlgamma', 'nan_to_num', 'neg', 'negative', 'nextafter', 'polygamma', 'positive', 'pow', 'rad2deg', 'real', 'reciprocal', 'remainder', 'round', 'rsqrt', 'sigmoid', 'sign', 'sin', 'sinc', 'sinh', 'softmax', 'sqrt', 'square', 'sub', 'subtract', 'tan', 'tanh', 'true_divide', 'trunc', 'xlogy', 'relu'] \ No newline at end of file diff --git a/mindnlp/core/ops/random.py b/mindnlp/core/ops/random.py new file mode 100644 index 000000000..f66f58f30 --- /dev/null +++ b/mindnlp/core/ops/random.py @@ -0,0 +1,136 @@ +"""random op""" +import numpy as np +import mindspore +from mindspore import ops +from mindspore.ops._primitive_cache import _get_cache_prim +from ..configs import use_pyboost, DEVICE_TARGET +from .other import cumsum, searchsorted +from .comparison import topk +from .pointwise import div, log +from .._bind import get_default_dtype +from ._inner import call_ms_func + +# bernoulli +has_bernoulli = hasattr(mindspore.mint, 'bernoulli') +def bernoulli(input, *, generator=None, out=None): + if use_pyboost() and has_bernoulli: + return call_ms_func(mindspore.mint.bernoulli, input, generator=generator, out=out) + random_numbers = rand(*input.shape, dtype=mindspore.float32) + samples = random_numbers < 0.5 + samples = samples.int() + if out is None: + return samples + else: + return out.copy_(samples) + +# multinomial +has_multinomial = hasattr(mindspore.mint, 'multinomial') +def multinomial(input, num_samples, replacement=False, *, generator=None): + """custom multinomial""" + if use_pyboost() and has_multinomial: + return mindspore.mint.multinomial(input, num_samples, replacement=replacement, generator=generator) + if replacement: + # with replacement + cumulative_probs = cumsum(input, dim=-1) + uniform_samples = rand(*input.shape[:-1] + (num_samples,)) + if cumulative_probs.dtype == mindspore.float16: + cumulative_probs = cumulative_probs.astype(mindspore.float32) + samples = searchsorted(cumulative_probs, uniform_samples, right=True) + else: + # without replacement + n_dist = 1 + if input.ndim > 1: + n_dist = input.shape[-2] + random_uniform = rand(*(n_dist * input.shape[-1],)) + if n_dist != 1: + random_uniform = random_uniform.reshape(n_dist, input.shape[-1]) + + vals = div(log(random_uniform), input + 1e-10) + _, samples = topk(vals, num_samples) + + return samples.astype(mindspore.int64) + +# normal +has_normal = hasattr(mindspore.mint, 'normal') +def normal(mean=0.0, std=1.0, size=None, *, generator=None, out=None): + if use_pyboost() and has_normal: + return call_ms_func(mindspore.mint.normal, mean, std, size, generator, out=out) + if size is None: + if isinstance(mean, mindspore.Tensor): + size = mean.shape + else: + size = () + return call_ms_func(ops.normal, size, mean, std, out=out) + +# poisson + + +# rand +has_rand = hasattr(mindspore.mint, 'rand') +def rand(*size, generator=None, out=None, dtype=None, device=None, pin_memory=False): + if size[0] == []: + size = () + elif isinstance(size[0], (tuple, list)): + size = size[0] + if dtype is None: + dtype = get_default_dtype() + if use_pyboost() and has_rand: + return call_ms_func(mindspore.mint.rand, *size, generator=generator, dtype=dtype, out=out) + return call_ms_func(ops.rand, *size, dtype=dtype, out=out) + +# rand_like +has_rand_like = hasattr(mindspore.mint, 'rand_like') +def rand_like(input, *, dtype=None): + if use_pyboost() and has_rand_like: + return mindspore.mint.rand_like(input, dtype=dtype) + return ops.rand_like(input, dtype=dtype) + +# randint +has_randint = hasattr(mindspore.mint, 'randint') +def randint(*args, **kwargs): + if use_pyboost() and has_randint: + return mindspore.mint.randint(*args, **kwargs) + return ops.randint(*args, **kwargs) + +# randint_like +def randint_like(*args, **kwargs): + if use_pyboost() and has_randint: + return mindspore.mint.randint_like(*args, **kwargs) + return ops.randint_like(*args, **kwargs) + +# randn +has_randn = hasattr(mindspore.mint, 'randn') +def randn(*size, generator=None, dtype=None, **kwargs): + if dtype is None: + dtype = get_default_dtype() + if use_pyboost() and has_randn: + return mindspore.mint.randn(*size, generator=generator, dtype=dtype) + return ops.randn(*size, dtype=dtype) + +# randn_like +has_randn_like = hasattr(mindspore.mint, 'randn_like') +def randn_like(input, *, dtype=None): + if use_pyboost() and has_randn_like: + return mindspore.mint.randn_like(input, dtype=dtype) + return ops.randn_like(input, dtype=dtype) + +# randperm +has_randperm = hasattr(mindspore.mint, 'randperm') +def randperm(n, *, generator=None, dtype=mindspore.int64): + """randperm""" + if use_pyboost() and has_randperm: + return mindspore.mint.randperm(n, generator=generator, dtype=dtype) + if DEVICE_TARGET == 'CPU': + seed, offset = 0, 0 + randperm_v2_op = _get_cache_prim(ops.RandpermV2)(seed, offset, dtype) + return randperm_v2_op(n) + + randperm_op = _get_cache_prim(ops.Randperm)(max_length=n, dtype=dtype) + return randperm_op(mindspore.tensor([n])) + +def gamma(shape, alpha, beta): + if DEVICE_TARGET != 'Ascend': + return mindspore.tensor(np.random.gamma(alpha, 1/beta, shape)) + return ops.gamma(shape, alpha, beta) + +__all__ = ['bernoulli', 'gamma', 'multinomial', 'normal', 'rand', 'rand_like', 'randint', 'randn', 'randn_like', 'randperm', 'randint_like'] diff --git a/mindnlp/core/ops/reduction.py b/mindnlp/core/ops/reduction.py new file mode 100644 index 000000000..1789c991a --- /dev/null +++ b/mindnlp/core/ops/reduction.py @@ -0,0 +1,216 @@ +"""reduction op""" +import mindspore +from mindspore import ops +from mindspore.ops._primitive_cache import _get_cache_prim +from ..configs import use_pyboost, DEVICE_TARGET + +from ._inner import call_ms_func + +# argmax +has_argmax = hasattr(mindspore.mint, 'argmax') +def argmax(input, dim=None, keepdim=False): + if use_pyboost() and has_argmax: + return mindspore.mint.argmax(input, dim, keepdim) + return ops.argmax(input, dim, keepdim) + +# argmin +has_argmin = hasattr(mindspore.mint, 'argmin') +def argmin(input, dim=None, keepdim=False): + if use_pyboost() and has_argmin: + return mindspore.mint.argmin(input, dim, keepdim) + return ops.argmin(input, dim, keepdim) + +# amax +has_amax = hasattr(mindspore.mint, 'amax') +def amax(input, dim, keepdim=False): + if use_pyboost() and has_amax: + return mindspore.mint.amax(input, dim, keepdim) + return ops.amax(input, dim, keepdim) + +# amin +has_amin = hasattr(mindspore.mint, 'amin') +def amin(input, dim, keepdim=False): + if use_pyboost() and has_amin: + return mindspore.mint.amin(input, dim, keepdim) + return ops.amin(input, dim, keepdim) + +# aminmax +def aminmax(input, *, dim=None, keepdim=False): + if dim is None: + dim = () + return amin(input, dim, keepdim), amax(input, dim, keepdim) + +# all +has_all = hasattr(mindspore.mint, 'all') +def all(input, dim=None, keepdim=False, *, dtype=None): + if use_pyboost() and has_all: + return mindspore.mint.all(input, dim, keepdim).to(input.dtype) + return ops.all(input, dim, keepdim).to(input.dtype) + +# any +has_any = hasattr(mindspore.mint, 'any') +def any(input, dim=None, keepdim=False, *, out=None): + if use_pyboost() and has_any: + if dim is None: + return call_ms_func(mindspore.mint.any, input, out=out) + else: + return call_ms_func(mindspore.mint.any, input, dim, keepdim, out=out) + return call_ms_func(ops.any, input, dim, out=out) + +# max +has_max = hasattr(mindspore.mint, 'max') +def max(*args, **kwargs): + return mindspore.mint.max(*args, **kwargs) + +# min +has_min = hasattr(mindspore.mint, 'min') +def min(*args, **kwargs): + return mindspore.mint.min(*args, **kwargs) + +# dist + + +# logsumexp +has_logsumexp = hasattr(mindspore.mint, 'logsumexp') +def logsumexp(input, dim, keepdim=False): + if use_pyboost() and has_logsumexp: + return mindspore.mint.logsumexp(input, dim, keepdim) + return ops.logsumexp(input, dim, keepdim) + +# mean +has_mean = hasattr(mindspore.mint, 'mean') +def mean(input, dim=None, keepdim=False, *, dtype=None): + if use_pyboost() and has_mean: + return mindspore.mint.mean(input, dim, keepdim, dtype=dtype) + out = ops.mean(input, dim, keepdim) + if dtype is not None: + out = out.astype(dtype) + return out + +# nanmean + + +# median +has_median = hasattr(mindspore.mint, 'median') +def median(input, dim=-1, keepdim=False): + if use_pyboost() and has_median: + return mindspore.mint.median(input, dim, keepdim) + return ops.median(input, dim, keepdim) + +# nanmedian +def nanmedian(input, dim=-1, keepdim=False): + return ops.nanmedian(input, dim, keepdim) + +# mode + + +# norm +has_norm = hasattr(mindspore.mint, 'norm') +def norm(input, p='fro', dim=None, keepdim=False, out=None, dtype=None): + if use_pyboost() and has_norm: + return call_ms_func(mindspore.mint.norm, input, p, dim, keepdim, out=out, dtype=dtype) + return call_ms_func(ops.norm, input, p, dim, keepdim, out=out, dtype=dtype) + +# nansum +has_nansum = hasattr(mindspore.mint, 'nansum') +def nansum(input, dim=None, keepdim=False, *, dtype=None): + if use_pyboost() and has_nansum: + return mindspore.mint.nansum(input, dim, keepdim, dtype=dtype) + return ops.nansum(input, dim, keepdim, dtype=dtype) + +# prod +has_prod = hasattr(mindspore.mint, 'prod') +def prod(input, dim=None, keepdim=False, *, dtype=None): + if use_pyboost() and has_prod: + return mindspore.mint.prod(input, dim, keepdim, dtype=dtype) + return ops.prod(input, dim, keepdim).to(dtype) + +# quantile +def quantile(input, q, dim=None, keepdim=False, *, interpolation='linear'): + return ops.quantile(input, q, dim, keepdim) + +# nanquantile +def nanquantile(input, q, dim=None, keepdim=False, *, interpolation='linear'): + return ops.quantile(input, q, dim, keepdim) + +# std +has_std = hasattr(mindspore.mint, 'std') +def std(input, dim=None, *, correction=1, keepdim=False): + if use_pyboost() and has_std: + return mindspore.mint.std(input, dim=dim, correction=correction, keepdim=keepdim) + if DEVICE_TARGET == 'GPU': + unbiased = bool(correction) + if dim is None: + dim = () + if isinstance(dim, int): + dim = (dim,) + _std = _get_cache_prim(ops.ReduceStd)(dim, unbiased, keepdim) + _std.set_device('CPU') + return _std(input)[0] + return ops.std(input, dim, correction, keepdim) + +# std_mean +has_std_mean = hasattr(mindspore.mint, 'std_mean') +def std_mean(input, dim=None, *, correction=1, keepdim=False): + if use_pyboost and has_std_mean: + return mindspore.mint.std_mean(input, dim=dim, correction=correction, keepdim=keepdim) + return std(input, dim, correction=correction, keepdim=keepdim), \ + mean(input, dim, keepdim) + +# sum +has_sum = hasattr(mindspore.mint, 'sum') +def sum(input, dim=None, keepdim=False, *, dtype=None): + if 0 in input.shape: + return mindspore.tensor(0, dtype=dtype) + if use_pyboost() and has_sum: + return mindspore.mint.sum(input, dim, keepdim, dtype=dtype) + return ops.sum(input, dim, keepdim, dtype=dtype) + +# unique +has_unique = hasattr(mindspore.mint, 'unique') +def unique(input, sorted=True, return_inverse=False, return_counts=False, dim=None): + if use_pyboost() and has_unique: + return mindspore.mint.unique(input, sorted, return_inverse, return_counts, dim) + + out, inverse = ops.unique(input) + outs = (out,) + if return_inverse: + outs += (inverse,) + if return_counts: + counts = (out == input).sum(0, keepdims=True) + outs += (counts,) + return outs if len(outs) > 1 else outs[0] + +# unique_consecutive +has_unique_consecutive = hasattr(mindspore.mint, 'unique_consecutive') +def unique_consecutive(input, return_inverse=False, return_counts=False, dim=None): + if use_pyboost() and has_unique_consecutive: + return mindspore.mint.unique_consecutive(input, return_inverse, return_counts, dim) + return ops.unique_consecutive(input, return_inverse, return_counts, dim) + +# var +has_var = hasattr(mindspore.mint, 'var') +def var(input, dim=None, *, correction=1, keepdim=False): + if use_pyboost and has_var: + return mindspore.mint.var(input, dim=dim, correction=correction, keepdim=keepdim) + return pow(std(input, dim, correction=correction, keepdim=keepdim), 2) + +# var_mean +has_var_mean = hasattr(mindspore.mint, 'var_mean') +def var_mean(input, dim=None, *, correction=1, keepdim=False): + if use_pyboost and has_var_mean: + return mindspore.mint.var_mean(input, dim=dim, correction=correction, keepdim=keepdim) + return pow(std(input, dim, correction=correction, keepdim=keepdim), 2), \ + mean(input, dim, keepdim) + +# count_nonzero +has_count_nonzero = hasattr(mindspore.mint, 'count_nonzero') +def count_nonzero(input, dim=None): + if use_pyboost() and has_count_nonzero: + return mindspore.mint.count_nonzero(input, dim) + if dim is None: + dim = () + return ops.count_nonzero(input, dim) + +__all__ = ['all', 'amax', 'amin', 'aminmax', 'any', 'argmax', 'argmin', 'count_nonzero', 'logsumexp', 'max', 'mean', 'median', 'min', 'nanmedian', 'nanquantile', 'nansum', 'norm', 'prod', 'quantile', 'std', 'std_mean', 'sum', 'unique', 'unique_consecutive', 'var', 'var_mean'] + \ No newline at end of file diff --git a/mindnlp/core/ops/spectral.py b/mindnlp/core/ops/spectral.py new file mode 100644 index 000000000..7e8f6764c --- /dev/null +++ b/mindnlp/core/ops/spectral.py @@ -0,0 +1,28 @@ +"""spectral""" +from mindspore import ops +# stft +def stft(input, n_fft, hop_length=None, win_length=None, + window=None, center=True, pad_mode='reflect', + normalized=False, onesided=None, return_complex=None): + return ops.stft(input, n_fft, hop_length, win_length, window, + center, pad_mode.upper(), normalized, onesided, return_complex) + +# istft + + +# bartlett_window + + +# blackman_window + + +# hamming_window + + +# hann_window +def hann_window(window_length, periodic=True, *, dtype=None): + return ops.hann_window(window_length, periodic, dtype=dtype) + +# kaiser_window + +__all__ = ['hann_window', 'stft'] \ No newline at end of file diff --git a/mindnlp/core/ops/tensor.py b/mindnlp/core/ops/tensor.py new file mode 100644 index 000000000..f204a22e7 --- /dev/null +++ b/mindnlp/core/ops/tensor.py @@ -0,0 +1,17 @@ +"""tensor op""" +import mindspore +from mindspore._c_expression import typing # pylint: disable=no-name-in-module, import-error + +def is_floating_point(input): + return isinstance(input.dtype, typing.Float) + +def is_tensor(input): + return isinstance(input, mindspore.Tensor) + +def numel(input): + return input.numel() + +def as_tensor(data, dtype=None, **kwarg): + return mindspore.Tensor(data, dtype) + +__all__ = ['as_tensor', 'is_floating_point', 'is_tensor', 'numel'] \ No newline at end of file diff --git a/mindnlp/core/ops/utils.py b/mindnlp/core/ops/utils.py new file mode 100644 index 000000000..a176e28e1 --- /dev/null +++ b/mindnlp/core/ops/utils.py @@ -0,0 +1,21 @@ +def sum_to(x, shape): + """Sum elements along axes to output an array of a given shape. + + Args: + x (ndarray): Input array. + shape: + + Returns: + ndarray: Output array of the shape. + """ + if x is None: + return None + ndim = len(shape) + lead = x.dim() - ndim + lead_axis = tuple(range(lead)) + + axis = tuple([i + lead for i, sx in enumerate(shape) if sx == 1]) + y = x.sum(lead_axis + axis, keepdim=True) + if lead > 0: + y = y.squeeze(lead_axis) + return y diff --git a/mindnlp/core/optim/__init__.py b/mindnlp/core/optim/__init__.py new file mode 100644 index 000000000..8b5ea8b7e --- /dev/null +++ b/mindnlp/core/optim/__init__.py @@ -0,0 +1,6 @@ +"""optimizers""" +from .optimizer import Optimizer +from .sgd import SGD +from .adam import Adam +from .adamw import AdamW +from .lr_scheduler import * diff --git a/mindnlp/core/optim/adam.py b/mindnlp/core/optim/adam.py new file mode 100644 index 000000000..bc935d0b8 --- /dev/null +++ b/mindnlp/core/optim/adam.py @@ -0,0 +1,135 @@ +"""adam""" +# pylint: disable=unneeded-not, use-dict-literal +# mypy: allow-untyped-defs +from typing import Tuple, Union + +from mindnlp import core +from mindnlp.core import Tensor +from .. import ops +from .optimizer import ( + _get_scalar_dtype, + Optimizer, + ParamsT, +) + +__all__ = ["Adam"] + + + +class Adam(Optimizer): + def __init__( + self, + params: ParamsT, + lr: Union[float, Tensor] = 1e-3, + betas: Tuple[float, float] = (0.9, 0.999), + eps: float = 1e-8, + weight_decay: float = 0, + amsgrad: bool = False, + *, + maximize: bool = False, + ): + if not 0.0 <= lr: + raise ValueError(f"Invalid learning rate: {lr}") + if not 0.0 <= eps: + raise ValueError(f"Invalid epsilon value: {eps}") + if not 0.0 <= betas[0] < 1.0: + raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}") + if not 0.0 <= betas[1] < 1.0: + raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}") + if not 0.0 <= weight_decay: + raise ValueError(f"Invalid weight_decay value: {weight_decay}") + + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + amsgrad=amsgrad, + maximize=maximize, + ) + super().__init__(params, defaults) + + + def __setstate__(self, state): + super().__setstate__(state) + for group in self.param_groups: + group.setdefault("amsgrad", False) + group.setdefault("maximize", False) + fused = group.setdefault("fused", None) + for p in group["params"]: + p_state = self.state.get(p, []) + if len(p_state) != 0 and not ops.is_tensor(p_state["step"]): + step_val = float(p_state["step"]) + p_state["step"] = ( + core.tensor( + step_val, + dtype=_get_scalar_dtype(is_fused=fused), + ) + if group["capturable"] or group["fused"] + else core.tensor(step_val, dtype=_get_scalar_dtype()) + ) + + def step(self, grads=None): + """Perform a single optimization step. + + Args: + closure (Callable, optional): A closure that reevaluates the model + and returns the loss. + """ + + loss = None + + for group in self.param_groups: + amsgrad = group['amsgrad'] + maximize = group["maximize"] + + for p in group['params']: + grad = p.grad if not maximize else -p.grad + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = ops.zeros_like(p) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = ops.zeros_like(p) + if amsgrad: + # Maintains max of all exp. moving avg. of sq. grad. values + state['max_exp_avg_sq'] = ops.zeros_like(p) + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + if amsgrad: + max_exp_avg_sq = state['max_exp_avg_sq'] + beta1, beta2 = group['betas'] + + state['step'] += 1 + + if group['weight_decay'] != 0: + grad = grad.add(p, alpha=group['weight_decay']) + # # Decay the first and second moment running average coefficient + # exp_avg.mul_(beta1).add_(grad, 1 - beta1) + # exp_avg_sq.mul_(beta2).addcmul_(grad, grad, 1 - beta2) + # if amsgrad: + # # Maintains the maximum of all 2nd moment running avg. till now + # core.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) + # # Use the max. for normalizing running avg. of gradient + # denom = max_exp_avg_sq.sqrt().add_(group['eps']) + # else: + # denom = exp_avg_sq.sqrt().add_(group['eps']) + + # bias_correction1 = 1 - beta1 ** state['step'] + # bias_correction2 = 1 - beta2 ** state['step'] + # step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 + + # p.addcdiv_(exp_avg, denom, -step_size) + + beta1_power = beta1 ** state['step'] + beta2_power = beta2 ** state['step'] + if amsgrad: + ops.optim.raw_adam_amsgrad(p, exp_avg, exp_avg_sq, max_exp_avg_sq, + beta1_power, beta2_power, group['lr'], beta1, beta2, group['eps'], grad) + else: + ops.optim.raw_adam(p, exp_avg, exp_avg_sq, beta1_power, beta2_power, + group['lr'], beta1, beta2, group['eps'], grad) diff --git a/mindnlp/core/optim/adamw.py b/mindnlp/core/optim/adamw.py new file mode 100644 index 000000000..bdc7594b8 --- /dev/null +++ b/mindnlp/core/optim/adamw.py @@ -0,0 +1,116 @@ +"""adamw optimizer""" +# pylint: disable=unneeded-not, use-dict-literal +# mypy: allow-untyped-defs +from typing import Tuple, Union + +from mindnlp import core +from mindnlp.core import Tensor +from .. import ops +from .optimizer import ( + _get_scalar_dtype, + Optimizer, + ParamsT, +) + +__all__ = ["AdamW"] + + + +class AdamW(Optimizer): + def __init__( + self, + params: ParamsT, + lr: Union[float, Tensor] = 1e-3, + betas: Tuple[float, float] = (0.9, 0.999), + eps: float = 1e-8, + weight_decay: float = 1e-2, + amsgrad: bool = False, + *, + maximize: bool = False, + ): + if not 0.0 <= lr: + raise ValueError(f"Invalid learning rate: {lr}") + if not 0.0 <= eps: + raise ValueError(f"Invalid epsilon value: {eps}") + if not 0.0 <= betas[0] < 1.0: + raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}") + if not 0.0 <= betas[1] < 1.0: + raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}") + if not 0.0 <= weight_decay: + raise ValueError(f"Invalid weight_decay value: {weight_decay}") + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + amsgrad=amsgrad, + maximize=maximize, + ) + super().__init__(params, defaults) + + def __setstate__(self, state): + super().__setstate__(state) + for group in self.param_groups: + group.setdefault("amsgrad", False) + group.setdefault("maximize", False) + group.setdefault("foreach", None) + group.setdefault("capturable", False) + group.setdefault("differentiable", False) + fused = group.setdefault("fused", None) + for p in group["params"]: + p_state = self.state.get(p, []) + if len(p_state) != 0 and not ops.is_tensor(p_state["step"]): + step_val = float(p_state["step"]) + p_state["step"] = ( + core.tensor( + step_val, + dtype=_get_scalar_dtype(is_fused=fused), + ) + if group["capturable"] or group["fused"] + else core.tensor(step_val, dtype=_get_scalar_dtype()) + ) + + def step(self, grads=None): + """Perform a single optimization step. + + Args: + closure (Callable, optional): A closure that reevaluates the model + and returns the loss. + """ + + for group in self.param_groups: + amsgrad = group['amsgrad'] + maximize = group["maximize"] + for p in group['params']: + grad = p.grad if not maximize else -p.grad + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = ops.zeros_like(p) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = ops.zeros_like(p) + if amsgrad: + # Maintains max of all exp. moving avg. of sq. grad. values + state['max_exp_avg_sq'] = ops.zeros_like(p) + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + if amsgrad: + max_exp_avg_sq = state['max_exp_avg_sq'] + beta1, beta2 = group['betas'] + + state['step'] += 1 + + if group['weight_decay'] != 0: + ops.assign(p, (1 - group['lr'] * group['weight_decay']) * p) + beta1_power = beta1 ** state['step'] + beta2_power = beta2 ** state['step'] + if amsgrad: + ops.optim.raw_adam_amsgrad(p, exp_avg, exp_avg_sq, max_exp_avg_sq, + beta1_power, beta2_power, group['lr'], beta1, beta2, group['eps'], grad) + else: + ops.optim.raw_adam(p, exp_avg, exp_avg_sq, beta1_power, beta2_power, + group['lr'], beta1, beta2, group['eps'], grad) diff --git a/mindnlp/core/optim/lr_scheduler.py b/mindnlp/core/optim/lr_scheduler.py new file mode 100644 index 000000000..36ecdbe45 --- /dev/null +++ b/mindnlp/core/optim/lr_scheduler.py @@ -0,0 +1,2166 @@ +"""lr scheduler""" +# mypy: allow-untyped-defs +import math +import types +import warnings +from bisect import bisect_right +from collections import Counter +from functools import partial +from typing import ( + Any, + Callable, + cast, + Dict, + Iterable, + List, + Literal, + Optional, + Sequence, + SupportsFloat, + TypedDict, + Union, +) +from weakref import ref + +from mindnlp.core import Tensor + +from .optimizer import Optimizer + +__all__ = [ + "LambdaLR", + "MultiplicativeLR", + "StepLR", + "MultiStepLR", + "ConstantLR", + "LinearLR", + "ExponentialLR", + "SequentialLR", + "CosineAnnealingLR", + "ChainedScheduler", + "ReduceLROnPlateau", + "CyclicLR", + "CosineAnnealingWarmRestarts", + "OneCycleLR", + "PolynomialLR", + "LRScheduler", +] + +EPOCH_DEPRECATION_WARNING = ( + "The epoch parameter in `scheduler.step()` was not necessary and is being " + "deprecated where possible. Please use `scheduler.step()` to step the " + "scheduler. During the deprecation, if epoch is different from None, the " + "closed form is used instead of the new chainable form, where available. " + "Please open an issue if you are unable to replicate your use case: " + "https://github.com/pytorch/pytorch/issues/new/choose." +) + +inf = float('inf') + +def _check_verbose_deprecated_warning(verbose): + """Raises a warning when verbose is not the default value.""" + if verbose != "deprecated": + warnings.warn( + "The verbose parameter is deprecated. Please use get_last_lr() " + "to access the learning rate.", + UserWarning, + ) + return verbose + return False + + +def _format_param(name: str, optimizer: Optimizer, param): + """Return correctly formatted lr/momentum for each param group.""" + + def _copy(_param): + return _param.clone() if isinstance(_param, Tensor) else _param + + if isinstance(param, (list, tuple)): + if len(param) != len(optimizer.param_groups): + raise ValueError( + f"{name} must have the same length as optimizer.param_groups. " + f"{name} has {len(param)} values, param_groups has {len(optimizer.param_groups)}." + ) + else: + param = [param] * len(optimizer.param_groups) + + return list(map(_copy, param)) + + +class LRScheduler: + _get_lr_called_within_step: bool = False + + def __init__(self, optimizer: Optimizer, last_epoch=-1, verbose="deprecated"): + # Attach optimizer + if not isinstance(optimizer, Optimizer): + raise TypeError(f"{type(optimizer).__name__} is not an Optimizer") + self.optimizer = optimizer + + # Initialize epoch and base learning rates + if last_epoch == -1: + for group in optimizer.param_groups: + initial_lr = group["lr"] + if isinstance(initial_lr, Tensor): + initial_lr = initial_lr.clone() + group.setdefault("initial_lr", initial_lr) + else: + for i, group in enumerate(optimizer.param_groups): + if "initial_lr" not in group: + raise KeyError( + "param 'initial_lr' is not specified " + f"in param_groups[{i}] when resuming an optimizer" + ) + self.base_lrs: List[float] = [ + group["initial_lr"] for group in optimizer.param_groups + ] + self.last_epoch = last_epoch + + # Following https://github.com/pytorch/pytorch/issues/20124 + # We would like to ensure that `lr_scheduler.step()` is called after + # `optimizer.step()` + def patch_track_step_called(opt: Optimizer): + if hasattr(opt.step, "_wrapped_by_lr_sched"): + # we've already patched + return opt.step + + def wrap_step(step_fn): + opt_ref = ref(self.optimizer) + func = step_fn.__func__ + + def wrapper(*args, **kwargs): + opt = opt_ref() + opt._opt_called = True # type: ignore[union-attr] + return func.__get__(opt, opt.__class__)(*args, **kwargs) + + wrapper._wrapped_by_lr_sched = True # type: ignore[attr-defined] + return wrapper + + opt.step = wrap_step(opt.step) # type: ignore[method-assign] + + patch_track_step_called(self.optimizer) + self.verbose = _check_verbose_deprecated_warning(verbose) + self._initial_step() + + def _initial_step(self): + """Initialize step counts and performs a step""" + self._step_count = 0 + self.step() + + def state_dict(self): + """Returns the state of the scheduler as a :class:`dict`. + + It contains an entry for every variable in self.__dict__ which + is not the optimizer. + """ + return { + key: value for key, value in self.__dict__.items() if key != "optimizer" + } + + def load_state_dict(self, state_dict: Dict[str, Any]): + """Loads the schedulers state. + + Args: + state_dict (dict): scheduler state. Should be an object returned + from a call to :meth:`state_dict`. + """ + self.__dict__.update(state_dict) + + def get_last_lr(self) -> List[float]: + """Return last computed learning rate by current scheduler.""" + return self._last_lr + + def get_lr(self) -> List[float]: + # Compute learning rate using chainable form of the scheduler + raise NotImplementedError + + def print_lr( + self, + is_verbose: bool, + group: Dict[str, Any], + lr: float, + epoch: Optional[int] = None, + ): + """Display the current learning rate. + + .. deprecated:: 2.4 + ``print_lr()`` is deprecated. Please use ``get_last_lr()`` to access the + learning rate. + """ + warnings.warn( + "`LRScheduler.print_lr()` is being deprecated. To fetch the learning rate, " + "please use `get_last_lr()` instead. For more details, " + "see https://github.com/pytorch/pytorch/issues/99270.", + UserWarning, + ) + if is_verbose: + if epoch is None: + print(f"Adjusting learning rate of group {group} to {lr:.4e}.") + else: + epoch_str = ("%.2f" if isinstance(epoch, float) else "%.5d") % epoch + print( + f"Epoch {epoch_str}: adjusting learning rate of group {group} to {lr:.4e}." + ) + + def step(self, epoch: Optional[int] = None): + # Raise a warning if old pattern is detected + # https://github.com/pytorch/pytorch/issues/20124 + if self._step_count == 1: + if not hasattr(self.optimizer.step, "_wrapped_by_lr_sched"): + warnings.warn( + "Seems like `optimizer.step()` has been overridden after learning rate scheduler " + "initialization. Please, make sure to call `optimizer.step()` before " + "`lr_scheduler.step()`. See more details at " + "https://pycore.org/docs/stable/optim.html#how-to-adjust-learning-rate", + UserWarning, + ) + + # Just check if there were two first lr_scheduler.step() calls before optimizer.step() + elif not getattr(self.optimizer, "_opt_called", False): + warnings.warn( + "Detected call of `lr_scheduler.step()` before `optimizer.step()`. " + "In PyTorch 1.1.0 and later, you should call them in the opposite order: " + "`optimizer.step()` before `lr_scheduler.step()`. Failure to do this " + "will result in PyTorch skipping the first value of the learning rate schedule. " + "See more details at " + "https://pycore.org/docs/stable/optim.html#how-to-adjust-learning-rate", + UserWarning, + ) + self._step_count += 1 + + with _enable_get_lr_call(self): + if epoch is None: + self.last_epoch += 1 + values = self.get_lr() + else: + warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning) + self.last_epoch = epoch + if hasattr(self, "_get_closed_form_lr"): + values = cast(List[float], self._get_closed_form_lr()) + else: + values = self.get_lr() + + for i, data in enumerate(zip(self.optimizer.param_groups, values)): + param_group, lr = data + if isinstance(param_group["lr"], Tensor): + lr_val = lr.item() if isinstance(lr, Tensor) else lr # type: ignore[attr-defined] + param_group["lr"].fill_(lr_val) + else: + param_group["lr"] = lr + + self._last_lr: List[float] = [ + group["lr"] for group in self.optimizer.param_groups + ] + + +def _warn_get_lr_called_within_step(lr_scheduler: LRScheduler): + if not lr_scheduler._get_lr_called_within_step: + warnings.warn( + "To get the last learning rate computed by the scheduler, " + "please use `get_last_lr()`.", + UserWarning, + stacklevel=2, + ) + + +# Including _LRScheduler for backwards compatibility +# Subclass instead of assign because we want __name__ of _LRScheduler to be _LRScheduler (assigning would make it LRScheduler). +class _LRScheduler(LRScheduler): + pass + + +class _enable_get_lr_call: + def __init__(self, o: LRScheduler): + self.o = o + + def __enter__(self): + self.o._get_lr_called_within_step = True + return self + + def __exit__(self, type, value, traceback): + self.o._get_lr_called_within_step = False + + + +class LambdaLR(LRScheduler): + """Sets the learning rate of each parameter group to the initial lr + times a given function. When last_epoch=-1, sets initial lr as lr. + + Args: + optimizer (Optimizer): Wrapped optimizer. + lr_lambda (function or list): A function which computes a multiplicative + factor given an integer parameter epoch, or a list of such + functions, one for each group in optimizer.param_groups. + last_epoch (int): The index of last epoch. Default: -1. + verbose (bool | str): If ``True``, prints a message to stdout for + each update. Default: ``False``. + + .. deprecated:: 2.2 + ``verbose`` is deprecated. Please use ``get_last_lr()`` to access the + learning rate. + + Example: + >>> # xdoctest: +SKIP + >>> # Assuming optimizer has two groups. + >>> lambda1 = lambda epoch: epoch // 30 + >>> lambda2 = lambda epoch: 0.95 ** epoch + >>> scheduler = LambdaLR(optimizer, lr_lambda=[lambda1, lambda2]) + >>> for epoch in range(100): + >>> train(...) + >>> validate(...) + >>> scheduler.step() + """ + + def __init__( + self, + optimizer: Optimizer, + lr_lambda: Union[Callable[[int], float], List[Callable[[int], float]]], + last_epoch=-1, + verbose="deprecated", + ): + self.optimizer = optimizer + + self.lr_lambdas: List[Callable[[int], float]] + if not isinstance(lr_lambda, list) and not isinstance(lr_lambda, tuple): + self.lr_lambdas = [lr_lambda] * len(optimizer.param_groups) + else: + if len(lr_lambda) != len(optimizer.param_groups): + raise ValueError( + f"Expected {len(optimizer.param_groups)} lr_lambdas, but got {len(lr_lambda)}" + ) + self.lr_lambdas = list(lr_lambda) + super().__init__(optimizer, last_epoch, verbose) + + + def state_dict(self): + """Returns the state of the scheduler as a :class:`dict`. + + It contains an entry for every variable in self.__dict__ which + is not the optimizer. + The learning rate lambda functions will only be saved if they are callable objects + and not if they are functions or lambdas. + + When saving or loading the scheduler, please make sure to also save or load the state of the optimizer. + """ + + state_dict = { + key: value + for key, value in self.__dict__.items() + if key not in ("optimizer", "lr_lambdas") + } + state_dict["lr_lambdas"] = [None] * len(self.lr_lambdas) + + for idx, fn in enumerate(self.lr_lambdas): + if not isinstance(fn, types.FunctionType): + state_dict["lr_lambdas"][idx] = fn.__dict__.copy() + + return state_dict + + + + def load_state_dict(self, state_dict): + """Loads the schedulers state. + + When saving or loading the scheduler, please make sure to also save or load the state of the optimizer. + + Args: + state_dict (dict): scheduler state. Should be an object returned + from a call to :meth:`state_dict`. + """ + + lr_lambdas = state_dict.pop("lr_lambdas") + self.__dict__.update(state_dict) + # Restore state_dict keys in order to prevent side effects + # https://github.com/pytorch/pytorch/issues/32756 + state_dict["lr_lambdas"] = lr_lambdas + + for idx, fn in enumerate(lr_lambdas): + if fn is not None: + self.lr_lambdas[idx].__dict__.update(fn) + + + def get_lr(self): + _warn_get_lr_called_within_step(self) + + return [ + base_lr * lmbda(self.last_epoch) + for lmbda, base_lr in zip(self.lr_lambdas, self.base_lrs) + ] + + + + +class MultiplicativeLR(LRScheduler): + """Multiply the learning rate of each parameter group by the factor given + in the specified function. When last_epoch=-1, sets initial lr as lr. + + Args: + optimizer (Optimizer): Wrapped optimizer. + lr_lambda (function or list): A function which computes a multiplicative + factor given an integer parameter epoch, or a list of such + functions, one for each group in optimizer.param_groups. + last_epoch (int): The index of last epoch. Default: -1. + verbose (bool | str): If ``True``, prints a message to stdout for + each update. Default: ``False``. + + .. deprecated:: 2.2 + ``verbose`` is deprecated. Please use ``get_last_lr()`` to access the + learning rate. + + Example: + >>> # xdoctest: +SKIP + >>> lmbda = lambda epoch: 0.95 + >>> scheduler = MultiplicativeLR(optimizer, lr_lambda=lmbda) + >>> for epoch in range(100): + >>> train(...) + >>> validate(...) + >>> scheduler.step() + """ + + def __init__( + self, + optimizer: Optimizer, + lr_lambda: Union[Callable[[int], float], List[Callable[[int], float]]], + last_epoch=-1, + verbose="deprecated", + ): + self.optimizer = optimizer + + self.lr_lambdas: List[Callable[[int], float]] + if not isinstance(lr_lambda, list) and not isinstance(lr_lambda, tuple): + self.lr_lambdas = [lr_lambda] * len(optimizer.param_groups) + else: + if len(lr_lambda) != len(optimizer.param_groups): + raise ValueError( + f"Expected {len(optimizer.param_groups)} lr_lambdas, but got {len(lr_lambda)}" + ) + self.lr_lambdas = list(lr_lambda) + super().__init__(optimizer, last_epoch, verbose) + + + def state_dict(self): + """Returns the state of the scheduler as a :class:`dict`. + + It contains an entry for every variable in self.__dict__ which + is not the optimizer. + The learning rate lambda functions will only be saved if they are callable objects + and not if they are functions or lambdas. + """ + state_dict = { + key: value + for key, value in self.__dict__.items() + if key not in ("optimizer", "lr_lambdas") + } + state_dict["lr_lambdas"] = [None] * len(self.lr_lambdas) + + for idx, fn in enumerate(self.lr_lambdas): + if not isinstance(fn, types.FunctionType): + state_dict["lr_lambdas"][idx] = fn.__dict__.copy() + + return state_dict + + + + def load_state_dict(self, state_dict): + """Loads the schedulers state. + + Args: + state_dict (dict): scheduler state. Should be an object returned + from a call to :meth:`state_dict`. + """ + lr_lambdas = state_dict.pop("lr_lambdas") + self.__dict__.update(state_dict) + # Restore state_dict keys in order to prevent side effects + # https://github.com/pytorch/pytorch/issues/32756 + state_dict["lr_lambdas"] = lr_lambdas + + for idx, fn in enumerate(lr_lambdas): + if fn is not None: + self.lr_lambdas[idx].__dict__.update(fn) + + + def get_lr(self): + _warn_get_lr_called_within_step(self) + + if self.last_epoch > 0: + return [ + group["lr"] * lmbda(self.last_epoch) + for lmbda, group in zip(self.lr_lambdas, self.optimizer.param_groups) + ] + else: + return [group["lr"] for group in self.optimizer.param_groups] + + + + +class StepLR(LRScheduler): + """Decays the learning rate of each parameter group by gamma every + step_size epochs. Notice that such decay can happen simultaneously with + other changes to the learning rate from outside this scheduler. When + last_epoch=-1, sets initial lr as lr. + + Args: + optimizer (Optimizer): Wrapped optimizer. + step_size (int): Period of learning rate decay. + gamma (float): Multiplicative factor of learning rate decay. + Default: 0.1. + last_epoch (int): The index of last epoch. Default: -1. + verbose (bool | str): If ``True``, prints a message to stdout for + each update. Default: ``False``. + + .. deprecated:: 2.2 + ``verbose`` is deprecated. Please use ``get_last_lr()`` to access the + learning rate. + + Example: + >>> # xdoctest: +SKIP + >>> # Assuming optimizer uses lr = 0.05 for all groups + >>> # lr = 0.05 if epoch < 30 + >>> # lr = 0.005 if 30 <= epoch < 60 + >>> # lr = 0.0005 if 60 <= epoch < 90 + >>> # ... + >>> scheduler = StepLR(optimizer, step_size=30, gamma=0.1) + >>> for epoch in range(100): + >>> train(...) + >>> validate(...) + >>> scheduler.step() + """ + + def __init__( + self, + optimizer: Optimizer, + step_size: int, + gamma=0.1, + last_epoch=-1, + verbose="deprecated", + ): + self.step_size = step_size + self.gamma = gamma + super().__init__(optimizer, last_epoch, verbose) + + def get_lr(self): + _warn_get_lr_called_within_step(self) + + if (self.last_epoch == 0) or (self.last_epoch % self.step_size != 0): + return [group["lr"] for group in self.optimizer.param_groups] + return [group["lr"] * self.gamma for group in self.optimizer.param_groups] + + def _get_closed_form_lr(self): + return [ + base_lr * self.gamma ** (self.last_epoch // self.step_size) + for base_lr in self.base_lrs + ] + + + + +class MultiStepLR(LRScheduler): + """Decays the learning rate of each parameter group by gamma once the + number of epoch reaches one of the milestones. Notice that such decay can + happen simultaneously with other changes to the learning rate from outside + this scheduler. When last_epoch=-1, sets initial lr as lr. + + Args: + optimizer (Optimizer): Wrapped optimizer. + milestones (list): List of epoch indices. Must be increasing. + gamma (float): Multiplicative factor of learning rate decay. + Default: 0.1. + last_epoch (int): The index of last epoch. Default: -1. + verbose (bool | str): If ``True``, prints a message to stdout for + each update. Default: ``False``. + + .. deprecated:: 2.2 + ``verbose`` is deprecated. Please use ``get_last_lr()`` to access the + learning rate. + + Example: + >>> # xdoctest: +SKIP + >>> # Assuming optimizer uses lr = 0.05 for all groups + >>> # lr = 0.05 if epoch < 30 + >>> # lr = 0.005 if 30 <= epoch < 80 + >>> # lr = 0.0005 if epoch >= 80 + >>> scheduler = MultiStepLR(optimizer, milestones=[30,80], gamma=0.1) + >>> for epoch in range(100): + >>> train(...) + >>> validate(...) + >>> scheduler.step() + """ + + def __init__( + self, + optimizer: Optimizer, + milestones: Iterable[int], + gamma=0.1, + last_epoch=-1, + verbose="deprecated", + ): + self.milestones = Counter(milestones) + self.gamma = gamma + super().__init__(optimizer, last_epoch, verbose) + + def get_lr(self): + _warn_get_lr_called_within_step(self) + + if self.last_epoch not in self.milestones: + return [group["lr"] for group in self.optimizer.param_groups] + return [ + group["lr"] * self.gamma ** self.milestones[self.last_epoch] + for group in self.optimizer.param_groups + ] + + def _get_closed_form_lr(self): + milestones = sorted(self.milestones.elements()) + return [ + base_lr * self.gamma ** bisect_right(milestones, self.last_epoch) + for base_lr in self.base_lrs + ] + + + + +class ConstantLR(LRScheduler): + """Multiply the learning rate of each parameter group by a small constant factor until the + number of epoch reaches a pre-defined milestone: total_iters. + Notice that such multiplication of the small constant factor can + happen simultaneously with other changes to the learning rate from outside this scheduler. + When last_epoch=-1, sets initial lr as lr. + + Args: + optimizer (Optimizer): Wrapped optimizer. + factor (float): The number we multiply learning rate until the milestone. Default: 1./3. + total_iters (int): The number of steps that the scheduler multiplies the learning rate by the factor. + Default: 5. + last_epoch (int): The index of the last epoch. Default: -1. + verbose (bool | str): If ``True``, prints a message to stdout for + each update. Default: ``False``. + + .. deprecated:: 2.2 + ``verbose`` is deprecated. Please use ``get_last_lr()`` to access the + learning rate. + + Example: + >>> # xdoctest: +SKIP + >>> # Assuming optimizer uses lr = 0.05 for all groups + >>> # lr = 0.025 if epoch == 0 + >>> # lr = 0.025 if epoch == 1 + >>> # lr = 0.025 if epoch == 2 + >>> # lr = 0.025 if epoch == 3 + >>> # lr = 0.05 if epoch >= 4 + >>> scheduler = ConstantLR(optimizer, factor=0.5, total_iters=4) + >>> for epoch in range(100): + >>> train(...) + >>> validate(...) + >>> scheduler.step() + """ + + def __init__( + self, + optimizer: Optimizer, + factor=1.0 / 3, + total_iters=5, + last_epoch=-1, + verbose="deprecated", + ): + if factor > 1.0 or factor < 0: + raise ValueError( + "Constant multiplicative factor expected to be between 0 and 1." + ) + + self.factor = factor + self.total_iters = total_iters + super().__init__(optimizer, last_epoch, verbose) + + def get_lr(self): + _warn_get_lr_called_within_step(self) + + if self.last_epoch == 0: + return [group["lr"] * self.factor for group in self.optimizer.param_groups] + + if self.last_epoch != self.total_iters: + return [group["lr"] for group in self.optimizer.param_groups] + + return [ + group["lr"] * (1.0 / self.factor) for group in self.optimizer.param_groups + ] + + def _get_closed_form_lr(self): + return [ + base_lr + * (self.factor + (self.last_epoch >= self.total_iters) * (1 - self.factor)) + for base_lr in self.base_lrs + ] + + + + +class LinearLR(LRScheduler): + """Decays the learning rate of each parameter group by linearly changing small + multiplicative factor until the number of epoch reaches a pre-defined milestone: total_iters. + Notice that such decay can happen simultaneously with other changes to the learning rate + from outside this scheduler. When last_epoch=-1, sets initial lr as lr. + + Args: + optimizer (Optimizer): Wrapped optimizer. + start_factor (float): The number we multiply learning rate in the first epoch. + The multiplication factor changes towards end_factor in the following epochs. + Default: 1./3. + end_factor (float): The number we multiply learning rate at the end of linear changing + process. Default: 1.0. + total_iters (int): The number of iterations that multiplicative factor reaches to 1. + Default: 5. + last_epoch (int): The index of the last epoch. Default: -1. + verbose (bool | str): If ``True``, prints a message to stdout for + each update. Default: ``False``. + + .. deprecated:: 2.2 + ``verbose`` is deprecated. Please use ``get_last_lr()`` to access the + learning rate. + + Example: + >>> # xdoctest: +SKIP + >>> # Assuming optimizer uses lr = 0.05 for all groups + >>> # lr = 0.025 if epoch == 0 + >>> # lr = 0.03125 if epoch == 1 + >>> # lr = 0.0375 if epoch == 2 + >>> # lr = 0.04375 if epoch == 3 + >>> # lr = 0.05 if epoch >= 4 + >>> scheduler = LinearLR(optimizer, start_factor=0.5, total_iters=4) + >>> for epoch in range(100): + >>> train(...) + >>> validate(...) + >>> scheduler.step() + """ + + def __init__( + self, + optimizer: Optimizer, + start_factor=1.0 / 3, + end_factor=1.0, + total_iters=5, + last_epoch=-1, + verbose="deprecated", + ): + if start_factor > 1.0 or start_factor <= 0: + raise ValueError( + "Starting multiplicative factor expected to be greater than 0 and less or equal to 1." + ) + + if end_factor > 1.0 or end_factor < 0: + raise ValueError( + "Ending multiplicative factor expected to be between 0 and 1." + ) + + self.start_factor = start_factor + self.end_factor = end_factor + self.total_iters = total_iters + super().__init__(optimizer, last_epoch, verbose) + + def get_lr(self): + _warn_get_lr_called_within_step(self) + + if self.last_epoch == 0: + return [ + group["lr"] * self.start_factor for group in self.optimizer.param_groups + ] + + if self.last_epoch > self.total_iters: + return [group["lr"] for group in self.optimizer.param_groups] + + return [ + group["lr"] + * ( + 1.0 + + (self.end_factor - self.start_factor) + / ( + self.total_iters * self.start_factor + + (self.last_epoch - 1) * (self.end_factor - self.start_factor) + ) + ) + for group in self.optimizer.param_groups + ] + + def _get_closed_form_lr(self): + return [ + base_lr + * ( + self.start_factor + + (self.end_factor - self.start_factor) + * min(self.total_iters, self.last_epoch) + / self.total_iters + ) + for base_lr in self.base_lrs + ] + + + + +class ExponentialLR(LRScheduler): + """Decays the learning rate of each parameter group by gamma every epoch. + When last_epoch=-1, sets initial lr as lr. + + Args: + optimizer (Optimizer): Wrapped optimizer. + gamma (float): Multiplicative factor of learning rate decay. + last_epoch (int): The index of last epoch. Default: -1. + verbose (bool | str): If ``True``, prints a message to stdout for + each update. Default: ``False``. + + .. deprecated:: 2.2 + ``verbose`` is deprecated. Please use ``get_last_lr()`` to access the + learning rate. + """ + + def __init__( + self, optimizer: Optimizer, gamma: float, last_epoch=-1, verbose="deprecated" + ): + self.gamma = gamma + super().__init__(optimizer, last_epoch, verbose) + + def get_lr(self): + _warn_get_lr_called_within_step(self) + + if self.last_epoch == 0: + return [group["lr"] for group in self.optimizer.param_groups] + return [group["lr"] * self.gamma for group in self.optimizer.param_groups] + + def _get_closed_form_lr(self): + return [base_lr * self.gamma**self.last_epoch for base_lr in self.base_lrs] + + + + +class SequentialLR(LRScheduler): + """Receives the list of schedulers that is expected to be called sequentially during + optimization process and milestone points that provides exact intervals to reflect + which scheduler is supposed to be called at a given epoch. + + Args: + optimizer (Optimizer): Wrapped optimizer. + schedulers (list): List of chained schedulers. + milestones (list): List of integers that reflects milestone points. + last_epoch (int): The index of last epoch. Default: -1. + verbose (bool | str): Does nothing. + + .. deprecated:: 2.2 + ``verbose`` is deprecated. Please use ``get_last_lr()`` to access the + learning rate. + + Example: + >>> # xdoctest: +SKIP + >>> # Assuming optimizer uses lr = 1. for all groups + >>> # lr = 0.1 if epoch == 0 + >>> # lr = 0.1 if epoch == 1 + >>> # lr = 0.9 if epoch == 2 + >>> # lr = 0.81 if epoch == 3 + >>> # lr = 0.729 if epoch == 4 + >>> scheduler1 = ConstantLR(optimizer, factor=0.1, total_iters=2) + >>> scheduler2 = ExponentialLR(optimizer, gamma=0.9) + >>> scheduler = SequentialLR(optimizer, schedulers=[scheduler1, scheduler2], milestones=[2]) + >>> for epoch in range(100): + >>> train(...) + >>> validate(...) + >>> scheduler.step() + """ + + def __init__( # pylint: disable=super-init-not-called + self, + optimizer: Optimizer, + schedulers: List[LRScheduler], + milestones: List[int], + last_epoch=-1, + verbose="deprecated", + ): + if len(schedulers) < 1: + raise ValueError( + f"{self.__class__.__name__} expects at least one scheduler, but got no scheduler." + ) + + for scheduler_idx, scheduler in enumerate(schedulers): + if not hasattr(scheduler, "optimizer"): + raise TypeError( + f"{self.__class__.__name__} at index {scheduler_idx} should have `optimizer` as its attribute." + ) + if isinstance(scheduler, ReduceLROnPlateau): + raise ValueError( + f"{self.__class__.__name__} does not support `ReduceLROnPlateau` scheduler as it " + "requires additional kwargs to be specified when calling `step`, " + f"but got one at index {scheduler_idx} in the given schedulers sequence." + ) + if optimizer != scheduler.optimizer: + raise ValueError( + f"{self.__class__.__name__} expects all schedulers to belong to the same optimizer, but " + f"got scheduler {scheduler.__class__.__name__} at index {scheduler_idx} has {scheduler.optimizer}, " + f"which is different from {optimizer.__class__.__name__}." + ) + + if len(milestones) != len(schedulers) - 1: + raise ValueError( + "Sequential Schedulers expects number of schedulers provided to be one more " + f"than the number of milestone points, but got number of schedulers {len(schedulers)} and the " + f"number of milestones to be equal to {len(milestones)}" + ) + _check_verbose_deprecated_warning(verbose) + self._schedulers = schedulers + self._milestones = milestones + self.last_epoch = last_epoch + 1 + self.optimizer = optimizer + + # Reset learning rates back to initial values + for group in self.optimizer.param_groups: + group["lr"] = group["initial_lr"] + + # "Undo" the step performed by other schedulers + for scheduler in self._schedulers: + scheduler.last_epoch -= 1 + + # Perform the initial step for only the first scheduler + self._schedulers[0]._initial_step() + + self._last_lr = schedulers[0].get_last_lr() + + def step(self): + self.last_epoch += 1 + idx = bisect_right(self._milestones, self.last_epoch) + scheduler = self._schedulers[idx] + if idx > 0 and self._milestones[idx - 1] == self.last_epoch: + scheduler.step(0) + else: + scheduler.step() + + self._last_lr = scheduler.get_last_lr() + + + def state_dict(self): + """Returns the state of the scheduler as a :class:`dict`. + + It contains an entry for every variable in self.__dict__ which + is not the optimizer. + The wrapped scheduler states will also be saved. + """ + state_dict = { + key: value + for key, value in self.__dict__.items() + if key not in ("optimizer", "_schedulers") + } + state_dict["_schedulers"] = [None] * len(self._schedulers) + + for idx, s in enumerate(self._schedulers): + state_dict["_schedulers"][idx] = s.state_dict() + + return state_dict + + + + def load_state_dict(self, state_dict): + """Loads the schedulers state. + + Args: + state_dict (dict): scheduler state. Should be an object returned + from a call to :meth:`state_dict`. + """ + _schedulers = state_dict.pop("_schedulers") + self.__dict__.update(state_dict) + # Restore state_dict keys in order to prevent side effects + # https://github.com/pytorch/pytorch/issues/32756 + state_dict["_schedulers"] = _schedulers + + for idx, s in enumerate(_schedulers): + self._schedulers[idx].load_state_dict(s) + + + + +class PolynomialLR(LRScheduler): + """Decays the learning rate of each parameter group using a polynomial function + in the given total_iters. When last_epoch=-1, sets initial lr as lr. + + Args: + optimizer (Optimizer): Wrapped optimizer. + total_iters (int): The number of steps that the scheduler decays the learning rate. Default: 5. + power (float): The power of the polynomial. Default: 1.0. + verbose (bool | str): If ``True``, prints a message to stdout for + each update. Default: ``False``. + + .. deprecated:: 2.2 + ``verbose`` is deprecated. Please use ``get_last_lr()`` to access the + learning rate. + + Example: + >>> # xdoctest: +SKIP("undefined vars") + >>> # Assuming optimizer uses lr = 0.001 for all groups + >>> # lr = 0.001 if epoch == 0 + >>> # lr = 0.00075 if epoch == 1 + >>> # lr = 0.00050 if epoch == 2 + >>> # lr = 0.00025 if epoch == 3 + >>> # lr = 0.0 if epoch >= 4 + >>> scheduler = PolynomialLR(optimizer, total_iters=4, power=1.0) + >>> for epoch in range(100): + >>> train(...) + >>> validate(...) + >>> scheduler.step() + """ + + def __init__( + self, + optimizer: Optimizer, + total_iters=5, + power=1.0, + last_epoch=-1, + verbose="deprecated", + ): + self.total_iters = total_iters + self.power = power + super().__init__(optimizer, last_epoch, verbose) + + def get_lr(self): + _warn_get_lr_called_within_step(self) + + if self.last_epoch == 0 or self.last_epoch > self.total_iters: + return [group["lr"] for group in self.optimizer.param_groups] + + decay_factor = ( + (1.0 - self.last_epoch / self.total_iters) + / (1.0 - (self.last_epoch - 1) / self.total_iters) + ) ** self.power + return [group["lr"] * decay_factor for group in self.optimizer.param_groups] + + def _get_closed_form_lr(self): + return [ + ( + base_lr + * (1.0 - min(self.total_iters, self.last_epoch) / self.total_iters) + ** self.power + ) + for base_lr in self.base_lrs + ] + + + + +class CosineAnnealingLR(LRScheduler): + r"""Set the learning rate of each parameter group using a cosine annealing + schedule, where :math:`\eta_{max}` is set to the initial lr and + :math:`T_{cur}` is the number of epochs since the last restart in SGDR: + + .. math:: + \begin{aligned} + \eta_t & = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 + + \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right), + & T_{cur} \neq (2k+1)T_{max}; \\ + \eta_{t+1} & = \eta_{t} + \frac{1}{2}(\eta_{max} - \eta_{min}) + \left(1 - \cos\left(\frac{1}{T_{max}}\pi\right)\right), + & T_{cur} = (2k+1)T_{max}. + \end{aligned} + + When last_epoch=-1, sets initial lr as lr. Notice that because the schedule + is defined recursively, the learning rate can be simultaneously modified + outside this scheduler by other operators. If the learning rate is set + solely by this scheduler, the learning rate at each step becomes: + + .. math:: + \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 + + \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right) + + It has been proposed in + `SGDR: Stochastic Gradient Descent with Warm Restarts`_. Note that this only + implements the cosine annealing part of SGDR, and not the restarts. + + Args: + optimizer (Optimizer): Wrapped optimizer. + T_max (int): Maximum number of iterations. + eta_min (float): Minimum learning rate. Default: 0. + last_epoch (int): The index of last epoch. Default: -1. + verbose (bool | str): If ``True``, prints a message to stdout for + each update. Default: ``False``. + + .. deprecated:: 2.2 + ``verbose`` is deprecated. Please use ``get_last_lr()`` to access the + learning rate. + + .. _SGDR\: Stochastic Gradient Descent with Warm Restarts: + https://arxiv.org/abs/1608.03983 + """ + + def __init__( + self, + optimizer: Optimizer, + T_max: int, + eta_min=0, + last_epoch=-1, + verbose="deprecated", + ): + self.T_max = T_max + self.eta_min = eta_min + super().__init__(optimizer, last_epoch, verbose) + + def get_lr(self): + _warn_get_lr_called_within_step(self) + + if self.last_epoch == 0: + return [group["lr"] for group in self.optimizer.param_groups] + elif self._step_count == 1 and self.last_epoch > 0: + return [ + self.eta_min + + (base_lr - self.eta_min) + * (1 + math.cos((self.last_epoch) * math.pi / self.T_max)) + / 2 + for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) + ] + elif (self.last_epoch - 1 - self.T_max) % (2 * self.T_max) == 0: + return [ + group["lr"] + + (base_lr - self.eta_min) * (1 - math.cos(math.pi / self.T_max)) / 2 + for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) + ] + return [ + (1 + math.cos(math.pi * self.last_epoch / self.T_max)) + / (1 + math.cos(math.pi * (self.last_epoch - 1) / self.T_max)) + * (group["lr"] - self.eta_min) + + self.eta_min + for group in self.optimizer.param_groups + ] + + def _get_closed_form_lr(self): + return [ + self.eta_min + + (base_lr - self.eta_min) + * (1 + math.cos(math.pi * self.last_epoch / self.T_max)) + / 2 + for base_lr in self.base_lrs + ] + + + + +class ChainedScheduler(LRScheduler): + """Chains list of learning rate schedulers. It takes a sequence of chainable learning + rate schedulers and performs consecutive step() functions belonging to them by just + one call. + + Args: + schedulers (sequence): sequence of chained schedulers. + optimizer (Optimizer, optional): Wrapped optimizer. Default: None. + + Example: + >>> # xdoctest: +SKIP + >>> # Assuming optimizer uses lr = 1. for all groups + >>> # lr = 0.09 if epoch == 0 + >>> # lr = 0.081 if epoch == 1 + >>> # lr = 0.729 if epoch == 2 + >>> # lr = 0.6561 if epoch == 3 + >>> # lr = 0.59049 if epoch >= 4 + >>> scheduler1 = ConstantLR(optimizer, factor=0.1, total_iters=2) + >>> scheduler2 = ExponentialLR(optimizer, gamma=0.9) + >>> scheduler = ChainedScheduler([scheduler1, scheduler2], optimizer=optimizer) + >>> for epoch in range(100): + >>> train(...) + >>> validate(...) + >>> scheduler.step() + """ + + def __init__( # pylint: disable=super-init-not-called + self, schedulers: Sequence[LRScheduler], optimizer: Optional[Optimizer] = None + ): + if len(schedulers) < 1: + raise ValueError( + f"{self.__class__.__name__} expects at least one scheduler to be chained, but got no scheduler." + ) + + optimizer = optimizer or schedulers[0].optimizer + for scheduler_idx, scheduler in enumerate(schedulers): + if not hasattr(scheduler, "optimizer"): + raise TypeError( + f"{self.__class__.__name__} at index {scheduler_idx} should have `optimizer` as its attribute." + ) + if isinstance(scheduler, ReduceLROnPlateau): + raise ValueError( + f"{self.__class__.__name__} does not support `ReduceLROnPlateau` scheduler as it " + "requires additional kwargs to be specified when calling `step`, " + f"but got one at index {scheduler_idx} in the given schedulers sequence." + ) + if optimizer != scheduler.optimizer: + raise ValueError( + f"{self.__class__.__name__} expects all schedulers to belong to the same optimizer, but " + f"got scheduler {scheduler.__class__.__name__} at index {scheduler_idx} has {scheduler.optimizer}, " + f"which is different from {optimizer.__class__.__name__}." + ) + self._schedulers = schedulers + self.optimizer = optimizer + self._last_lr = [ + group["lr"] for group in self._schedulers[-1].optimizer.param_groups + ] + + def step(self): + for scheduler in self._schedulers: + scheduler.step() + self._last_lr = [ + group["lr"] for group in self._schedulers[-1].optimizer.param_groups + ] + + + def state_dict(self): + """Returns the state of the scheduler as a :class:`dict`. + + It contains an entry for every variable in self.__dict__ which + is not the optimizer. + The wrapped scheduler states will also be saved. + """ + state_dict = { + key: value + for key, value in self.__dict__.items() + if key not in ("optimizer", "_schedulers") + } + state_dict["_schedulers"] = [None] * len(self._schedulers) + + for idx, s in enumerate(self._schedulers): + state_dict["_schedulers"][idx] = s.state_dict() + + return state_dict + + + + def load_state_dict(self, state_dict): + """Loads the schedulers state. + + Args: + state_dict (dict): scheduler state. Should be an object returned + from a call to :meth:`state_dict`. + """ + _schedulers = state_dict.pop("_schedulers") + self.__dict__.update(state_dict) + # Restore state_dict keys in order to prevent side effects + # https://github.com/pytorch/pytorch/issues/32756 + state_dict["_schedulers"] = _schedulers + + for idx, s in enumerate(_schedulers): + self._schedulers[idx].load_state_dict(s) + + + + +class ReduceLROnPlateau(LRScheduler): + """Reduce learning rate when a metric has stopped improving. + Models often benefit from reducing the learning rate by a factor + of 2-10 once learning stagnates. This scheduler reads a metrics + quantity and if no improvement is seen for a 'patience' number + of epochs, the learning rate is reduced. + + Args: + optimizer (Optimizer): Wrapped optimizer. + mode (str): One of `min`, `max`. In `min` mode, lr will + be reduced when the quantity monitored has stopped + decreasing; in `max` mode it will be reduced when the + quantity monitored has stopped increasing. Default: 'min'. + factor (float): Factor by which the learning rate will be + reduced. new_lr = lr * factor. Default: 0.1. + patience (int): The number of allowed epochs with no improvement after + which the learning rate will be reduced. + For example, consider the case of having no patience (`patience = 0`). + In the first epoch, a baseline is established and is always considered good as there's no previous baseline. + In the second epoch, if the performance is worse than the baseline, + we have what is considered an intolerable epoch. + Since the count of intolerable epochs (1) is greater than the patience level (0), + the learning rate is reduced at the end of this epoch. + From the third epoch onwards, the learning rate continues to be reduced at the end of each epoch + if the performance is worse than the baseline. If the performance improves or remains the same, + the learning rate is not adjusted. + Default: 10. + threshold (float): Threshold for measuring the new optimum, + to only focus on significant changes. Default: 1e-4. + threshold_mode (str): One of `rel`, `abs`. In `rel` mode, + dynamic_threshold = best * ( 1 + threshold ) in 'max' + mode or best * ( 1 - threshold ) in `min` mode. + In `abs` mode, dynamic_threshold = best + threshold in + `max` mode or best - threshold in `min` mode. Default: 'rel'. + cooldown (int): Number of epochs to wait before resuming + normal operation after lr has been reduced. Default: 0. + min_lr (float or list): A scalar or a list of scalars. A + lower bound on the learning rate of all param groups + or each group respectively. Default: 0. + eps (float): Minimal decay applied to lr. If the difference + between new and old lr is smaller than eps, the update is + ignored. Default: 1e-8. + verbose (bool | str): If ``True``, prints a message to stdout for + each update. Default: ``False``. + + .. deprecated:: 2.2 + ``verbose`` is deprecated. Please use ``get_last_lr()`` to access the + learning rate. + + Example: + >>> # xdoctest: +SKIP + >>> optimizer = core.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) + >>> scheduler = ReduceLROnPlateau(optimizer, 'min') + >>> for epoch in range(10): + >>> train(...) + >>> val_loss = validate(...) + >>> # Note that step should be called after validate() + >>> scheduler.step(val_loss) + """ + + def __init__( # pylint: disable=super-init-not-called + self, + optimizer: Optimizer, + mode: Literal["min", "max"] = "min", + factor=0.1, + patience=10, + threshold=1e-4, + threshold_mode: Literal["rel", "abs"] = "rel", + cooldown=0, + min_lr: Union[List[float], float] = 0, + eps=1e-8, + verbose="deprecated", + ): + if factor >= 1.0: + raise ValueError("Factor should be < 1.0.") + self.factor = factor + + # Attach optimizer + if not isinstance(optimizer, Optimizer): + raise TypeError(f"{type(optimizer).__name__} is not an Optimizer") + self.optimizer = optimizer + + if isinstance(min_lr, (list, tuple)): + if len(min_lr) != len(optimizer.param_groups): + raise ValueError( + f"expected {len(optimizer.param_groups)} min_lrs, got {len(min_lr)}" + ) + self.min_lrs = list(min_lr) + else: + self.min_lrs = [min_lr] * len(optimizer.param_groups) + + self.patience = patience + + self.verbose = _check_verbose_deprecated_warning(verbose) + self.cooldown = cooldown + self.cooldown_counter = 0 + self.mode = mode + self.threshold = threshold + self.threshold_mode = threshold_mode + self.best: float + self.num_bad_epochs: int + self.mode_worse: float # the worse value for the chosen mode + self.eps = eps + self.last_epoch = 0 + self._last_lr = [group["lr"] for group in self.optimizer.param_groups] + self._init_is_better( + mode=mode, threshold=threshold, threshold_mode=threshold_mode + ) + self._reset() + + def _reset(self): + """Resets num_bad_epochs counter and cooldown counter.""" + self.best = self.mode_worse + self.cooldown_counter = 0 + self.num_bad_epochs = 0 + + def step(self, metrics: SupportsFloat, epoch=None): # type: ignore[override] + # convert `metrics` to float, in case it's a zero-dim Tensor + current = float(metrics) + if epoch is None: + epoch = self.last_epoch + 1 + else: + warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning) + self.last_epoch = epoch + + if self.is_better(current, self.best): + self.best = current + self.num_bad_epochs = 0 + else: + self.num_bad_epochs += 1 + + if self.in_cooldown: + self.cooldown_counter -= 1 + self.num_bad_epochs = 0 # ignore any bad epochs in cooldown + + if self.num_bad_epochs > self.patience: + self._reduce_lr(epoch) + self.cooldown_counter = self.cooldown + self.num_bad_epochs = 0 + + self._last_lr = [group["lr"] for group in self.optimizer.param_groups] + + def _reduce_lr(self, epoch): + for i, param_group in enumerate(self.optimizer.param_groups): + old_lr = float(param_group["lr"]) + new_lr = max(old_lr * self.factor, self.min_lrs[i]) + if old_lr - new_lr > self.eps: + param_group["lr"] = new_lr + + @property + def in_cooldown(self): + return self.cooldown_counter > 0 + + def is_better(self, a, best): + if self.mode == "min" and self.threshold_mode == "rel": + rel_epsilon = 1.0 - self.threshold + return a < best * rel_epsilon + + elif self.mode == "min" and self.threshold_mode == "abs": + return a < best - self.threshold + + elif self.mode == "max" and self.threshold_mode == "rel": + rel_epsilon = self.threshold + 1.0 + return a > best * rel_epsilon + + else: # mode == 'max' and epsilon_mode == 'abs': + return a > best + self.threshold + + def _init_is_better(self, mode, threshold, threshold_mode): + if mode not in {"min", "max"}: + raise ValueError("mode " + mode + " is unknown!") + if threshold_mode not in {"rel", "abs"}: + raise ValueError("threshold mode " + threshold_mode + " is unknown!") + + if mode == "min": + self.mode_worse = inf + else: # mode == 'max': + self.mode_worse = -inf + + self.mode = mode + self.threshold = threshold + self.threshold_mode = threshold_mode + + def state_dict(self): + return { + key: value for key, value in self.__dict__.items() if key != "optimizer" + } + + def load_state_dict(self, state_dict): + self.__dict__.update(state_dict) + self._init_is_better( + mode=self.mode, threshold=self.threshold, threshold_mode=self.threshold_mode + ) + + + + +class CyclicLR(LRScheduler): + r"""Sets the learning rate of each parameter group according to + cyclical learning rate policy (CLR). The policy cycles the learning + rate between two boundaries with a constant frequency, as detailed in + the paper `Cyclical Learning Rates for Training Neural Networks`_. + The distance between the two boundaries can be scaled on a per-iteration + or per-cycle basis. + + Cyclical learning rate policy changes the learning rate after every batch. + `step` should be called after a batch has been used for training. + + This class has three built-in policies, as put forth in the paper: + + * "triangular": A basic triangular cycle without amplitude scaling. + * "triangular2": A basic triangular cycle that scales initial amplitude by half each cycle. + * "exp_range": A cycle that scales initial amplitude by :math:`\text{gamma}^{\text{cycle iterations}}` + at each cycle iteration. + + This implementation was adapted from the github repo: `bckenstler/CLR`_ + + Args: + optimizer (Optimizer): Wrapped optimizer. + base_lr (float or list): Initial learning rate which is the + lower boundary in the cycle for each parameter group. + max_lr (float or list): Upper learning rate boundaries in the cycle + for each parameter group. Functionally, + it defines the cycle amplitude (max_lr - base_lr). + The lr at any cycle is the sum of base_lr + and some scaling of the amplitude; therefore + max_lr may not actually be reached depending on + scaling function. + step_size_up (int): Number of training iterations in the + increasing half of a cycle. Default: 2000 + step_size_down (int): Number of training iterations in the + decreasing half of a cycle. If step_size_down is None, + it is set to step_size_up. Default: None + mode (str): One of {triangular, triangular2, exp_range}. + Values correspond to policies detailed above. + If scale_fn is not None, this argument is ignored. + Default: 'triangular' + gamma (float): Constant in 'exp_range' scaling function: + gamma**(cycle iterations) + Default: 1.0 + scale_fn (function): Custom scaling policy defined by a single + argument lambda function, where + 0 <= scale_fn(x) <= 1 for all x >= 0. + If specified, then 'mode' is ignored. + Default: None + scale_mode (str): {'cycle', 'iterations'}. + Defines whether scale_fn is evaluated on + cycle number or cycle iterations (training + iterations since start of cycle). + Default: 'cycle' + cycle_momentum (bool): If ``True``, momentum is cycled inversely + to learning rate between 'base_momentum' and 'max_momentum'. + Default: True + base_momentum (float or list): Lower momentum boundaries in the cycle + for each parameter group. Note that momentum is cycled inversely + to learning rate; at the peak of a cycle, momentum is + 'base_momentum' and learning rate is 'max_lr'. + Default: 0.8 + max_momentum (float or list): Upper momentum boundaries in the cycle + for each parameter group. Functionally, + it defines the cycle amplitude (max_momentum - base_momentum). + The momentum at any cycle is the difference of max_momentum + and some scaling of the amplitude; therefore + base_momentum may not actually be reached depending on + scaling function. Note that momentum is cycled inversely + to learning rate; at the start of a cycle, momentum is 'max_momentum' + and learning rate is 'base_lr' + Default: 0.9 + last_epoch (int): The index of the last batch. This parameter is used when + resuming a training job. Since `step()` should be invoked after each + batch instead of after each epoch, this number represents the total + number of *batches* computed, not the total number of epochs computed. + When last_epoch=-1, the schedule is started from the beginning. + Default: -1 + verbose (bool | str): If ``True``, prints a message to stdout for + each update. Default: ``False``. + + .. deprecated:: 2.2 + ``verbose`` is deprecated. Please use ``get_last_lr()`` to access the + learning rate. + + Example: + >>> # xdoctest: +SKIP + >>> optimizer = core.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) + >>> scheduler = core.optim.lr_scheduler.CyclicLR(optimizer, base_lr=0.01, max_lr=0.1) + >>> data_loader = core.utils.data.DataLoader(...) + >>> for epoch in range(10): + >>> for batch in data_loader: + >>> train_batch(...) + >>> scheduler.step() + + + .. _Cyclical Learning Rates for Training Neural Networks: https://arxiv.org/abs/1506.01186 + .. _bckenstler/CLR: https://github.com/bckenstler/CLR + """ + + def __init__( + self, + optimizer: Optimizer, + base_lr: Union[float, List[float]], + max_lr: Union[float, List[float]], + step_size_up=2000, + step_size_down: Optional[int] = None, + mode: Literal["triangular", "triangular2", "exp_range"] = "triangular", + gamma=1.0, + scale_fn: Optional[Callable[[float], float]] = None, + scale_mode: Literal["cycle", "iterations"] = "cycle", + cycle_momentum=True, + base_momentum=0.8, + max_momentum=0.9, + last_epoch=-1, + verbose="deprecated", + ): + # Attach optimizer + if not isinstance(optimizer, Optimizer): + raise TypeError(f"{type(optimizer).__name__} is not an Optimizer") + self.optimizer = optimizer + + base_lrs = _format_param("base_lr", optimizer, base_lr) + if last_epoch == -1: + for lr, group in zip(base_lrs, optimizer.param_groups): + if isinstance(group["lr"], Tensor): + lr_val = lr.item() if isinstance(lr, Tensor) else lr + group["lr"].fill_(lr_val) + else: + group["lr"] = lr + + self.max_lrs = _format_param("max_lr", optimizer, max_lr) + + step_size_up = float(step_size_up) + step_size_down = ( + float(step_size_down) if step_size_down is not None else step_size_up + ) + self.total_size = step_size_up + step_size_down + self.step_ratio = step_size_up / self.total_size + + if mode not in ["triangular", "triangular2", "exp_range"] and scale_fn is None: + raise ValueError("mode is invalid and scale_fn is None") + + self.mode = mode + self.gamma = gamma + + self._scale_fn_ref: Callable[[float], float] + self._scale_fn_custom = scale_fn + self.scale_mode = scale_mode + self._init_scale_fn() + + self.cycle_momentum = cycle_momentum + if cycle_momentum: + if ( + "momentum" not in optimizer.defaults + and "betas" not in optimizer.defaults + ): + raise ValueError( + "optimizer must support momentum or beta1 with `cycle_momentum` option enabled" + ) + + self.use_beta1 = "betas" in self.optimizer.defaults + self.base_momentums = _format_param( + "base_momentum", optimizer, base_momentum + ) + self.max_momentums = _format_param("max_momentum", optimizer, max_momentum) + if last_epoch == -1: + for m_momentum, b_momentum, group in zip( + self.max_momentums, self.base_momentums, optimizer.param_groups + ): + if self.use_beta1: + group["betas"] = (m_momentum, *group["betas"][1:]) + else: + group["momentum"] = m_momentum + group["max_momentum"] = m_momentum + group["base_momentum"] = b_momentum + + super().__init__(optimizer, last_epoch, verbose) + self.base_lrs = base_lrs + + def _init_scale_fn(self): + if self._scale_fn_custom is not None: + return + if self.mode == "triangular": + self._scale_fn_ref = self._triangular_scale_fn + self.scale_mode = "cycle" + elif self.mode == "triangular2": + self._scale_fn_ref = self._triangular2_scale_fn + self.scale_mode = "cycle" + elif self.mode == "exp_range": + self._scale_fn_ref = partial(self._exp_range_scale_fn, self.gamma) + self.scale_mode = "iterations" + + def scale_fn(self, x) -> float: + if self._scale_fn_custom is not None: + return self._scale_fn_custom(x) + else: + return self._scale_fn_ref(x) # static method + + @staticmethod + def _triangular_scale_fn(x: float) -> float: + return 1.0 + + @staticmethod + def _triangular2_scale_fn(x: float) -> float: + return 1 / (2.0 ** (x - 1)) + + @staticmethod + def _exp_range_scale_fn(gamma: float, x: float) -> float: + return gamma**x + + + def get_lr(self): + """Calculates the learning rate at batch index. This function treats + `self.last_epoch` as the last batch index. + + If `self.cycle_momentum` is ``True``, this function has a side effect of + updating the optimizer's momentum. + """ + + _warn_get_lr_called_within_step(self) + + cycle = math.floor(1 + self.last_epoch / self.total_size) + x = 1.0 + self.last_epoch / self.total_size - cycle + if x <= self.step_ratio: + scale_factor = x / self.step_ratio + else: + scale_factor = (x - 1) / (self.step_ratio - 1) + + lrs = [] + for base_lr, max_lr in zip(self.base_lrs, self.max_lrs): + base_height = (max_lr - base_lr) * scale_factor + if self.scale_mode == "cycle": + lr = base_lr + base_height * self.scale_fn(cycle) + else: + lr = base_lr + base_height * self.scale_fn(self.last_epoch) + lrs.append(lr) + + if self.cycle_momentum: + momentums = [] + for base_momentum, max_momentum in zip( + self.base_momentums, self.max_momentums + ): + base_height = (max_momentum - base_momentum) * scale_factor + if self.scale_mode == "cycle": + momentum = max_momentum - base_height * self.scale_fn(cycle) + else: + momentum = max_momentum - base_height * self.scale_fn( + self.last_epoch + ) + momentums.append(momentum) + for param_group, momentum in zip(self.optimizer.param_groups, momentums): + if self.use_beta1: + param_group["betas"] = (momentum, *param_group["betas"][1:]) + else: + param_group["momentum"] = momentum + + return lrs + + + def state_dict(self): + state = super().state_dict() + # We are dropping the `_scale_fn_ref` attribute because it is a + # `weakref.WeakMethod` and can't be pickled. + state.pop("_scale_fn_ref", None) + fn = state.pop("_scale_fn_custom") + state["_scale_fn_custom"] = None + if fn is not None and not isinstance(fn, types.FunctionType): + # The _scale_fn_custom will only be saved if it is a callable object + # and not if it is a function or lambda. + state["_scale_fn_custom"] = fn.__dict__.copy() + + return state + + def load_state_dict(self, state_dict): + fn = state_dict.pop("_scale_fn_custom") + super().load_state_dict(state_dict) + if fn is not None: + self._scale_fn_custom.__dict__.update(fn) + self._init_scale_fn() + + + + +class CosineAnnealingWarmRestarts(LRScheduler): + r"""Set the learning rate of each parameter group using a cosine annealing + schedule, where :math:`\eta_{max}` is set to the initial lr, :math:`T_{cur}` + is the number of epochs since the last restart and :math:`T_{i}` is the number + of epochs between two warm restarts in SGDR: + + .. math:: + \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 + + \cos\left(\frac{T_{cur}}{T_{i}}\pi\right)\right) + + When :math:`T_{cur}=T_{i}`, set :math:`\eta_t = \eta_{min}`. + When :math:`T_{cur}=0` after restart, set :math:`\eta_t=\eta_{max}`. + + It has been proposed in + `SGDR: Stochastic Gradient Descent with Warm Restarts`_. + + Args: + optimizer (Optimizer): Wrapped optimizer. + T_0 (int): Number of iterations until the first restart. + T_mult (int, optional): A factor by which :math:`T_{i}` increases after a restart. Default: 1. + eta_min (float, optional): Minimum learning rate. Default: 0. + last_epoch (int, optional): The index of the last epoch. Default: -1. + verbose (bool | str): If ``True``, prints a message to stdout for + each update. Default: ``False``. + + .. deprecated:: 2.2 + ``verbose`` is deprecated. Please use ``get_last_lr()`` to access the + learning rate. + + .. _SGDR\: Stochastic Gradient Descent with Warm Restarts: + https://arxiv.org/abs/1608.03983 + """ + + def __init__( + self, + optimizer: Optimizer, + T_0: int, + T_mult=1, + eta_min=0, + last_epoch=-1, + verbose="deprecated", + ): + if T_0 <= 0 or not isinstance(T_0, int): + raise ValueError(f"Expected positive integer T_0, but got {T_0}") + if T_mult < 1 or not isinstance(T_mult, int): + raise ValueError(f"Expected integer T_mult >= 1, but got {T_mult}") + if not isinstance(eta_min, (float, int)): + raise ValueError( + f"Expected float or int eta_min, but got {eta_min} of type {type(eta_min)}" + ) + self.T_0 = T_0 + self.T_i = T_0 + self.T_mult = T_mult + self.eta_min = eta_min + self.T_cur = last_epoch + super().__init__(optimizer, last_epoch, verbose) + + def get_lr(self): + _warn_get_lr_called_within_step(self) + + return [ + self.eta_min + + (base_lr - self.eta_min) + * (1 + math.cos(math.pi * self.T_cur / self.T_i)) + / 2 + for base_lr in self.base_lrs + ] + + + def step(self, epoch=None): + """Step could be called after every batch update + + Example: + >>> # xdoctest: +SKIP("Undefined vars") + >>> scheduler = CosineAnnealingWarmRestarts(optimizer, T_0, T_mult) + >>> iters = len(dataloader) + >>> for epoch in range(20): + >>> for i, sample in enumerate(dataloader): + >>> inputs, labels = sample['inputs'], sample['labels'] + >>> optimizer.zero_grad() + >>> outputs = net(inputs) + >>> loss = criterion(outputs, labels) + >>> loss.backward() + >>> optimizer.step() + >>> scheduler.step(epoch + i / iters) + + This function can be called in an interleaved way. + + Example: + >>> # xdoctest: +SKIP("Undefined vars") + >>> scheduler = CosineAnnealingWarmRestarts(optimizer, T_0, T_mult) + >>> for epoch in range(20): + >>> scheduler.step() + >>> scheduler.step(26) + >>> scheduler.step() # scheduler.step(27), instead of scheduler(20) + """ + + if epoch is None and self.last_epoch < 0: + epoch = 0 + + if epoch is None: + epoch = self.last_epoch + 1 + self.T_cur = self.T_cur + 1 + if self.T_cur >= self.T_i: + self.T_cur = self.T_cur - self.T_i + self.T_i = self.T_i * self.T_mult + else: + if epoch < 0: + raise ValueError(f"Expected non-negative epoch, but got {epoch}") + if epoch >= self.T_0: + if self.T_mult == 1: + self.T_cur = epoch % self.T_0 + else: + n = int( + math.log( + (epoch / self.T_0 * (self.T_mult - 1) + 1), self.T_mult + ) + ) + self.T_cur = epoch - self.T_0 * (self.T_mult**n - 1) / ( + self.T_mult - 1 + ) + self.T_i = self.T_0 * self.T_mult ** (n) + else: + self.T_i = self.T_0 + self.T_cur = epoch + self.last_epoch = math.floor(epoch) + + with _enable_get_lr_call(self): + for i, data in enumerate(zip(self.optimizer.param_groups, self.get_lr())): + param_group, lr = data + param_group["lr"] = lr + + self._last_lr = [group["lr"] for group in self.optimizer.param_groups] + + + +class _SchedulePhase(TypedDict): + end_step: float + start_lr: str + end_lr: str + start_momentum: str + end_momentum: str + + + +class OneCycleLR(LRScheduler): + r"""Sets the learning rate of each parameter group according to the + 1cycle learning rate policy. The 1cycle policy anneals the learning + rate from an initial learning rate to some maximum learning rate and then + from that maximum learning rate to some minimum learning rate much lower + than the initial learning rate. + This policy was initially described in the paper `Super-Convergence: + Very Fast Training of Neural Networks Using Large Learning Rates`_. + + The 1cycle learning rate policy changes the learning rate after every batch. + `step` should be called after a batch has been used for training. + + This scheduler is not chainable. + + Note also that the total number of steps in the cycle can be determined in one + of two ways (listed in order of precedence): + + #. A value for total_steps is explicitly provided. + #. A number of epochs (epochs) and a number of steps per epoch + (steps_per_epoch) are provided. + In this case, the number of total steps is inferred by + total_steps = epochs * steps_per_epoch + + You must either provide a value for total_steps or provide a value for both + epochs and steps_per_epoch. + + The default behaviour of this scheduler follows the fastai implementation of 1cycle, which + claims that "unpublished work has shown even better results by using only two phases". To + mimic the behaviour of the original paper instead, set ``three_phase=True``. + + Args: + optimizer (Optimizer): Wrapped optimizer. + max_lr (float or list): Upper learning rate boundaries in the cycle + for each parameter group. + total_steps (int): The total number of steps in the cycle. Note that + if a value is not provided here, then it must be inferred by providing + a value for epochs and steps_per_epoch. + Default: None + epochs (int): The number of epochs to train for. This is used along + with steps_per_epoch in order to infer the total number of steps in the cycle + if a value for total_steps is not provided. + Default: None + steps_per_epoch (int): The number of steps per epoch to train for. This is + used along with epochs in order to infer the total number of steps in the + cycle if a value for total_steps is not provided. + Default: None + pct_start (float): The percentage of the cycle (in number of steps) spent + increasing the learning rate. + Default: 0.3 + anneal_strategy (str): {'cos', 'linear'} + Specifies the annealing strategy: "cos" for cosine annealing, "linear" for + linear annealing. + Default: 'cos' + cycle_momentum (bool): If ``True``, momentum is cycled inversely + to learning rate between 'base_momentum' and 'max_momentum'. + Default: True + base_momentum (float or list): Lower momentum boundaries in the cycle + for each parameter group. Note that momentum is cycled inversely + to learning rate; at the peak of a cycle, momentum is + 'base_momentum' and learning rate is 'max_lr'. + Default: 0.85 + max_momentum (float or list): Upper momentum boundaries in the cycle + for each parameter group. Functionally, + it defines the cycle amplitude (max_momentum - base_momentum). + Note that momentum is cycled inversely + to learning rate; at the start of a cycle, momentum is 'max_momentum' + and learning rate is 'base_lr' + Default: 0.95 + div_factor (float): Determines the initial learning rate via + initial_lr = max_lr/div_factor + Default: 25 + final_div_factor (float): Determines the minimum learning rate via + min_lr = initial_lr/final_div_factor + Default: 1e4 + three_phase (bool): If ``True``, use a third phase of the schedule to annihilate the + learning rate according to 'final_div_factor' instead of modifying the second + phase (the first two phases will be symmetrical about the step indicated by + 'pct_start'). + last_epoch (int): The index of the last batch. This parameter is used when + resuming a training job. Since `step()` should be invoked after each + batch instead of after each epoch, this number represents the total + number of *batches* computed, not the total number of epochs computed. + When last_epoch=-1, the schedule is started from the beginning. + Default: -1 + verbose (bool | str): If ``True``, prints a message to stdout for + each update. Default: ``False``. + + .. deprecated:: 2.2 + ``verbose`` is deprecated. Please use ``get_last_lr()`` to access the + learning rate. + + Example: + >>> # xdoctest: +SKIP + >>> data_loader = core.utils.data.DataLoader(...) + >>> optimizer = core.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) + >>> scheduler = core.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.01, steps_per_epoch=len(data_loader), epochs=10) + >>> for epoch in range(10): + >>> for batch in data_loader: + >>> train_batch(...) + >>> optimizer.step() + >>> scheduler.step() + + + .. _Super-Convergence\: Very Fast Training of Neural Networks Using Large Learning Rates: + https://arxiv.org/abs/1708.07120 + """ + + def __init__( + self, + optimizer: Optimizer, + max_lr: Union[float, List[float]], + total_steps: Optional[int] = None, + epochs: Optional[int] = None, + steps_per_epoch: Optional[int] = None, + pct_start=0.3, + anneal_strategy: Literal["cos", "linear"] = "cos", + cycle_momentum=True, + base_momentum: Union[float, List[float]] = 0.85, + max_momentum: Union[float, List[float]] = 0.95, + div_factor=25.0, + final_div_factor=1e4, + three_phase=False, + last_epoch=-1, + verbose="deprecated", + ): + # Validate optimizer + if not isinstance(optimizer, Optimizer): + raise TypeError(f"{type(optimizer).__name__} is not an Optimizer") + self.optimizer = optimizer + + # Validate total_steps + if total_steps is not None: + if total_steps <= 0 or not isinstance(total_steps, int): + raise ValueError( + f"Expected positive integer total_steps, but got {total_steps}" + ) + self.total_steps = total_steps + elif epochs is not None and steps_per_epoch is not None: + if not isinstance(epochs, int) or epochs <= 0: + raise ValueError(f"Expected positive integer epochs, but got {epochs}") + if not isinstance(steps_per_epoch, int) or steps_per_epoch <= 0: + raise ValueError( + f"Expected positive integer steps_per_epoch, but got {steps_per_epoch}" + ) + self.total_steps = epochs * steps_per_epoch + else: + raise ValueError( + "You must define either total_steps OR (epochs AND steps_per_epoch)" + ) + + self._schedule_phases: List[_SchedulePhase] + if three_phase: + self._schedule_phases = [ + { + "end_step": float(pct_start * self.total_steps) - 1, + "start_lr": "initial_lr", + "end_lr": "max_lr", + "start_momentum": "max_momentum", + "end_momentum": "base_momentum", + }, + { + "end_step": float(2 * pct_start * self.total_steps) - 2, + "start_lr": "max_lr", + "end_lr": "initial_lr", + "start_momentum": "base_momentum", + "end_momentum": "max_momentum", + }, + { + "end_step": self.total_steps - 1, + "start_lr": "initial_lr", + "end_lr": "min_lr", + "start_momentum": "max_momentum", + "end_momentum": "max_momentum", + }, + ] + else: + self._schedule_phases = [ + { + "end_step": float(pct_start * self.total_steps) - 1, + "start_lr": "initial_lr", + "end_lr": "max_lr", + "start_momentum": "max_momentum", + "end_momentum": "base_momentum", + }, + { + "end_step": self.total_steps - 1, + "start_lr": "max_lr", + "end_lr": "min_lr", + "start_momentum": "base_momentum", + "end_momentum": "max_momentum", + }, + ] + + # Validate pct_start + if pct_start < 0 or pct_start > 1 or not isinstance(pct_start, float): + raise ValueError( + f"Expected float between 0 and 1 pct_start, but got {pct_start}" + ) + + # Validate anneal_strategy + if anneal_strategy not in ["cos", "linear"]: + raise ValueError( + f"anneal_strategy must be one of 'cos' or 'linear', instead got {anneal_strategy}" + ) + else: + self._anneal_func_type = anneal_strategy + + # Initialize learning rate variables + max_lrs = _format_param("max_lr", self.optimizer, max_lr) + if last_epoch == -1: + for idx, group in enumerate(self.optimizer.param_groups): + group["initial_lr"] = max_lrs[idx] / div_factor + group["max_lr"] = max_lrs[idx] + group["min_lr"] = group["initial_lr"] / final_div_factor + + # Initialize momentum variables + self.cycle_momentum = cycle_momentum + if self.cycle_momentum: + if ( + "momentum" not in self.optimizer.defaults + and "betas" not in self.optimizer.defaults + ): + raise ValueError( + "optimizer must support momentum or beta1 with `cycle_momentum` option enabled" + ) + self.use_beta1 = "betas" in self.optimizer.defaults + max_momentums = _format_param("max_momentum", optimizer, max_momentum) + base_momentums = _format_param("base_momentum", optimizer, base_momentum) + if last_epoch == -1: + for m_momentum, b_momentum, group in zip( + max_momentums, base_momentums, optimizer.param_groups + ): + if self.use_beta1: + group["betas"] = (m_momentum, *group["betas"][1:]) + else: + group["momentum"] = m_momentum + group["max_momentum"] = m_momentum + group["base_momentum"] = b_momentum + + super().__init__(optimizer, last_epoch, verbose) + + def _anneal_func(self, *args, **kwargs): + if hasattr(self, "_anneal_func_type"): + if self._anneal_func_type == "cos": + return self._annealing_cos(*args, **kwargs) + elif self._anneal_func_type == "linear": + return self._annealing_linear(*args, **kwargs) + else: + raise ValueError(f"Unknown _anneal_func_type: {self._anneal_func_type}") + else: + # For BC + return self.anneal_func(*args, **kwargs) # type: ignore[attr-defined] + + @staticmethod + def _annealing_cos(start, end, pct): + "Cosine anneal from `start` to `end` as pct goes from 0.0 to 1.0." + cos_out = math.cos(math.pi * pct) + 1 + return end + (start - end) / 2.0 * cos_out + + @staticmethod + def _annealing_linear(start, end, pct): + "Linearly anneal from `start` to `end` as pct goes from 0.0 to 1.0." + return (end - start) * pct + start + + def get_lr(self): + _warn_get_lr_called_within_step(self) + + lrs = [] + step_num = self.last_epoch + + if step_num > self.total_steps: + raise ValueError( + f"Tried to step {step_num} times. The specified number of total steps is {self.total_steps}" # noqa: UP032 + ) + + for group in self.optimizer.param_groups: + start_step = 0.0 + for i, phase in enumerate(self._schedule_phases): + end_step = phase["end_step"] + if step_num <= end_step or i == len(self._schedule_phases) - 1: + pct = (step_num - start_step) / (end_step - start_step) + computed_lr = self._anneal_func( + group[phase["start_lr"]], group[phase["end_lr"]], pct + ) + if self.cycle_momentum: + computed_momentum = self._anneal_func( + group[phase["start_momentum"]], + group[phase["end_momentum"]], + pct, + ) + break + start_step = phase["end_step"] + + lrs.append(computed_lr) # type: ignore[possibly-undefined] + if self.cycle_momentum: + if self.use_beta1: + group["betas"] = (computed_momentum, *group["betas"][1:]) # type: ignore[possibly-undefined] + else: + group[ + "momentum" + ] = computed_momentum # type: ignore[possibly-undefined] + + return lrs diff --git a/mindnlp/core/optim/optimizer.py b/mindnlp/core/optim/optimizer.py new file mode 100644 index 000000000..148e976a9 --- /dev/null +++ b/mindnlp/core/optim/optimizer.py @@ -0,0 +1,756 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +"""Base optimizer.""" +import functools +import math +import warnings +from collections import defaultdict, OrderedDict +from copy import deepcopy +from itertools import chain +from typing import ( + Any, + Callable, + cast, + DefaultDict, + Dict, + Hashable, + Iterable, + List, + Optional, + Set, + Tuple, + TypeVar, + Union, +) +from typing_extensions import ParamSpec, Self, TypeAlias + +from mindnlp import core +from mindnlp.core.nn import Parameter +from .. import ops +from ..utils import hooks +from ..utils.hooks import RemovableHandle +from .._bind import get_default_dtype + +Args: TypeAlias = Tuple[Any, ...] +Kwargs: TypeAlias = Dict[str, Any] +StateDict: TypeAlias = Dict[str, Any] +TensorListList: TypeAlias = List[List[core.Tensor]] + + +GlobalOptimizerPreHook: TypeAlias = Callable[ + ["Optimizer", Args, Kwargs], Optional[Tuple[Args, Kwargs]] +] +GlobalOptimizerPostHook: TypeAlias = Callable[["Optimizer", Args, Kwargs], None] + +__all__ = [ + "Optimizer", + "register_optimizer_step_pre_hook", + "register_optimizer_step_post_hook", +] +_global_optimizer_pre_hooks: Dict[int, GlobalOptimizerPreHook] = OrderedDict() +_global_optimizer_post_hooks: Dict[int, GlobalOptimizerPostHook] = OrderedDict() +_foreach_supported_types = [core.Tensor, Parameter] + + +class _RequiredParameter: + """Singleton class representing a required parameter for an Optimizer.""" + + def __repr__(self) -> str: + return "" + + +required = _RequiredParameter() + + +def _get_value(x): + # item is significantly faster than a cpu tensor in eager mode + return x.item() if isinstance(x, core.Tensor) else x + + +def _stack_if_compiling(x): + return x + + +def _dispatch_sqrt( + x: float, +): # float annotation is needed because of torchscript type inference + return math.sqrt(x) + + +def _disable_dynamo_if_unsupported(single_tensor_fn=None): + # workaround for torchscript BC + # it requires all called functions to be in the + # global environment at the site at which the + # maybe_fallback closure is created + if single_tensor_fn: + globals()[single_tensor_fn.__name__] = single_tensor_fn + + def wrapper(func): + import inspect + + ps = inspect.signature(func).parameters + has_state_steps = True + try: + state_steps_ind = list(ps.keys()).index("state_steps") + except ValueError: + has_state_steps = False + + # Today, there are cases where we stack state steps + # and pass them as the value arg of foreach ops. + # Having state steps on cuda as the value arg is not supported in eager, + # but this only occurs in the rare case that the user explicitly deletes + # the capturable flag. If capturable=True, this is not a problem. + @functools.wraps(func) + def maybe_fallback(*args, **kwargs): + return func(*args, **kwargs) + + return maybe_fallback + + return wrapper + + + +def _view_as_real(params, *state_and_grads): + for i, p in enumerate(params): + if ops.is_complex(p): + params[i] = ops.view_as_real(params[i]) + for s in state_and_grads: + s[i] = ops.view_as_real(s[i]) + + +def _get_scalar_dtype(is_fused=None): + if is_fused: + return core.float32 + return ( + core.float64 if get_default_dtype() == core.float64 else core.float32 + ) + + + +def register_optimizer_step_pre_hook(hook: GlobalOptimizerPreHook) -> RemovableHandle: + r"""Register a pre hook common to all optimizers. + + The hook should have the following signature:: + + hook(optimizer, args, kwargs) -> None or modified args and kwargs + + Args: + hook (Callable): A user defined hook which is registered on all optimizers. + + Returns: + :class:`core.utils.hooks.RemovableHandle`: + a handle that can be used to remove the added hook by calling + ``handle.remove()`` + """ + handle = hooks.RemovableHandle(_global_optimizer_pre_hooks) + _global_optimizer_pre_hooks[handle.id] = hook + return handle + + +def register_optimizer_step_post_hook(hook: GlobalOptimizerPostHook) -> RemovableHandle: + r"""Register a post hook common to all optimizers. + + The hook should have the following signature:: + + hook(optimizer, args, kwargs) -> None + + Args: + hook (Callable): A user defined hook which is registered on all optimizers. + + Returns: + :class:`core.utils.hooks.RemovableHandle`: + a handle that can be used to remove the added hook by calling + ``handle.remove()`` + """ + handle = hooks.RemovableHandle(_global_optimizer_post_hooks) + _global_optimizer_post_hooks[handle.id] = hook + return handle + + +ParamsT: TypeAlias = Union[Iterable[core.Tensor], Iterable[Dict[str, Any]]] + +_P = ParamSpec("_P") +R = TypeVar("R") +T = TypeVar("T") + + +class Optimizer: + r"""Base class for all optimizers. + + .. warning:: + Parameters need to be specified as collections that have a deterministic + ordering that is consistent between runs. Examples of objects that don't + satisfy those properties are sets and iterators over values of dictionaries. + + Args: + params (iterable): an iterable of :class:`core.Tensor` s or + :class:`dict` s. Specifies what Tensors should be optimized. + defaults: (dict): a dict containing default values of optimization + options (used when a parameter group doesn't specify them). + """ + + OptimizerPreHook: TypeAlias = Callable[[Self, Args, Kwargs], Optional[Tuple[Args, Kwargs]]] # type: ignore[misc] + OptimizerPostHook: TypeAlias = Callable[[Self, Args, Kwargs], None] # type: ignore[misc] + + _optimizer_step_pre_hooks: Dict[int, OptimizerPreHook] + _optimizer_step_post_hooks: Dict[int, OptimizerPostHook] + _optimizer_state_dict_pre_hooks: 'OrderedDict[int, Callable[["Optimizer"], None]]' + _optimizer_state_dict_post_hooks: 'OrderedDict[int, Callable[["Optimizer", StateDict], Optional[StateDict]]]' + _optimizer_load_state_dict_pre_hooks: 'OrderedDict[int, Callable[["Optimizer", StateDict], Optional[StateDict]]]' + _optimizer_load_state_dict_post_hooks: 'OrderedDict[int, Callable[["Optimizer"], None]]' + + def __init__(self, params: ParamsT, defaults: Dict[str, Any]) -> None: # noqa: D107 + self.defaults = defaults + self._optimizer_step_pre_hooks = OrderedDict() + self._optimizer_step_post_hooks = OrderedDict() + self._optimizer_state_dict_pre_hooks = OrderedDict() + self._optimizer_state_dict_post_hooks = OrderedDict() + self._optimizer_load_state_dict_pre_hooks = OrderedDict() + self._optimizer_load_state_dict_post_hooks = OrderedDict() + + if isinstance(params, core.Tensor): + raise TypeError( + "params argument given to the optimizer should be " + "an iterable of Tensors or dicts, but got " + type(params) + ) + + self.state: DefaultDict[core.Tensor, Any] = defaultdict(dict) + self.param_groups: List[Dict[str, Any]] = [] + + param_groups = list(params) + if len(param_groups) == 0: + raise ValueError("optimizer got an empty parameter list") + if not isinstance(param_groups[0], dict): + param_groups = [{"params": param_groups}] + + for param_group in param_groups: + self.add_param_group(cast(dict, param_group)) + + # Allows _cuda_graph_capture_health_check to rig a poor man's TORCH_WARN_ONCE in python, + # which I don't think exists + # https://github.com/pytorch/pytorch/issues/72948 + self._warned_capturable_if_run_uncaptured = True + + def __getstate__(self) -> Dict[str, Any]: # noqa: D105 + return { + "defaults": self.defaults, + "state": self.state, + "param_groups": self.param_groups, + } + + def __setstate__(self, state: Dict[str, Any]) -> None: # noqa: D105 + self.__dict__.update(state) + if "_optimizer_step_pre_hooks" not in self.__dict__: + self._optimizer_step_pre_hooks = OrderedDict() + if "_optimizer_step_post_hooks" not in self.__dict__: + self._optimizer_step_post_hooks = OrderedDict() + if "_optimizer_state_dict_pre_hooks" not in self.__dict__: + self._optimizer_state_dict_pre_hooks = OrderedDict() + if "_optimizer_state_dict_post_hooks" not in self.__dict__: + self._optimizer_state_dict_post_hooks = OrderedDict() + if "_optimizer_load_state_dict_pre_hooks" not in self.__dict__: + self._optimizer_load_state_dict_pre_hooks = OrderedDict() + if "_optimizer_load_state_dict_post_hooks" not in self.__dict__: + self._optimizer_load_state_dict_post_hooks = OrderedDict() + self.defaults.setdefault("differentiable", False) + + def __repr__(self) -> str: # noqa: D105 + format_string = self.__class__.__name__ + " (" + for i, group in enumerate(self.param_groups): + format_string += "\n" + format_string += f"Parameter Group {i}\n" + for key in sorted(group.keys()): + if key != "params": + format_string += f" {key}: {group[key]}\n" + format_string += ")" + return format_string + + def _optimizer_step_code(self) -> None: + """Entry point for `core.profile.profiler`. + + When python tracing is enabled the profiler will hook into this + function at the CPython level to inspect the optimizer's parameters and + param groups. It is called it after `step()` since many optimizers + lazily initialize state. + + This is a workaround due to lack of a proper step hook on the optimizer, + . + """ + + def register_step_pre_hook(self, hook: OptimizerPreHook) -> RemovableHandle: + r"""Register an optimizer step pre hook which will be called before optimizer step. + + It should have the following signature:: + + hook(optimizer, args, kwargs) -> None or modified args and kwargs + + The ``optimizer`` argument is the optimizer instance being used. If + args and kwargs are modified by the pre-hook, then the transformed + values are returned as a tuple containing the new_args and new_kwargs. + + Args: + hook (Callable): The user defined hook to be registered. + + Returns: + :class:`core.utils.hooks.RemovableHandle`: + a handle that can be used to remove the added hook by calling + ``handle.remove()`` + """ + handle = hooks.RemovableHandle(self._optimizer_step_pre_hooks) + self._optimizer_step_pre_hooks[handle.id] = hook + return handle + + def register_step_post_hook(self, hook: OptimizerPostHook) -> RemovableHandle: + r"""Register an optimizer step post hook which will be called after optimizer step. + + It should have the following signature:: + + hook(optimizer, args, kwargs) -> None + + The ``optimizer`` argument is the optimizer instance being used. + + Args: + hook (Callable): The user defined hook to be registered. + + Returns: + :class:`core.utils.hooks.RemovableHandle`: + a handle that can be used to remove the added hook by calling + ``handle.remove()`` + """ + handle = hooks.RemovableHandle(self._optimizer_step_post_hooks) + self._optimizer_step_post_hooks[handle.id] = hook + return handle + + def register_state_dict_pre_hook( + self, hook: Callable[["Optimizer"], None], prepend: bool = False + ) -> RemovableHandle: # noqa: D101 + r"""Register a state dict pre-hook which will be called before :meth:`~core.optim.Optimizer.state_dict` is called. + + It should have the following signature:: + + hook(optimizer) -> None + + The ``optimizer`` argument is the optimizer instance being used. + The hook will be called with argument ``self`` before calling ``state_dict`` on ``self``. + The registered hook can be used to perform pre-processing before the ``state_dict`` + call is made. + + Args: + hook (Callable): The user defined hook to be registered. + prepend (bool): If True, the provided pre ``hook`` will be fired before + all the already registered pre-hooks on ``state_dict``. Otherwise, + the provided ``hook`` will be fired after all the already registered + pre-hooks. (default: False) + + Returns: + :class:`core.utils.hooks.RemoveableHandle`: + a handle that can be used to remove the added hook by calling + ``handle.remove()`` + """ + handle = hooks.RemovableHandle(self._optimizer_state_dict_pre_hooks) + self._optimizer_state_dict_pre_hooks[handle.id] = hook + if prepend: + self._optimizer_state_dict_pre_hooks.move_to_end(handle.id, last=False) + return handle + + def register_state_dict_post_hook( + self, + hook: Callable[["Optimizer", StateDict], Optional[StateDict]], + prepend: bool = False, + ) -> RemovableHandle: + r"""Register a state dict post-hook which will be called after :meth:`~core.optim.Optimizer.state_dict` is called. + + It should have the following signature:: + + hook(optimizer, state_dict) -> state_dict or None + + The hook will be called with arguments ``self`` and ``state_dict`` after generating + a ``state_dict`` on ``self``. The hook may modify the state_dict inplace or optionally + return a new one. The registered hook can be used to perform post-processing + on the ``state_dict`` before it is returned. + + Args: + hook (Callable): The user defined hook to be registered. + prepend (bool): If True, the provided post ``hook`` will be fired before + all the already registered post-hooks on ``state_dict``. Otherwise, + the provided ``hook`` will be fired after all the already registered + post-hooks. (default: False) + + Returns: + :class:`core.utils.hooks.RemoveableHandle`: + a handle that can be used to remove the added hook by calling + ``handle.remove()`` + """ + handle = hooks.RemovableHandle(self._optimizer_state_dict_post_hooks) + self._optimizer_state_dict_post_hooks[handle.id] = hook + if prepend: + self._optimizer_state_dict_post_hooks.move_to_end(handle.id, last=False) + return handle + + def state_dict(self) -> StateDict: + r"""Return the state of the optimizer as a :class:`dict`. + + It contains two entries: + + * ``state``: a Dict holding current optimization state. Its content + differs between optimizer classes, but some common characteristics + hold. For example, state is saved per parameter, and the parameter + itself is NOT saved. ``state`` is a Dictionary mapping parameter ids + to a Dict with state corresponding to each parameter. + * ``param_groups``: a List containing all parameter groups where each + parameter group is a Dict. Each parameter group contains metadata + specific to the optimizer, such as learning rate and weight decay, + as well as a List of parameter IDs of the parameters in the group. + + NOTE: The parameter IDs may look like indices but they are just IDs + associating state with param_group. When loading from a state_dict, + the optimizer will zip the param_group ``params`` (int IDs) and the + optimizer ``param_groups`` (actual ``nn.Parameter`` s) in order to + match state WITHOUT additional verification. + + A returned state dict might look something like: + + .. code-block:: text + + { + 'state': { + 0: {'momentum_buffer': tensor(...), ...}, + 1: {'momentum_buffer': tensor(...), ...}, + 2: {'momentum_buffer': tensor(...), ...}, + 3: {'momentum_buffer': tensor(...), ...} + }, + 'param_groups': [ + { + 'lr': 0.01, + 'weight_decay': 0, + ... + 'params': [0] + }, + { + 'lr': 0.001, + 'weight_decay': 0.5, + ... + 'params': [1, 2, 3] + } + ] + } + + """ + for pre_hook in self._optimizer_state_dict_pre_hooks.values(): + pre_hook(self) + + # Save order indices instead of Tensors + param_mappings: Dict[int, int] = {} + start_index = 0 + + def pack_group(group: Dict[str, Any]) -> Dict[str, Any]: + nonlocal start_index + packed = {k: v for k, v in group.items() if k != "params"} + param_mappings.update( + { + id(p): i + for i, p in enumerate(group["params"], start_index) + if id(p) not in param_mappings + } + ) + packed["params"] = [param_mappings[id(p)] for p in group["params"]] + start_index += len(packed["params"]) + return packed + + param_groups = [pack_group(g) for g in self.param_groups] + # Remap state to use order indices as keys + packed_state = { + (param_mappings[id(k)] if isinstance(k, core.Tensor) else k): v + for k, v in self.state.items() + } + + state_dict = { + "state": packed_state, + "param_groups": param_groups, + } + + for post_hook in self._optimizer_state_dict_post_hooks.values(): + hook_result = post_hook(self, state_dict) + if hook_result is not None: + state_dict = hook_result + return state_dict + + @staticmethod + def _process_value_according_to_param_policy( + param: core.Tensor, + value: core.Tensor, + param_id: int, + param_groups: List[Dict[Any, Any]], + key: Hashable = None, + ) -> core.Tensor: + # Floating-point types are a bit special here. They are the only ones + # that are assumed to always match the type of params. + # Make sure state['step'] is not casted https://github.com/pytorch/pytorch/issues/74424 + # UNLESS fused or capturable, see note [special device hosting for step] + fused = False + capturable = False + assert param_groups is not None + for pg in param_groups: + if param_id in pg["params"]: + fused = pg["fused"] if "fused" in pg else False + capturable = pg["capturable"] if "capturable" in pg else False + break + if key == "step": + if capturable or fused: + return value.to(dtype=core.float32) + else: + return value + else: + if param.is_floating_point(): + return value.to(dtype=param.dtype) + else: + return value + + def register_load_state_dict_pre_hook( + self, + hook: Callable[["Optimizer", StateDict], Optional[StateDict]], + prepend: bool = False, + ) -> RemovableHandle: # noqa: D205 D400 + r"""Register a load_state_dict pre-hook which will be called before + :meth:`~core.optim.Optimizer.load_state_dict` is called. It should have the + following signature:: + + hook(optimizer, state_dict) -> state_dict or None + + The ``optimizer`` argument is the optimizer instance being used and the + ``state_dict`` argument is a shallow copy of the ``state_dict`` the user + passed in to ``load_state_dict``. The hook may modify the state_dict inplace + or optionally return a new one. If a state_dict is returned, it will be used + to be loaded into the optimizer. + + The hook will be called with argument ``self`` and ``state_dict`` before + calling ``load_state_dict`` on ``self``. The registered hook can be used to + perform pre-processing before the ``load_state_dict`` call is made. + + Args: + hook (Callable): The user defined hook to be registered. + prepend (bool): If True, the provided pre ``hook`` will be fired before + all the already registered pre-hooks on ``load_state_dict``. Otherwise, + the provided ``hook`` will be fired after all the already registered + pre-hooks. (default: False) + + Returns: + :class:`core.utils.hooks.RemoveableHandle`: + a handle that can be used to remove the added hook by calling + ``handle.remove()`` + """ + handle = hooks.RemovableHandle(self._optimizer_load_state_dict_pre_hooks) + self._optimizer_load_state_dict_pre_hooks[handle.id] = hook + if prepend: + self._optimizer_load_state_dict_pre_hooks.move_to_end(handle.id, last=False) + return handle + + def register_load_state_dict_post_hook( + self, hook: Callable[["Optimizer"], None], prepend: bool = False + ) -> RemovableHandle: # noqa: D205 D400 + r"""Register a load_state_dict post-hook which will be called after + :meth:`~core.optim.Optimizer.load_state_dict` is called. It should have the + following signature:: + + hook(optimizer) -> None + + The ``optimizer`` argument is the optimizer instance being used. + + The hook will be called with argument ``self`` after calling + ``load_state_dict`` on ``self``. The registered hook can be used to + perform post-processing after ``load_state_dict`` has loaded the + ``state_dict``. + + Args: + hook (Callable): The user defined hook to be registered. + prepend (bool): If True, the provided post ``hook`` will be fired before + all the already registered post-hooks on ``load_state_dict``. Otherwise, + the provided ``hook`` will be fired after all the already registered + post-hooks. (default: False) + + Returns: + :class:`core.utils.hooks.RemoveableHandle`: + a handle that can be used to remove the added hook by calling + ``handle.remove()`` + """ + handle = hooks.RemovableHandle(self._optimizer_load_state_dict_post_hooks) + self._optimizer_load_state_dict_post_hooks[handle.id] = hook + if prepend: + self._optimizer_load_state_dict_post_hooks.move_to_end(handle.id, last=False) # type: ignore[attr-defined] + return handle + + def load_state_dict(self, state_dict: StateDict) -> None: + r"""Load the optimizer state. + + Args: + state_dict (dict): optimizer state. Should be an object returned + from a call to :meth:`state_dict`. + """ + # shallow copy, to be consistent with module API + state_dict = state_dict.copy() + + for pre_hook in self._optimizer_load_state_dict_pre_hooks.values(): + hook_result = pre_hook(self, state_dict) + if hook_result is not None: + state_dict = hook_result + + # Validate the state_dict + groups = self.param_groups + + # Deepcopy as we write into saved_groups later to update state + saved_groups = deepcopy(state_dict["param_groups"]) + + if len(groups) != len(saved_groups): + raise ValueError( + "loaded state dict has a different number of parameter groups" + ) + param_lens = (len(g["params"]) for g in groups) + saved_lens = (len(g["params"]) for g in saved_groups) + if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)): + raise ValueError( + "loaded state dict contains a parameter group " + "that doesn't match the size of optimizer's group" + ) + + # Update the state + id_map = dict( + zip( + chain.from_iterable(g["params"] for g in saved_groups), + chain.from_iterable(g["params"] for g in groups), + ) + ) + + def _cast(param, value, param_id=None, param_groups=None, key=None): + r"""Make a deep copy of value, casting all tensors to device of param.""" + if isinstance(value, core.Tensor): + return Optimizer._process_value_according_to_param_policy( + param, value, param_id, param_groups, key + ) + elif isinstance(value, dict): + return { + k: _cast( + param, v, param_id=param_id, param_groups=param_groups, key=k + ) + for k, v in value.items() + } + elif isinstance(value, Iterable): + return type(value)(_cast(param, v, param_id=param_id, param_groups=param_groups) for v in value) # type: ignore[call-arg] + else: + return value + + # Copy state assigned to params (and cast tensors to appropriate types). + # State that is not assigned to params is copied as is (needed for + # backward compatibility). + state: DefaultDict[core.Tensor, Dict[Any, Any]] = defaultdict(dict) + for k, v in state_dict["state"].items(): + if k in id_map: + param = id_map[k] + state[param] = _cast( + param, v, param_id=k, param_groups=state_dict["param_groups"] + ) + else: + state[k] = v + + # Update parameter groups, setting their 'params' value + def update_group( + group: Dict[str, Any], new_group: Dict[str, Any] + ) -> Dict[str, Any]: + new_group["params"] = group["params"] + return new_group + + param_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups)] + self.__setstate__({"state": state, "param_groups": param_groups}) + + for post_hook in self._optimizer_load_state_dict_post_hooks.values(): + post_hook(self) + + + def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]: + r"""Perform a single optimization step to update parameter. + + Args: + closure (Callable): A closure that reevaluates the model and + returns the loss. Optional for most optimizers. + + .. note:: + Unless otherwise specified, this function should not modify the + ``.grad`` field of the parameters. + """ + raise NotImplementedError + + def add_param_group(self, param_group: Dict[str, Any]) -> None: + r"""Add a param group to the :class:`Optimizer` s `param_groups`. + + This can be useful when fine tuning a pre-trained network as frozen layers can be made + trainable and added to the :class:`Optimizer` as training progresses. + + Args: + param_group (dict): Specifies what Tensors should be optimized along with group + specific optimization options. + """ + if not isinstance(param_group, dict): + raise TypeError(f"param_group must be a dict, but got {type(param_group)}") + + params = param_group["params"] + if isinstance(params, core.Tensor): + param_group["params"] = [params] + elif isinstance(params, set): + raise TypeError( + "optimizer parameters need to be organized in ordered collections, but " + "the ordering of tensors in sets will change between runs. Please use a list instead." + ) + else: + param_group["params"] = list(params) + + for param in param_group["params"]: + if not isinstance(param, core.Tensor): + raise TypeError( + "optimizer can only optimize Tensors, " + "but one of the params is " + type(param) + ) + for name, default in self.defaults.items(): + if default is required and name not in param_group: + raise ValueError( + f"parameter group didn't specify a value of required optimization parameter {name}" + ) + else: + param_group.setdefault(name, default) + + params = param_group["params"] + if len(params) != len(set(params)): + warnings.warn( + "optimizer contains a parameter group with duplicate parameters; " + "in future, this will cause an error; " + "see github.com/pytorch/pytorch/issues/40967 for more information", + stacklevel=3, + ) + + param_set: Set[core.Tensor] = set() + for group in self.param_groups: + param_set.update(set(group["params"])) + + if not param_set.isdisjoint(set(param_group["params"])): + raise ValueError("some parameters appear in more than one parameter group") + + self.param_groups.append(param_group) + + def zero_grad(self, set_to_none: bool = True) -> None: + r"""Reset the gradients of all optimized :class:`core.Tensor` s. + + Args: + set_to_none (bool): instead of setting to zero, set the grads to None. + This will in general have lower memory footprint, and can modestly improve performance. + However, it changes certain behaviors. For example: + 1. When the user tries to access a gradient and perform manual ops on it, + a None attribute or a Tensor full of 0s will behave differently. + 2. If the user requests ``zero_grad(set_to_none=True)`` followed by a backward pass, ``.grad``\ s + are guaranteed to be None for params that did not receive a gradient. + 3. ``core.optim`` optimizers have a different behavior if the gradient is 0 or None + (in one case it does the step with a gradient of 0 and in the other it skips + the step altogether). + """ + for group in self.param_groups: + for p in group["params"]: + if p.grad is not None: + if set_to_none: + p.grad = None diff --git a/mindnlp/core/optim/sgd.py b/mindnlp/core/optim/sgd.py new file mode 100644 index 000000000..f8a58a560 --- /dev/null +++ b/mindnlp/core/optim/sgd.py @@ -0,0 +1,90 @@ +"""sgd""" +# pylint: disable=use-dict-literal +# mypy: allow-untyped-defs +from mindnlp import core +from mindnlp.core import Tensor +from .optimizer import ( + Optimizer, +) +from .. import ops + +__all__ = ["SGD"] + + +class SGD(Optimizer): + def __init__( + self, + params, + lr: float = 1e-3, + momentum: float = 0, + dampening: float = 0, + weight_decay: float = 0, + nesterov=False, + *, + maximize: bool = False, + ): + if lr < 0.0: + raise ValueError(f"Invalid learning rate: {lr}") + if momentum < 0.0: + raise ValueError(f"Invalid momentum value: {momentum}") + if weight_decay < 0.0: + raise ValueError(f"Invalid weight_decay value: {weight_decay}") + + defaults = dict( + lr=lr, + momentum=momentum, + dampening=dampening, + weight_decay=weight_decay, + nesterov=nesterov, + maximize=maximize, + ) + if nesterov and (momentum <= 0 or dampening != 0): + raise ValueError("Nesterov momentum requires a momentum and zero dampening") + super().__init__(params, defaults) + + def __setstate__(self, state): + super().__setstate__(state) + for group in self.param_groups: + group.setdefault("nesterov", False) + group.setdefault("maximize", False) + + def step(self): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + for group in self.param_groups: + weight_decay = float(group['weight_decay']) + momentum = core.tensor(group['momentum'], dtype=core.float32, device=group['params'][0].device) + lr = core.tensor(group['lr'], dtype=core.float32, device=group['params'][0].device) + + dampening = float(group['dampening']) + nesterov = group['nesterov'] + maximize = group["maximize"] + + for p in group['params']: + d_p = p.grad if not maximize else -p.grad + # if weight_decay != 0: + # d_p = d_p.add(p, alpha=weight_decay) + # if momentum != 0: + # param_state = self.state[p] + # if 'momentum_buffer' not in param_state: + # buf = param_state['momentum_buffer'] = d_p.clone() + # else: + # buf = param_state['momentum_buffer'] + # buf = buf.mul(momentum) + # buf = buf.add_(d_p, alpha=1 - dampening) + # if nesterov: + # d_p = d_p.add(momentum, buf) + # else: + # d_p = buf + # new_p = p.add(d_p, alpha=-group['lr']) + # assign(p, new_p) + stat = core.ones_like(p) + accum = core.zeros_like(p) + ops.optim.raw_sgd(p, d_p, lr, dampening, weight_decay, nesterov, accum, momentum, stat) + + return loss diff --git a/mindnlp/core/overrides.py b/mindnlp/core/overrides.py new file mode 100644 index 000000000..047e190a5 --- /dev/null +++ b/mindnlp/core/overrides.py @@ -0,0 +1,85 @@ +from typing import Callable, Iterable, Any +from mindnlp import core + +def is_tensor_like(inp): + """ + Returns ``True`` if the passed-in input is a Tensor-like. + + Currently, this occurs whenever there's a ``__torch_function__`` + attribute on the type of the input. + + Examples + -------- + A subclass of tensor is generally a Tensor-like. + + >>> class SubTensor(torch.Tensor): ... + >>> is_tensor_like(SubTensor([0])) + True + + Built-in or user types aren't usually Tensor-like. + + >>> is_tensor_like(6) + False + >>> is_tensor_like(None) + False + >>> class NotATensor: ... + >>> is_tensor_like(NotATensor()) + False + + But, they can be made Tensor-like by implementing __torch_function__. + + >>> class TensorLike: + ... @classmethod + ... def __torch_function__(cls, func, types, args, kwargs): + ... return -1 + >>> is_tensor_like(TensorLike()) + True + """ + return type(inp) is core.Tensor or hasattr(inp, "__torch_function__") + +def handle_torch_function( + public_api: Callable, + relevant_args: Iterable[Any], + *args, + **kwargs, +) -> Any: + """Implement a function with checks for ``__torch_function__`` overrides. + + See torch::autograd::handle_torch_function for the equivalent of this + function in the C++ implementation. + + Arguments + --------- + public_api : function + Function exposed by the public torch API originally called like + ``public_api(*args, **kwargs)`` on which arguments are now being + checked. + relevant_args : iterable + Iterable of arguments to check for __torch_function__ methods. + args : tuple + Arbitrary positional arguments originally passed into ``public_api``. + kwargs : tuple + Arbitrary keyword arguments originally passed into ``public_api``. + + Returns + ------- + object + Result from calling ``implementation`` or an ``__torch_function__`` + method, as appropriate. + + Raises + ------ + TypeError : if no implementation is found. + + Example + ------- + >>> def func(a): + ... if has_torch_function_unary(a): + ... return handle_torch_function(func, (a,), a) + ... return a + 0 + """ + # Check for __torch_function__ methods. + pass + +def has_torch_function(inp): + return hasattr(inp, "__torch_function__") \ No newline at end of file diff --git a/mindnlp/core/profiler/__init__.py b/mindnlp/core/profiler/__init__.py new file mode 100644 index 000000000..379e8c3e6 --- /dev/null +++ b/mindnlp/core/profiler/__init__.py @@ -0,0 +1,3 @@ +from mindspore.profiler import ProfilerActivity + +from .profiler import profile, tensorboard_trace_handler diff --git a/mindnlp/core/profiler/experimental_config.py b/mindnlp/core/profiler/experimental_config.py new file mode 100644 index 000000000..998e7b26a --- /dev/null +++ b/mindnlp/core/profiler/experimental_config.py @@ -0,0 +1,51 @@ +from mindspore.profiler import ProfilerLevel as msProfilerLevel +from mindspore.profiler import AicoreMetrics + + +class ProfilerLevel: + Level0 = msProfilerLevel.Level0 + Level1 = msProfilerLevel.Level1 + Level2 = msProfilerLevel.Level2 + Level_none = msProfilerLevel.LevelNone + + +class AiCMetrics: + PipeUtilization = AicoreMetrics.PipeUtilization + ArithmeticUtilization = AicoreMetrics.ArithmeticUtilization + Memory = AicoreMetrics.Memory + MemoryL0 = AicoreMetrics.MemoryL0 + MemoryUB = AicoreMetrics.MemoryUB + ResourceConflictRatio = AicoreMetrics.ResourceConflictRatio + L2Cache = AicoreMetrics.L2Cache + MemoryAccess = AicoreMetrics.AiCoreNone # ToImplenmentened + AiCoreNone = AicoreMetrics.AiCoreNone + + +class ExportType: + Db = "db" + Text = "text" + + +class _ExperimentalConfig: + def __init__(self, + profiler_level: int = ProfilerLevel.Level0, + aic_metrics: int = AiCMetrics.AiCoreNone, + l2_cache: bool = False, + msprof_tx: bool = False, + data_simplification: bool = True, + record_op_args: bool = False, + op_attr: bool = False, + gc_detect_threshold: float = None, + export_type: str = ExportType.Text): + self._profiler_level = profiler_level + self._aic_metrics = aic_metrics + if self._profiler_level != None: + if self._profiler_level != ProfilerLevel.Level0 and self._aic_metrics == AiCMetrics.AiCoreNone: + self._aic_metrics = AiCMetrics.PipeUtilization + self._l2_cache = l2_cache + self._msprof_tx = msprof_tx + self._data_simplification = data_simplification + self.record_op_args = record_op_args + self._export_type = export_type + self._op_attr = op_attr + self._gc_detect_threshold = gc_detect_threshold diff --git a/mindnlp/core/profiler/profiler.py b/mindnlp/core/profiler/profiler.py new file mode 100644 index 000000000..1da89fe5e --- /dev/null +++ b/mindnlp/core/profiler/profiler.py @@ -0,0 +1,76 @@ +from mindspore.profiler import ProfilerActivity +from typing import Optional, Iterable, Callable, Any + +from mindspore import Profiler as Profiler +try: + from mindspore.profiler import tensor_board_trace_handler +except: + from mindspore.profiler import tensorboard_trace_handler + tensor_board_trace_handler = None + +from mindspore.profiler.schedule import ProfilerAction +from .scheduler import Schedule +from .experimental_config import _ExperimentalConfig + + +if tensor_board_trace_handler is not None: + def tensorboard_trace_handler(dir_name: str = None, worker_name: str = None, + analyse_flag: bool = True, async_mode: bool = False): + def voidfunc(): + pass + if analyse_flag: + return (tensor_board_trace_handler, dir_name) + else: + return (voidfunc, dir_name) + + +class profile: + def __init__( + self, + *, + activities: Optional[Iterable[ProfilerActivity]] = None, + schedule: Optional[Schedule] = None, + on_trace_ready: Optional[tuple] = None, + record_shapes: bool = False, + profile_memory: bool = False, + with_stack: bool = False, + with_flops: bool = False, + with_modules: bool = False, + experimental_config: Optional[_ExperimentalConfig] = None, + # deprecated: + use_cuda: Optional[bool] = None, + ): + if on_trace_ready is not None: + if isinstance(on_trace_ready, tuple): + (on_trace_ready, dir_name) = on_trace_ready + else: + dir_name = ".data" + else: + dir_name = ".data" + + self.profiler = Profiler( + start_profile = False, + output_path = dir_name, + profiler_level = experimental_config._profiler_level, + activities = activities, + schedule = schedule.scheduler, + on_trace_ready = on_trace_ready, + profile_memory = profile_memory, + aicore_metrics = experimental_config._aic_metrics, + with_stack = with_stack, + data_simplification = experimental_config._data_simplification, + l2_cache = experimental_config._l2_cache, + mstx = experimental_config._msprof_tx + ) + + def start(self): + self.profiler.start() + + def stop(self): + self.profiler.stop() + + def step(self): + self.profiler.step() + +def analyse(profiler_path: str, max_process_number:int): + Profiler.analyse(profiler_path) diff --git a/mindnlp/core/profiler/scheduler.py b/mindnlp/core/profiler/scheduler.py new file mode 100644 index 000000000..b177cf9e5 --- /dev/null +++ b/mindnlp/core/profiler/scheduler.py @@ -0,0 +1,11 @@ +from mindspore.profiler import schedule + +class Schedule: + def __init__(self, wait: int, active: int, warmup: int = 0, repeat: int = 0, skip_first: int = 0) -> None: + self.scheduler = schedule( + wait=wait, + active=active, + warmup=warmup, + repeat=repeat, + skip_first=skip_first + ) diff --git a/mindnlp/core/random.py b/mindnlp/core/random.py new file mode 100644 index 000000000..e9d35838a --- /dev/null +++ b/mindnlp/core/random.py @@ -0,0 +1,158 @@ +# mypy: allow-untyped-defs +import contextlib +import warnings +from typing import Generator + +from mindnlp import core +from mindspore import default_generator, set_seed + +def get_rng_state(): + """ + Get the state of the default generator. + + Returns: + Tensor, generator state. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import numpy as np + >>> from mindspore import get_rng_state + >>> state = get_rng_state() + """ + return default_generator.get_state() + + +def set_rng_state(state): # pylint: disable=redefined-outer-name + """ + Set the state of the default generator. + + Args: + state (Tensor): the target state + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> from mindspore import set_rng_state, get_rng_state + >>> state = get_rng_state() + >>> set_rng_state(state) + """ + default_generator.set_state(state) + +def manual_seed(seed): + r"""Sets the seed for generating random numbers on all devices. Returns a + `core.Generator` object. + + Args: + seed (int): The desired seed. Value must be within the inclusive range + `[-0x8000_0000_0000_0000, 0xffff_ffff_ffff_ffff]`. Otherwise, a RuntimeError + is raised. Negative inputs are remapped to positive values with the formula + `0xffff_ffff_ffff_ffff + seed`. + """ + seed = int(seed) + # set_seed(seed) + return default_generator.manual_seed(seed) + + +def seed() -> int: + r"""Sets the seed for generating random numbers to a non-deterministic + random number on all devices. Returns a 64 bit number used to seed the RNG. + """ + seed = default_generator.seed() + + return seed + + +def initial_seed() -> int: + r"""Returns the initial seed for generating random numbers as a + Python `long`. + + .. note:: The returned seed is for the default generator on CPU only. + """ + return default_generator.initial_seed() + + +_fork_rng_warned_already = False + + +@contextlib.contextmanager +def fork_rng( + devices=None, + enabled=True, + _caller="fork_rng", + _devices_kw="devices", + device_type="cuda", +) -> Generator: + """ + Forks the RNG, so that when you return, the RNG is reset + to the state that it was previously in. + + Args: + devices (iterable of Device IDs): devices for which to fork + the RNG. CPU RNG state is always forked. By default, :meth:`fork_rng` operates + on all devices, but will emit a warning if your machine has a lot + of devices, since this function will run very slowly in that case. + If you explicitly specify devices, this warning will be suppressed + enabled (bool): if ``False``, the RNG is not forked. This is a convenience + argument for easily disabling the context manager without having + to delete it and unindent your Python code under it. + device_type (str): device type str, default is `cuda`. As for custom device, + see details in [Note: support the custom device with privateuse1] + """ + + if device_type == "meta": + yield + return + + device_type = core.device(device_type).type + device_mod = getattr(torch, device_type, None) + if device_mod is None: + raise RuntimeError( + f"torch has no module of `{device_type}`, you should register " + + "a module by `core._register_device_module`." + ) + global _fork_rng_warned_already + + # Internal arguments: + # _caller: the function which called fork_rng, which the user used + # _devices_kw: the devices keyword of _caller + + if not enabled: + yield + return + + if devices is None: + num_devices = device_mod.device_count() + if num_devices > 1 and not _fork_rng_warned_already: + message = ( + f"{device_type.upper()} reports that you have {num_devices} available devices, and " + f"you have used {_caller} without explicitly specifying which devices are being used. " + f"For safety, we initialize *every* {device_type.upper()} device by default, which can " + f"be quite slow if you have a lot of {device_type.upper()}s. If you know that you are only" + f" making use of a few {device_type.upper()} devices, set the environment variable " + f"{device_type.upper()}_VISIBLE_DEVICES or the '{_devices_kw}' keyword argument of {_caller} " + "with the set of devices you are actually using. For example, if you are using CPU only, " + "set device.upper()_VISIBLE_DEVICES= or devices=[]; if you are using device 0 only, " + f"set {device_type.upper()}_VISIBLE_DEVICES=0 or devices=[0]. To initialize all devices " + f"and suppress this warning, set the '{_devices_kw}' keyword argument to " + f"`range(core.{device_type}.device_count())`." + ) + warnings.warn(message) + _fork_rng_warned_already = True + devices = list(range(num_devices)) + else: + # Protect against user passing us a generator; we need to traverse this + # multiple times but a generator will be exhausted upon first traversal + devices = list(devices) + + cpu_rng_state = core.get_rng_state() + device_rng_states = [device_mod.get_rng_state(device) for device in devices] + + try: + yield + finally: + core.set_rng_state(cpu_rng_state) + for device, device_rng_state in zip(devices, device_rng_states): + device_mod.set_rng_state(device_rng_state, device) diff --git a/mindnlp/core/return_types.py b/mindnlp/core/return_types.py new file mode 100644 index 000000000..52f422152 --- /dev/null +++ b/mindnlp/core/return_types.py @@ -0,0 +1,50 @@ +import inspect + +from .utils._pytree import register_pytree_node, SequenceKey + + +__all__ = ["pytree_register_structseq", "all_return_types"] + +all_return_types = [] + +# error: Module has no attribute "_return_types" +return_types = {} # type: ignore[attr-defined] + + +def pytree_register_structseq(cls): + def structseq_flatten(structseq): + return list(structseq), None + + def structseq_flatten_with_keys(structseq): + values, context = structseq_flatten(structseq) + return [(SequenceKey(i), v) for i, v in enumerate(values)], context + + def structseq_unflatten(values, context): + return cls(values) + + register_pytree_node( + cls, + structseq_flatten, + structseq_unflatten, + flatten_with_keys_fn=structseq_flatten_with_keys, + ) + + +for name in dir(return_types): + if name.startswith("__"): + continue + + _attr = getattr(return_types, name) + globals()[name] = _attr + + if not name.startswith("_"): + __all__.append(name) + all_return_types.append(_attr) + + # Today everything in core.return_types is a structseq, aka a "namedtuple"-like + # thing defined by the Python C-API. We're going to need to modify this when that + # is no longer the case. + # NB: I don't know how to check that something is a "structseq" so we do a fuzzy + # check for tuple + if inspect.isclass(_attr) and issubclass(_attr, tuple): + pytree_register_structseq(_attr) \ No newline at end of file diff --git a/mindnlp/core/serialization.py b/mindnlp/core/serialization.py new file mode 100644 index 000000000..ac9dce42d --- /dev/null +++ b/mindnlp/core/serialization.py @@ -0,0 +1,1612 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Serialization utils +""" +import os +import io +import sys +import pickle +import shutil +import zipfile +import tarfile +import pathlib +import warnings +import tempfile +import operator +import struct +import mmap +import json +import math +import logging + +from contextlib import closing, contextmanager +from enum import Enum +from typing import Dict, Union, Optional, Any, OrderedDict +from functools import reduce +from dataclasses import dataclass + +import numpy as np +import mindspore + +from mindspore._c_expression import Tensor as MSTensor +from mindspore.train.serialization import _exec_save, _parse_ckpt_proto, tensor_to_np_type, tensor_to_ms_type + +import safetensors +import safetensors.numpy +from safetensors import deserialize + +from mindnlp import core +from .configs import SUPPORT_BF16 +from .nn import Module, Parameter + + +if SUPPORT_BF16: + from mindspore.common.np_dtype import bfloat16 # pylint: disable=import-error +else: + from ml_dtypes import bfloat16 + +MAGIC_NUMBER = 0x1950a86a20f9469cfc6c +PROTOCOL_VERSION = 1001 + +@contextmanager +def mkdtemp(): + """ + Context manager that creates a temporary directory and provides its path. + + Usage: + with mkdtemp() as path: + # Use the temporary directory at 'path' + + Args: + This function does not take any parameters. + + Returns: + None. + + Raises: + This function does not raise any exceptions. + """ + path = tempfile.mkdtemp() + try: + yield path + finally: + shutil.rmtree(path) + +class PyTorchFileReader: + """ + Class to allow PackageImporter to operate on unzipped packages. Methods + copy the behavior of the internal PyTorchFileReader class (which is used for + accessing packages in all other cases). + + N.B.: ScriptObjects are not depickleable or accessible via this DirectoryReader + class due to ScriptObjects requiring an actual PyTorchFileReader instance. + """ + def __init__(self, file): + """ + Initializes a new instance of PyTorchFileReader. + + Args: + self (PyTorchFileReader): The instance of the PyTorchFileReader class. + file (str): The path to the zip file to be read. + + Returns: + None. This method initializes the PyTorchFileReader instance with the provided file. + + Raises: + IOError: If the file specified by the 'file' parameter does not exist or cannot be opened. + zipfile.BadZipFile: If the file specified by the 'file' parameter is not a valid zip file. + IndexError: If the zip file does not contain any files. + """ + + self.file = zipfile.ZipFile(file) + if hasattr(file, 'offset'): + file.seek(0) + bytes = file.read(file.len) + bytes = io.BytesIO(bytes) + self.file = zipfile.ZipFile(bytes) + + self.directory = self.file.namelist()[0].split('/')[0] + + def open_record(self, name): + """ + Opens a record file from the PyTorchFileReader directory. + + Args: + self (PyTorchFileReader): The instance of the PyTorchFileReader class. + name (str): The name of the record file to open. + + Returns: + None: If the specified record file does not exist in the PyTorchFileReader directory. + + Raises: + None. + + This method checks if the specified record file exists in the PyTorchFileReader directory. If it does, the file is opened and returned. If the file does not exist, None is returned. + """ + filename = f"{self.directory}/{name}" + if filename in self.file.namelist(): + return self.file.open(filename) + return None + + def read_record(self, name): + """ + Reads a record from a PyTorch file. + + Args: + self (PyTorchFileReader): An instance of the PyTorchFileReader class. + name (str): The name of the record to read from the PyTorch file. + + Returns: + None: If the record with the specified name does not exist in the PyTorch file. + + Raises: + FileNotFoundError: If the PyTorch file does not exist in the specified directory. + IOError: If there is an error in reading the PyTorch file. + + """ + filename = f"{self.directory}/{name}" + if filename in self.file.namelist(): + return self.file.read(filename) + return None + + def has_record(self, name): + """ + This method checks if a record with the specified name exists in the PyTorchFileReader's directory. + + Args: + self (PyTorchFileReader): An instance of the PyTorchFileReader class. + name (str): The name of the record to be checked in the directory. + + Returns: + None: This method returns None. + + Raises: + None + """ + filename = f"{self.directory}/{name}" + return filename in self.file.namelist() + + def get_all_records( + self, + ): + """ + Retrieves a list of all records from the PyTorchFileReader object. + + Args: + self: The PyTorchFileReader object itself. + + Returns: + None. This method does not return any value. + + Raises: + None. + + This method iterates through the files in the PyTorchFileReader object's directory and retrieves the names of all records. The records are then returned as a list of file names. + + Note: + - The PyTorchFileReader object must be initialized with a valid directory. + - The list of file names returned only includes the names of the files, without the directory path. + """ + files = [name.replace(self.directory + '/' , '')for name in self.file.namelist()] + return files + + def get_record_offset(self, name): + """ + Returns the header offset of a specified record in a PyTorch file. + + Args: + self (PyTorchFileReader): An instance of the PyTorchFileReader class. + name (str): The name of the record for which the header offset is to be retrieved. + + Returns: + None: If the specified record does not exist in the PyTorch file. + + Raises: + None. + + This method takes in the self parameter, which is an instance of the PyTorchFileReader class. It also takes a name parameter, which represents the name of the record for which the header offset is to +be retrieved. The method checks if the specified record exists in the PyTorch file by creating the filename using the directory attribute of the PyTorchFileReader instance and the provided name. If the +filename exists in the file's namelist, the method returns the header offset of the file info associated with the filename. Otherwise, it returns None, indicating that the specified record does not exist in +the file. + """ + filename = f"{self.directory}/{name}" + if filename in self.file.namelist(): + return self.file.getinfo(filename).header_offset + return None + +class PyTorchFileWriter: + def __init__(self, file): + self.zipfile = zipfile.ZipFile(file, mode='w') + self.written_records = set() + + def write_record(self, name, data, offset=0): + if name in self.written_records: + raise RuntimeError(f"Record {name} already written") + self.written_records.add(name) + self.zipfile.writestr(name, data) + + def write_end_of_file(self): + pass + + def get_all_written_records(self): + return self.written_records + +class LoadEndianness(Enum): + + """ + Represents an enumeration for specifying the byte order (endianness) of a data load. + + This class inherits from the built-in Enum class in Python and provides a set of pre-defined constants for different byte orders. The byte order determines the arrangement of bytes in a multi-byte data +type, such as integers and floating-point numbers, when it is stored or transmitted. + + Attributes: + BIG_ENDIAN: Represents the big-endian byte order where the most significant byte is stored first. + LITTLE_ENDIAN: Represents the little-endian byte order where the least significant byte is stored first. + NATIVE: Represents the native byte order of the underlying platform. + NETWORK: Represents the byte order used in network byte order, which is big-endian. + + The LoadEndianness class allows you to easily specify the desired byte order when loading data, ensuring compatibility with the expected byte order. It provides a convenient and readable way to work with +different byte orders without the need for manual byte swapping or conversion. + + Usage: + The LoadEndianness class can be used to specify the byte order when loading data from a file, network, or any other data source. Simply import the class and use the desired constant to set the byte +order. + + Example: + >>> load_endianness = LoadEndianness.BIG_ENDIAN + >>> data = load_data(source_file, byte_order=load_endianness) + >>> print(data) + + Note: + It is important to ensure that the byte order specified matches the actual byte order of the data being loaded. Using the wrong byte order can lead to incorrect interpretation of the data and produce +unexpected results. + + """ + NATIVE = 1 + LITTLE = 2 + BIG = 3 + +_default_load_endian: Optional[LoadEndianness] = None + +def get_default_load_endianness() -> Optional[LoadEndianness]: + ''' + Get fallback byte order for loading files + + If byteorder mark is not present in saved checkpoint, + this byte order is used as fallback. + By default, it's "native" byte order. + + Returns: + default_load_endian: Optional[LoadEndianness] + ''' + return _default_load_endian + +def set_default_load_endianness(endianness): + ''' + Set fallback byte order for loading files + + If byteorder mark is not present in saved checkpoint, + this byte order is used as fallback. + By default, it's "native" byte order. + + Args: + endianness: the new fallback byte order + ''' + global _default_load_endian + if not isinstance(endianness, LoadEndianness) and endianness is not None: + raise TypeError("Invalid argument type in function set_default_load_endianness") + _default_load_endian = endianness + +def _is_zipfile(f) -> bool: + """ + Args: + f (file object): The file object to be checked for being a valid zip file. + It should be opened in binary mode and point to the beginning of the file. + + Returns: + bool: Returns True if the input file is a valid zip file, otherwise False. + + Raises: + No specific exceptions are raised by this function. + """ + # This is a stricter implementation than zipfile.is_zipfile(). + # zipfile.is_zipfile() is True if the magic number appears anywhere in the + # binary. Since we expect the files here to be generated by core.save or + # core.jit.save, it's safe to only check the start bytes and avoid + # collisions and assume the zip has only 1 file. + # See bugs.python.org/issue28494. + + # Read the first 4 bytes of the file + read_bytes = [] + start = f.tell() + + byte = f.read(1) + while byte != b"": + read_bytes.append(byte) + if len(read_bytes) == 4: + break + byte = f.read(1) + f.seek(start) + + local_header_magic_number = [b'P', b'K', b'\x03', b'\x04'] + return read_bytes == local_header_magic_number + +def _check_seekable(f) -> bool: + """ + Checks if the given file object is seekable. + + Args: + f (file object): The file object to be checked for seekability. + + Returns: + bool: True if the file object is seekable, False otherwise. + + Raises: + UnsupportedOperation: If the file object does not support seek or tell operations. + AttributeError: If the file object does not have the seek or tell attribute. + """ + def raise_err_msg(patterns, e): + for p in patterns: + if p in str(e): + msg = (str(e) + ". You can only core.load from a file that is seekable." + + " Please pre-load the data into a buffer like io.BytesIO and" + + " try to load from it instead.") + raise type(e)(msg) + raise e + + try: + f.seek(f.tell()) + return True + except (io.UnsupportedOperation, AttributeError) as e: + raise_err_msg(["seek", "tell"], e) + return False + +def _is_compressed_file(f) -> bool: + """ + Checks whether the given file is a compressed file. + + Args: + f (object): The file object to be checked. + + Returns: + bool: Returns True if the file is compressed, False otherwise. + + Raises: + None. + + """ + compress_modules = ['gzip'] + try: + return f.__module__ in compress_modules + except AttributeError: + return False + +def _should_read_directly(f): + """ + Checks if f is a file that should be read directly. It should be read + directly if it is backed by a real file (has a fileno) and is not a + a compressed file (e.g. gzip) + """ + if _is_compressed_file(f): + return False + try: + return f.fileno() >= 0 + except io.UnsupportedOperation: + return False + except AttributeError: + return False + + +def _is_path(name_or_buffer): + """ + Check if the input is a valid path. + + Args: + name_or_buffer (str or pathlib.Path): A string representing a file path or a pathlib.Path object. + + Returns: + None: This function does not return any value. + + Raises: + None + """ + return isinstance(name_or_buffer, (str, pathlib.Path)) + +def _is_torchscript_zip(zip_file): + """ + Checks if the given zip file contains a specific record. + + Args: + zip_file (object): The zip file to be checked for the presence of a specific record. + + Returns: + None: This function does not return any value. + + Raises: + None + """ + return 'constants.pkl' in zip_file.get_all_records() + +class _opener: + + """ + Class `_opener` represents a context manager for opening files in Python. It inherits from the built-in `object` class. + + This class provides a convenient way to work with file-like objects by allowing them to be used within a `with` statement. The `_opener` class implements the `__init__`, `__enter__`, and `__exit__` methods. + + __init__(self, file_like): + Initializes an instance of the `_opener` class. + + Parameters: + - file_like: A file-like object that will be used for reading or writing operations. + + __enter__(self): + Returns the file-like object passed during initialization. + + Returns: + The file-like object for use within the `with` statement. + + __exit__(self, *args): + Performs cleanup operations after the `with` statement block is executed. + + Parameters: + - *args: Any exception arguments passed by the Python interpreter. + + Note: + This method does not handle exceptions. It is designed to be used as a context manager and should be used in conjunction with a `try-except-finally` block to handle exceptions properly. + """ + def __init__(self, file_like): + """ + Initializes an instance of the '_opener' class. + + Args: + self (object): The instance of the '_opener' class. + file_like (object): A file-like object representing the file to be opened. + It can be a file object, a file path, or any object with a file-like interface. + The object must support the 'read' method. + + Returns: + None. This method does not return any value. + + Raises: + None. This method does not raise any exceptions. + """ + self.file_like = file_like + + def __enter__(self): + """ + The '__enter__' method is a special method in the '_opener' class that is used to set up the context for an object. It is called when using the 'with' statement in Python. + + Args: + self: An instance of the '_opener' class. + + Returns: + None. This method does not return any value. + + Raises: + This method does not raise any exceptions. + """ + return self.file_like + + def __exit__(self, *args): + """ + Method '__exit__' in the class '_opener'. + + Args: + self: The instance of the class. + Type: object + Purpose: Represents the instance of the class. + Restrictions: None + + Returns: + None: Indicates that the method does not return any value. + Type: None + Purpose: Signifies the absence of a return value. + + Raises: + No specific exceptions are raised by this method. + """ + + +class _open_file(_opener): + + """ + _open_file represents a class that inherits from _opener and provides methods for opening and closing files. + + This class initializes an instance of _open_file with the given name and mode and utilizes the super() function to call the __init__ method of the _opener class with the opened file. + + The __exit__ method is implemented to close the file-like object when the instance is exited. + + Attributes: + name (str): The name of the file to be opened. + mode (str): The mode in which the file should be opened. + + Methods: + __init__(name, mode): + Initializes an instance of _open_file with the given name and mode. + + __exit__(*args): + Closes the file-like object when the instance is exited. + """ + def __init__(self, name, mode): + """ + __init__ + + Initializes an instance of the _open_file class. + + Args: + self: _open_file instance + The instance of the _open_file class. + + name: str + The name of the file to be opened. + + mode: str + The mode in which the file should be opened. It should be a string that represents the mode in which the file is opened. It can be 'r' for reading, 'w' for writing, or 'a' for appending. Other +modes are also supported. + + Returns: + None + This method does not return any value. + + Raises: + OSError + If an error occurs while opening the file, an OSError is raised. + """ + super().__init__(open(name, mode)) + + def __exit__(self, *args): + """ + This method __exit__ is used in the class _open_file to handle the cleanup operations when exiting a context manager. + + Args: + - self (object): The instance of the _open_file class. Represents the context manager itself. + + Returns: + None. This method does not return any value explicitly. + + Raises: + This method does not raise any exceptions explicitly. However, it indirectly depends on the behavior of the 'close()' method of the file-like object it operates on, which may raise exceptions related +to file I/O operations. + """ + self.file_like.close() + + +class _open_buffer_reader(_opener): + + """ + A class representing an open buffer reader for reading files. + + This class is a subclass of _opener and provides functionality for reading files from a buffer. + The class's forwardor takes a buffer as input and initializes the buffer for reading. + It also performs a check to ensure that the buffer is seekable before proceeding with reading operations. + """ + def __init__(self, buffer): + """ + Initializes an instance of the '_open_buffer_reader' class. + + Args: + self: The instance of the '_open_buffer_reader' class. + buffer: The buffer object to be read. It should be a seekable object. + + Returns: + None + + Raises: + TypeError: If the 'buffer' parameter is not a seekable object. + """ + super().__init__(buffer) + _check_seekable(buffer) + + +class _open_buffer_writer(_opener): + + """ + _open_buffer_writer is a Python class that represents a buffered writer for file-like objects. This class inherits from the _opener class. + + Usage: + The _open_buffer_writer class provides a convenient way to write data to file-like objects with buffering capabilities. + + Attributes: + file_like (file-like object): The file-like object to which data will be written. + + Methods: + __init__(self, file_like): + Initializes a new instance of the _open_buffer_writer class. + + Args: + file_like (file-like object): The file-like object to which data will be written. + + write(self, data): + Writes the given data to the file-like object. + + Args: + data (str): The data to be written. + + Returns: + None + + flush(self): + Flushes the buffer and writes any remaining data to the file-like object. + + Returns: + None + + __enter__(self): + Enters the context manager and returns the _open_buffer_writer instance. + + Returns: + _open_buffer_writer: The current _open_buffer_writer instance. + + __exit__(self, *args): + Exits the context manager and performs necessary cleanup operations. + + Args: + *args: Variable length argument list. + + Returns: + None + """ + def __exit__(self, *args): + """ + __exit__ + + Args: + self: _open_buffer_writer + The instance of the _open_buffer_writer class. + + Returns: + None: + This method does not return any value. + + Raises: + N/A + """ + self.file_like.flush() + + +def _open_file_like(name_or_buffer, mode): + """ + Args: + name_or_buffer (str or buffer): The name of the file or a buffer object to be opened. If a string, it represents the file path. If a buffer, it represents a memory buffer. + mode (str): The mode in which the file or buffer should be opened. It should be either 'r' for reading or 'w' for writing. + + Returns: + None: This function does not return a value. + + Raises: + RuntimeError: If the mode is not 'r' or 'w'. + """ + if _is_path(name_or_buffer): + return _open_file(name_or_buffer, mode) + if 'w' in mode: + return _open_buffer_writer(name_or_buffer) + if 'r' in mode: + return _open_buffer_reader(name_or_buffer) + raise RuntimeError(f"Expected 'r' or 'w' in mode but got {mode}") + +class _open_zipfile_reader(_opener): + + """ + The _open_zipfile_reader class represents a reader for opening and reading zip files. + It inherits from the _opener class and provides functionality for reading zip files. + + Attributes: + name_or_buffer: The name or buffer of the file to be opened. + + Methods: + __init__: Initializes the _open_zipfile_reader instance, using the specified name_or_buffer to open a PyTorchFileReader. + """ + def __init__(self, name_or_buffer) -> None: + """ + Initializes the _open_zipfile_reader class. + + Args: + self (object): The instance of the _open_zipfile_reader class. + name_or_buffer (str or file-like object): The name of the file or a buffer object for reading the zipfile. + It can be a string representing the name of the file or a file-like object for reading the zipfile data. + + Returns: + None: This method does not return any value. + + Raises: + - TypeError: If the name_or_buffer parameter is not a string or file-like object. + - ValueError: If the name_or_buffer parameter is empty or invalid. + - IOError: If there is an error reading the zipfile from the provided name_or_buffer. + """ + super().__init__(PyTorchFileReader(name_or_buffer)) + +class _open_zipfile_writer_file(_opener): + def __init__(self, name): + self.file_stream = None + self.name = str(name) + try: + self.name.encode('ascii') + except UnicodeEncodeError: + self.file_stream = io.FileIO(self.name, mode='w') + super().__init__(PyTorchFileWriter(self.file_stream)) + else: + super().__init__(PyTorchFileWriter(self.name)) + + def __exit__(self, *args): + self.file_like.write_end_of_file() + if self.file_stream is not None: + self.file_stream.close() + +class _open_zipfile_writer_buffer(_opener): + def __init__(self, buffer): + if not callable(getattr(buffer, "write", None)): + msg = f"Buffer of {str(type(buffer)).strip('<>')} has no callable attribute 'write'" + if not hasattr(buffer, "write"): + raise AttributeError(msg) + raise TypeError(msg) + self.buffer = buffer + super().__init__(PyTorchFileWriter(buffer)) + + def __exit__(self, *args): + self.file_like.write_end_of_file() + self.buffer.flush() + +def _open_zipfile_writer(name_or_buffer): + if _is_path(name_or_buffer): + container = _open_zipfile_writer_file + else: + container = _open_zipfile_writer_buffer + return container(name_or_buffer) + +def _rebuild_parameter(data, requires_grad, backward_hooks): + param = Parameter(data, requires_grad=requires_grad) + # NB: This line exists only for backwards compatibility; the + # general expectation is that backward_hooks is an empty + # OrderedDict. See Note [Don't serialize hooks] + return param + +def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata=None): + '''Rebuilds a tensor based on the provided parameters. + + Args: + storage (ndarray): The storage array from which the tensor is created. + storage_offset (int): The offset in the storage array from where the tensor data starts. + size (tuple): The size of the tensor. + stride (tuple or None): The stride of the tensor, or None if not applicable. + requires_grad (bool): Indicates if the tensor requires gradient computation. + backward_hooks (list): A list of backward hooks for the tensor. + metadata (Any, optional): Additional metadata associated with the tensor. + + Returns: + None: This function does not have a return value. + + Raises: + None: This function does not raise any exceptions. + ''' + if size == (): + num_elemets = 1 + else: + num_elemets = reduce(operator.mul, size) + array = storage[storage_offset: storage_offset + num_elemets] + + if array.dtype == bfloat16 and not SUPPORT_BF16: + array = array.astype(np.float16) + + if stride is not None and len(stride) > 1 and stride[0] == 1: + # stride = tuple((s * 4 for s in stride)) + # # stride = tuple((s * 4 if s != 1 else s for s in stride)) + # array = np.lib.stride_tricks.as_strided(array, size, stride) + order = "F" + array = array.reshape(size, order=order) + else: + order = "C" + array = array.reshape(size, order=order) + param = core.from_numpy(array) + return param + +def _rebuild_from_type_v2(func, new_type, args, state): + ret = func(*args) + return ret + +@dataclass +class FakeParameter: + + """ + This class represents a fake parameter in Python. + + The 'FakeParameter' class inherits from [insert inherited class here]. + + Class Attributes: + [List any class attributes here, if applicable] + + Instance Attributes: + [List any instance attributes here, if applicable] + + Methods: + [List all the methods of the class here, along with their descriptions] + + - [method name]: [method description] + - [method name]: [method description] + - ... + + Usage: + [Explain how to use the 'FakeParameter' class, including any important details or considerations] + + Example: + [Provide an example usage of the 'FakeParameter' class] + + >>> [code example] + + """ + storage: np.ndarray = None + storage_offset: int = None + size: tuple = None + stride: tuple = None + requires_grad: bool = None + +@dataclass +class FakeStorage: + + """ + This class represents a fake storage system in Python. + + The 'FakeStorage' class is designed to mimic a real storage system but without any actual functionality. It serves as a placeholder or a testing tool for applications that require a storage system. + + Attributes: + None. + + Methods: + None. + + Inheritance: + This class does not inherit from any other class. + + Usage: + 1. Instantiate the 'FakeStorage' class to create a fake storage object. + 2. Use the object to simulate storage-related operations without actually interacting with a real storage system. + 3. Since this class does not have any attributes or methods, its usefulness lies in its ability to stand in for a real storage system during testing or development. + + Note: + - This class is not intended for production use and should only be used for testing or development purposes. + - It is recommended to replace instances of 'FakeStorage' with a real storage system before deploying the application. + + Example: + from fake_storage import FakeStorage + + storage = FakeStorage() + storage.upload_file('file.txt') + storage.download_file('file.txt') + storage.delete_file('file.txt') + + The above example demonstrates how to use the 'FakeStorage' class to simulate storage operations. However, since this is a fake storage system, no files are actually uploaded, downloaded, or deleted. + + """ + storage: np.ndarray = None + +def _rebuild_tensor_legacy(storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata=None): + """ + This function rebuilds a tensor using legacy parameters. + + Args: + storage (Tensor): The storage for the tensor. + storage_offset (int): The offset within the storage. + size (tuple): The size of the tensor. + stride (tuple): The stride of the tensor. + requires_grad (bool): Indicates if gradients need to be computed for the tensor. + backward_hooks (dict): Dictionary containing backward hooks for the tensor. + metadata (optional): Additional metadata for the tensor. + + Returns: + None. This function does not return any value. + + Raises: + None. This function does not raise any exceptions. + """ + return FakeParameter(storage, storage_offset, size, stride, requires_grad) + +def _maybe_decode_ascii(bytes_str: Union[bytes, str]) -> str: + """ + This function decodes a bytes string to ASCII if it is a bytes type, otherwise returns the input string. + + Args: + bytes_str (Union[bytes, str]): A bytes or string input to be decoded if it is a bytes type. If it is already a string, it will be returned as is. + + Returns: + str: The decoded ASCII string if the input is of bytes type, otherwise the original string. + + Raises: + None + """ + # When using encoding='bytes' in Py3, some **internal** keys stored as + # strings in Py2 are loaded as bytes. This function decodes them with + # ascii encoding, one that Py3 uses by default. + # + # NOTE: This should only be used on internal keys (e.g., `typename` and + # `location` in `persistent_load` below! + if isinstance(bytes_str, bytes): + return bytes_str.decode('ascii') + return bytes_str + + +dtype_map = { + "HalfStorage": np.float16, + "FloatStorage": np.float32, + 'BFloat16Storage': bfloat16, + 'LongStorage': np.int64, + 'ByteStorage': np.uint8, + 'BoolStorage': np.bool_ +} + +storage_map = { + mindspore.float16: "HalfStorage", + mindspore.float32: "FloatStorage", + mindspore.bfloat16: 'BFloat16Storage', + mindspore.int64: 'LongStorage', + mindspore.int32: 'IntStorage', + mindspore.uint8: 'ByteStorage', + mindspore.bool_: 'BoolStorage' +} + +element_size_map = { + "HalfStorage": 2, + "FloatStorage": 3, + 'BFloat16Storage': 2, + 'LongStorage': 4, + 'ByteStorage': 1, + 'BoolStorage': 1 +} + +def load(f, map_location=None, pickle_module=pickle, *, weights_only=False, mmap=None, **pickle_load_args): + """ + Load a file using pickle, optionally with memory mapping. + + Args: + f (file-like object or str): The file to load from. If a string is provided, it should be the filename. + pickle_module (module): The module to use for pickling. Defaults to the standard 'pickle' module. + + Returns: + None: This function does not return any value. + + Raises: + ValueError: Raised if 'f' is not a string filename when using mmap argument, or if torchscript is detected in a zipfile. + RuntimeError: Raised if mmap argument is used without files saved with `core.save(_use_new_zipfile_serialization=True)`. + """ + if pickle_module is None: + pickle_module = pickle + + # make flipping default BC-compatible + if mmap is None: + mmap = False + + if 'encoding' not in pickle_load_args: + pickle_load_args['encoding'] = 'utf-8' + + with _open_file_like(f, 'rb') as opened_file: + if _is_zipfile(opened_file): + # The zipfile reader is going to advance the current file position. + # If we want to actually tail call to core.jit.load, we need to + # reset back to the original position. + overall_storage = None + with _open_zipfile_reader(opened_file, ) as opened_zipfile: + if _is_torchscript_zip(opened_zipfile): + raise ValueError('do not support torchscript now') + if mmap: + if not isinstance(f, str): + raise ValueError("f must be a string filename in order to use mmap argument") + overall_storage = f + + return _load(opened_zipfile, + pickle_module, + overall_storage=overall_storage, + **pickle_load_args) + if mmap: + raise RuntimeError("mmap can only be used with files saved with ", + "`core.save(_use_new_zipfile_serialization=True), " + "please core.save your checkpoint with this option in order to use mmap.") + + return _legacy_load(opened_file, pickle_module, **pickle_load_args) + +def _legacy_load(f, pickle_module, **pickle_load_args): + """ + Args: + f (file-like object): The file-like object containing the serialized data to be loaded. + pickle_module (module): The module used for unpickling the serialized data. + + Returns: + None. This function does not return any value. + + Raises: + ValueError: Raised if legacy load for MindSpore is not supported. + RuntimeError: Raised if an unknown saved id type is encountered during deserialization. + RuntimeError: Raised if the magic number in the file does not match the expected value. + RuntimeError: Raised if the protocol version in the file does not match the expected value. + RuntimeError: Raised if there is an issue with the file-like object compatibility with core.load. + """ + deserialized_objects: Dict[int, Any] = {} + + class UnpicklerWrapper(pickle_module.Unpickler): # type: ignore[name-defined] + + def find_class(self, mod_name, name): + if name == '_rebuild_tensor_v2': + name = '_rebuild_tensor_legacy' + if mod_name == 'core._utils': + return eval(name) + if mod_name == 'torch': + return str(name) + return super().find_class(mod_name, name) + + def legacy_load(f): + deserialized_objects: Dict[int, Any] = {} + + def persistent_load(saved_id): + if isinstance(saved_id, tuple): + # Ignore containers that don't have any sources saved + return saved_id[0] + return deserialized_objects[int(saved_id)] + + with closing(tarfile.open(fileobj=f, mode='r:', format=tarfile.PAX_FORMAT)) as tar, \ + mkdtemp() as tmpdir: + raise ValueError('do not support legacy load for MindSpore.') + + deserialized_objects = {} + + def persistent_load(saved_id): + assert isinstance(saved_id, tuple) + typename = _maybe_decode_ascii(saved_id[0]) + data = saved_id[1:] + + if typename == 'module': + # Ignore containers that don't have any sources saved + return data[0] + if typename == 'storage': + storage_type, root_key, location, numel, view_metadata = data + location = _maybe_decode_ascii(location) + + if root_key not in deserialized_objects: + typed_storage = FakeStorage(np.empty(numel, dtype_map[storage_type])) + deserialized_objects[root_key] = typed_storage + else: + typed_storage = deserialized_objects[root_key] + + if view_metadata is not None: + view_key, offset, view_size = view_metadata + if view_key not in deserialized_objects: + # TODO: Once we decide to break serialization FC, we can + # stop wrapping with TypedStorage + deserialized_objects[view_key] = typed_storage[offset: offset + view_size] + res = deserialized_objects[view_key] + else: + res = typed_storage + return res + raise RuntimeError(f"Unknown saved id type: {saved_id[0]}") + + _check_seekable(f) + f_should_read_directly = _should_read_directly(f) + + if f_should_read_directly and f.tell() == 0: + # legacy_load requires that f has fileno() + # only if offset is zero we can attempt the legacy tar file loader + try: + return legacy_load(f) + except tarfile.TarError: + if _is_zipfile(f): + # .zip is used for core.jit.save and will throw an un-pickling error here + raise RuntimeError( + f"{f.name} is a zip archive (did you mean to use core.jit.load()?)") from None + # if not a tarfile, reset file offset and proceed + f.seek(0) + + if not hasattr(f, 'readinto') and (3, 8, 0) <= sys.version_info < (3, 8, 2): + raise RuntimeError( + "core.load does not work with file-like objects that do not implement readinto on Python 3.8.0 and 3.8.1. " + f"Received object of type \"{type(f)}\". Please update to Python 3.8.2 or newer to restore this " + "functionality.") + + magic_number = pickle_module.load(f, **pickle_load_args) + if magic_number != MAGIC_NUMBER: + raise RuntimeError("Invalid magic number; corrupt file?") + protocol_version = pickle_module.load(f, **pickle_load_args) + if protocol_version != PROTOCOL_VERSION: + raise RuntimeError(f"Invalid protocol version: {protocol_version}") + + _sys_info = pickle_module.load(f, **pickle_load_args) + unpickler = UnpicklerWrapper(f, **pickle_load_args) + unpickler.persistent_load = persistent_load + result = unpickler.load() + deserialized_storage_keys = pickle_module.load(f, **pickle_load_args) + + offset = f.tell() if f_should_read_directly else None + for key in deserialized_storage_keys: + assert key in deserialized_objects + typed_storage = deserialized_objects[key].storage + f.read(8) # trick for read + array = np.frombuffer(f.read(typed_storage.nbytes), typed_storage.dtype) + deserialized_objects[key].storage = array + if offset is not None: + offset = f.tell() + + new_result = {} + for k, v in result.items(): + num_elemets = reduce(operator.mul, v.size) + array = v.storage.storage[v.storage_offset: v.storage_offset + num_elemets] + stride = v.stride + size = v.size + if stride is not None and len(stride) > 1 and stride[0] == 1 and stride[1] > 1: + stride = tuple((s * 4 for s in stride)) + array = np.lib.stride_tricks.as_strided(array, size, stride) + else: + order = "C" + array = array.reshape(size, order=order) + if array.dtype == bfloat16 and not SUPPORT_BF16: + array = array.astype(np.float16) + new_result[k] = core.from_numpy(array) + + return new_result + +def _load(zip_file, pickle_module, overall_storage=None, pickle_file='data.pkl', **pickle_load_args): + """ + Loads data from a zip file using pickle serialization. + + Args: + zip_file (zipfile.ZipFile): The zip file containing the data. + pickle_module (module): The pickle module to use for deserialization. + overall_storage (numpy.memmap, optional): The overall storage for loading the data. + pickle_file (str, optional): The name of the pickle file within the zip file. Default is 'data.pkl'. + **pickle_load_args: Additional keyword arguments to pass to the pickle module's load function. + + Returns: + None + + Raises: + ValueError: If an unknown endianness type is encountered. + ValueError: If an invalid load endianness type is encountered. + UserWarning: If the default load endianness is changed on big endian machines. + + """ + loaded_storages = {} + # check if byteswapping is needed + byteordername = 'byteorder' + byteorderdata = None + if zip_file.has_record(byteordername): + byteorderdata = zip_file.read_record(byteordername) + if byteorderdata not in [b'little', b'big']: + raise ValueError('Unknown endianness type: ' + byteorderdata.decode()) + elif get_default_load_endianness() == LoadEndianness.LITTLE or \ + get_default_load_endianness() is None: + byteorderdata = b'little' + elif get_default_load_endianness() == LoadEndianness.BIG: + byteorderdata = b'big' + elif get_default_load_endianness() == LoadEndianness.NATIVE: + pass + else: + raise ValueError('Invalid load endianness type') + + if not zip_file.has_record(byteordername) and \ + get_default_load_endianness() is None and \ + sys.byteorder == 'big': + # Default behaviour was changed + # See https://github.com/pytorch/pytorch/issues/101688 + warnings.warn("The default load endianness for checkpoints without a byteorder mark " + "on big endian machines was changed from 'native' to 'little' endian, " + "to avoid this behavior please use " + "core.serialization.set_default_load_endianness to set " + "the desired default load endianness", + UserWarning) + + def persistent_load(saved_id): + assert isinstance(saved_id, tuple) + typename = _maybe_decode_ascii(saved_id[0]) + data = saved_id[1:] + + assert typename == 'storage', \ + f"Unknown typename for persistent_load, expected 'storage' but got '{typename}'" + storage_type, key, location, numel = data + + name = f'data/{key}' + if name in loaded_storages: + return loaded_storages[name] + + if overall_storage is not None: + array = np.memmap(overall_storage, dtype=dtype_map[storage_type], offset=zip_file.open_record(name)._fileobj.tell(), shape=(numel,)) + else: + array = np.frombuffer(zip_file.read_record(name), dtype_map[storage_type]) + loaded_storages[name] = array + return array + + load_module_mapping: Dict[str, str] = { + # See https://github.com/pytorch/pytorch/pull/51633 + 'core.tensor': 'core._tensor' + } + + # Need to subclass Unpickler instead of directly monkey-patching the find_class method + # because it's marked readonly in pickle. + # The type: ignore is because mypy can't statically determine the type of this class. + class UnpicklerWrapper(pickle_module.Unpickler): # type: ignore[name-defined] + # from https://stackoverflow.com/questions/13398462/unpickling-python-objects-with-a-changed-module-path/13405732 + # Lets us override the imports that pickle uses when unpickling an object. + # This is useful for maintaining BC if we change a module path that tensor instantiation relies on. + def find_class(self, mod_name, name): + if mod_name == 'core._utils': + return eval(name) + if mod_name == 'torch': + return str(name) + if mod_name == 'core._tensor': + return eval(name) + mod_name = load_module_mapping.get(mod_name, mod_name) + return super().find_class(mod_name, name) + + # Load the data (which may in turn use `persistent_load` to load tensors) + data_file = zip_file.open_record(pickle_file) + + unpickler = UnpicklerWrapper(data_file, **pickle_load_args) + unpickler.persistent_load = persistent_load + result = unpickler.load() + + return result + +def convert_torch_to_mindspore(pth_file): + """convert torch checkpoint to mindspore""" + try: + from mindnlp import core # pylint: disable=import-error + except Exception as exc: + raise ImportError("'from mindnlp import core' failed, please install torch by " + "`pip install torch` or instructions from 'https://pycore.org'") \ + from exc + if pth_file.endswith(".safetensors"): + from safetensors.torch import load_file + state_dict = load_file(pth_file) + ms_ckpt_path = pth_file.replace('model-', 'mindspore-') + ms_ckpt_path = ms_ckpt_path.replace('.safetensors', '.ckpt') + + else: + ms_ckpt_path = pth_file.replace('pytorch_model', 'mindspore') + ms_ckpt_path = ms_ckpt_path.replace('.bin', '.ckpt') + + state_dict = core.load(pth_file, map_location='cpu') + + if os.path.exists(ms_ckpt_path): + return ms_ckpt_path + + ms_ckpt = [] + logging.info('Starting checkpoint conversion.') + + has_bf16 = False + for key, value in state_dict.items(): + if value.dtype == core.bfloat16: + data = core.from_numpy(value.to(core.float).numpy().astype(np.float16)) + if not has_bf16: + has_bf16 = True + else: + data = core.from_numpy(value.numpy()) + ms_ckpt.append({'name': key, 'data': data}) + + if has_bf16: + logging.warning("MindSpore do not support bfloat16 dtype, we will automaticlly convert to float16") + + try: + mindspore.save_checkpoint(ms_ckpt, ms_ckpt_path) + except Exception as exc: + raise RuntimeError(f'Save checkpoint to {ms_ckpt_path} failed, ' + f'please checkout the path.') from exc + + return ms_ckpt_path + +def _check_save_filelike(f): + if not isinstance(f, (str, os.PathLike)) and not hasattr(f, 'write'): + raise AttributeError( + "expected 'f' to be string, path, or a file-like object with " + "a 'write' attribute") + +def save(obj, f, pickle_module = pickle, pickle_protocol = 2, _disable_byteorder_record: bool = False): + _check_save_filelike(f) + with _open_zipfile_writer(f) as opened_zipfile: + _save( + obj, + opened_zipfile, + pickle_module, + pickle_protocol, + _disable_byteorder_record, + ) + return + +def _save( + obj, + zip_file, + pickle_module, + pickle_protocol, + _disable_byteorder_record, +): + serialized_storages = {} + id_map: Dict[int, str] = {} + + # Since loading storages that view the same data with different dtypes is + # not supported, we need to keep track of the dtype associated with each + # storage data_ptr and throw an error if the dtype is ever different. + # TODO: This feature could be added in the future + storage_dtypes = {} + + def persistent_id(obj): + # FIXME: the docs say that persistent_id should only return a string + # but torch store returns tuples. This works only in the binary protocol + # see + # https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects + # https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537 + + if isinstance(obj, MSTensor): + storage_type = storage_map[obj.dtype] + storage_numel = obj._size + storage_key = id_map.setdefault(id(obj), str(len(id_map))) + serialized_storages[storage_key] = obj + location = 'cpu' + + return ("storage", storage_type, storage_key, location, storage_numel) + + return None + + # Write the pickle data for `obj` + data_buf = io.BytesIO() + pickler = pickle_module.Pickler(data_buf, protocol=pickle_protocol) + pickler.persistent_id = persistent_id + pickler.dump(obj) + data_value = data_buf.getvalue() + zip_file.write_record("archive/data.pkl", data_value, len(data_value)) + zip_file.write_record("archive/version", bytes(str(3), encoding='utf-8')) + + # Write byte order marker + if not _disable_byteorder_record: + if sys.byteorder not in ["little", "big"]: + raise ValueError("Unknown endianness type: " + sys.byteorder) + + zip_file.write_record("archive/byteorder", sys.byteorder, len(sys.byteorder)) + + # Write each tensor to a file named tensor/the_tensor_key in the zip archive + for key in sorted(serialized_storages.keys()): + name = f"archive/data/{key}" + storage = serialized_storages[key].get_bytes() + num_bytes = len(storage) + zip_file.write_record(name, storage, num_bytes) + + +_MS_TYPES = { + "F64": mindspore.float64, + "F32": mindspore.float32, + "F16": mindspore.float16, + "BF16": mindspore.bfloat16, + "I64": mindspore.int64, + "U64": mindspore.uint64, + "I32": mindspore.int32, + "U32": mindspore.uint32, + "I16": mindspore.int16, + "U16": mindspore.uint16, + "I8": mindspore.int8, + "U8": mindspore.uint8, + "BOOL": mindspore.bool_, +} + +_NP_TYPES = { + "F64": np.float64, + "F32": np.float32, + "F16": np.float16, + "BF16": bfloat16, + "I64": np.int64, + "U64": np.uint64, + "I32": np.int32, + "U32": np.uint32, + "I16": np.int16, + "U16": np.uint16, + "I8": np.int8, + "U8": np.uint8, + "BOOL": bool, +} + +def legacy_safe_load_file(filename): + """ + This function safely loads a file containing state dictionary data and converts it into a dictionary of MindSpore Parameters. + + Args: + filename (str): The path to the file containing the state dictionary data to be loaded. + + Returns: + dict: A dictionary where keys are parameter names and values are MindSpore Parameters. + + Raises: + FileNotFoundError: If the specified file 'filename' does not exist. + ValueError: If the data in the file is not in the correct format to create MindSpore Parameters. + """ + with open(filename, "rb") as f: + data = f.read() + + safeview = deserialize(data) + + result = {} + try: + for k, v in safeview: + dtype = _MS_TYPES[v["dtype"]] + if (not SUPPORT_BF16 and dtype != mindspore.bfloat16) or SUPPORT_BF16: + arr = MSTensor.convert_bytes_to_tensor(bytes(v["data"]), tuple(v["shape"]), dtype) + result[k] = core.Tensor(arr) + else: + raise TypeError('Do not support bfloat16 on current device, use numpy as convert buffer to boost load.') + return result + + except Exception as e: + for k, v in safeview: + dtype = _NP_TYPES[v["dtype"]] + arr = np.frombuffer(v["data"], dtype=dtype).reshape(v["shape"]) + + if (not SUPPORT_BF16 and dtype != bfloat16) or SUPPORT_BF16: + result[k] = core.from_numpy(arr) + else: + result[k] = core.from_numpy(arr.astype(np.float16)) + return result + + +def safe_load_file(filename): + """ + This function safely loads a file containing state dictionary data and converts it into a dictionary of MindSpore Parameters. + + Args: + filename (str): The path to the file containing the state dictionary data to be loaded. + + Returns: + dict: A dictionary where keys are parameter names and values are MindSpore Parameters. + + Raises: + FileNotFoundError: If the specified file 'filename' does not exist. + ValueError: If the data in the file is not in the correct format to create MindSpore Parameters. + """ + def convert(info: dict[str, Any]): + numpy_dtype = _NP_TYPES[info['dtype']] + ms_dtype = _MS_TYPES[info['dtype']] + shape: list[int] = info['shape'] + begin, end = info['data_offsets'] + assert 0 <= begin <= end <= len(byte_buf) + assert end - begin == math.prod(shape) * np.dtype(numpy_dtype).itemsize + buf = byte_buf[begin:end] + + try: + if info['dtype'] == 'BF16' and not SUPPORT_BF16: + ms_dtype = mindspore.float16 + out = MSTensor.convert_bytes_to_tensor(buf, tuple(shape), ms_dtype) + except: + array = np.frombuffer(buf, dtype=numpy_dtype).reshape(shape) + + if array.dtype == bfloat16 and not SUPPORT_BF16: + array = array.astype(np.float16) + array = array.astype(array.dtype) + out = core.from_numpy(array) + return out + + with open(filename, "rb") as fp: + header_size, = struct.unpack(' _int: + raise NotImplementedError + + def __getitem__(self, idx): + raise NotImplementedError + + def __setitem__(self, *args, **kwargs): + raise NotImplementedError + + def copy_(self, source: T, non_blocking: _Optional[_bool] = None) -> T: + raise NotImplementedError + + def new(self) -> Union[_StorageBase, TypedStorage]: + raise NotImplementedError + + def nbytes(self) -> _int: + raise NotImplementedError + + def size(self) -> _int: + return self.nbytes() + + def type( + self, dtype: _Optional[str] = None, non_blocking: _bool = False + ) -> Union[_StorageBase, TypedStorage]: + return _type(self, dtype, non_blocking) + + def cuda( + self, device=None, non_blocking=False + ) -> Union[_StorageBase, TypedStorage]: + """Returns a copy of this object in CUDA memory. + + If this object is already in CUDA memory and on the correct device, then + no copy is performed and the original object is returned. + + Args: + device (int): The destination GPU id. Defaults to the current device. + non_blocking (bool): If ``True`` and the source is in pinned memory, + the copy will be asynchronous with respect to the host. Otherwise, + the argument has no effect. + """ + device2 = core.device("cuda", device) if device else core.device("cuda") + return self.to(device=device2, non_blocking=non_blocking) + + def hpu(self, device=None, non_blocking=False) -> Union[_StorageBase, TypedStorage]: + """Returns a copy of this object in HPU memory. + + If this object is already in HPU memory and on the correct device, then + no copy is performed and the original object is returned. + + Args: + device (int): The destination HPU id. Defaults to the current device. + non_blocking (bool): If ``True`` and the source is in pinned memory, + the copy will be asynchronous with respect to the host. Otherwise, + the argument has no effect. + """ + device2 = core.device("hpu", device) if device else core.device("hpu") + return self.to(device=device2, non_blocking=non_blocking) + + def element_size(self) -> _int: + raise NotImplementedError + + def get_device(self) -> _int: + return self.device.index + + def data_ptr(self) -> _int: + raise NotImplementedError + + def resizable(self) -> _bool: + raise NotImplementedError + + # Defined in torch/csrc/generic/StorageSharing.cpp + def _share_filename_cpu_(self, *args, **kwargs): + raise NotImplementedError + + def _share_fd_cpu_(self, *args, **kwargs): + raise NotImplementedError + + @classmethod + def _new_using_filename_cpu(cls, size: _int) -> Self: + raise NotImplementedError + + @classmethod + def _new_using_fd_cpu(cls, size: _int) -> Self: + raise NotImplementedError + + @classmethod + def from_buffer(cls, *args, **kwargs) -> Self: + raise NotImplementedError + + @classmethod + def _new_shared_filename_cpu( + cls, + manager, + obj, + size, + *, + device=None, + dtype=None, + ) -> Self: + raise NotImplementedError + + @classmethod + def _release_ipc_counter_cuda(cls, *args, **kwargs) -> Self: + raise NotImplementedError + + @classmethod + def _new_with_weak_ptr(cls, *args, **kwargs) -> Self: + raise NotImplementedError + + def _shared_decref(self) -> Union[_StorageBase, TypedStorage]: + raise NotImplementedError + + def _write_file(self, *args, **kwargs): + raise NotImplementedError + + def resize_(self, size: _int): + raise NotImplementedError + + def _weak_ref(self, *args, **kwargs) -> Union[_StorageBase, TypedStorage]: + raise NotImplementedError + + def _set_from_file(self, *args, **kwargs): + raise NotImplementedError + + def _set_cdata(self, *args, **kwargs): + raise NotImplementedError + + def _share_cuda_(self, *args, **kwargs): + raise NotImplementedError + + def is_shared(self) -> _bool: + raise NotImplementedError + + @classmethod + def _new_shared_cuda(cls, *args, **kwargs) -> Self: + raise NotImplementedError + + def _shared_incref(self, *args, **kwargs): + raise NotImplementedError + + @classmethod + def _free_weak_ref(cls, *args, **kwargs): + raise NotImplementedError + + @property + def is_cuda(self): + raise NotImplementedError + + @property + def is_hpu(self): + raise NotImplementedError + + @classmethod + def from_file(cls, filename, shared, nbytes) -> Union[_StorageBase, TypedStorage]: + raise NotImplementedError + + @classmethod + def _expired(cls, *args, **kwargs) -> Union[_StorageBase, TypedStorage]: + raise NotImplementedError + + def _byteswap(self, *args, **kwargs): + raise NotImplementedError + + def _get_filename(self, *args, **kwargs) -> _Optional[str]: + raise NotImplementedError + + def __repr__(self): + info_str = f"[{core.typename(self)}(device={self.device}) of size {len(self)}]" + if self.device.type == "meta": + return "...\n" + info_str + data_str = " " + "\n ".join(str(self[i]) for i in range(self.size())) + return data_str + "\n" + info_str + + def __iter__(self): + return iter(self[i] for i in range(self.size())) + + def __copy__(self): + return self.clone() + + def __deepcopy__(self, memo): + memo = memo.setdefault("torch", {}) + if self._cdata in memo: + return memo[self._cdata] + new_storage = self.clone() + memo[self._cdata] = new_storage + return new_storage + + def __reduce__(self): + b = io.BytesIO() + core.save(self, b, _use_new_zipfile_serialization=False) + return (_load_from_bytes, (b.getvalue(),)) + + def __sizeof__(self): + return super().__sizeof__() + self.size() + + def clone(self): + """Return a copy of this storage.""" + return type(self)(self.nbytes(), device=self.device).copy_(self) + + def tolist(self): + """Return a list containing the elements of this storage.""" + return list(self) + + def cpu(self): + """Return a CPU copy of this storage if it's not already on the CPU.""" + if self.device.type != "cpu": + return core.UntypedStorage(self.size()).copy_(self, False) + return self + + def mps(self): + """Return a MPS copy of this storage if it's not already on the MPS.""" + if self.device.type != "mps": + return core.UntypedStorage(self.size(), device="mps").copy_(self, False) + return self + + def _to(self, dtype): + if not isinstance(dtype, core.dtype): + raise TypeError(f"Argument 'dtype' must be core.dtype, not {type(dtype)}") + storage = ( + core.tensor([], dtype=core.uint8, device=self.device) + .set_(cast(Storage, self)) + .to(dtype) + ._typed_storage() + ) + if storage.data_ptr() == self.data_ptr(): + storage = storage.clone() + return storage + + def to(self, *, device: DeviceLikeType, non_blocking: _bool = False): + if not isinstance(device, core.device): + device = core.device(device) + return _to(self, device, non_blocking) + + def double(self): + """Casts this storage to double type.""" + return self._to(core.double) + + def float(self): + """Casts this storage to float type.""" + return self._to(core.float) + + def half(self): + """Casts this storage to half type.""" + return self._to(core.half) + + def long(self): + """Casts this storage to long type.""" + return self._to(core.long) + + def int(self): + """Casts this storage to int type.""" + return self._to(core.int) + + def short(self): + """Casts this storage to short type.""" + return self._to(core.short) + + def char(self): + """Casts this storage to char type.""" + return self._to(core.int8) + + def byte(self): + """Casts this storage to byte type.""" + return self._to(core.uint8) + + def bool(self): + """Casts this storage to bool type.""" + return self._to(core.bool) + + def bfloat16(self): + """Casts this storage to bfloat16 type.""" + return self._to(core.bfloat16) + + def complex_double(self): + """Casts this storage to complex double type.""" + return self._to(core.cdouble) + + def complex_float(self): + """Casts this storage to complex float type.""" + return self._to(core.cfloat) + + def float8_e5m2(self): + """Casts this storage to float8_e5m2 type""" + return self._to(core.float8_e5m2) + + def float8_e4m3fn(self): + """Casts this storage to float8_e4m3fn type""" + return self._to(core.float8_e4m3fn) + + def float8_e5m2fnuz(self): + """Casts this storage to float8_e5m2fnuz type""" + return self._to(core.float8_e5m2fnuz) + + def float8_e4m3fnuz(self): + """Casts this storage to float8_e4m3fnuz type""" + return self._to(core.float8_e4m3fnuz) + + def is_pinned(self, device: Union[str, core.device] = "cuda"): + r"""Determine whether the CPU storage is already pinned on device. + + Args: + device (str or core.device): The device to pin memory on (default: ``'cuda'``). + This argument is discouraged and subject to deprecated. + + Returns: + A boolean variable. + """ + return ( + core.tensor([], dtype=core.uint8, device=self.device) + .set_(cast(Storage, self)) + .is_pinned(device) + ) + + def pin_memory(self, device: Union[str, core.device] = "cuda"): + r"""Copy the CPU storage to pinned memory, if it's not already pinned. + + Args: + device (str or core.device): The device to pin memory on (default: ``'cuda'``). + This argument is discouraged and subject to deprecated. + + Returns: + A pinned CPU storage. + """ + if self.device.type != "cpu": + raise TypeError(f"cannot pin '{self.type()}' only CPU memory can be pinned") + + pinned_tensor = ( + core.tensor([], dtype=core.uint8, device=self.device) + .set_(cast(Storage, self)) + .pin_memory(device) + ) + return pinned_tensor.untyped_storage() + + def share_memory_(self): + """See :meth:`core.UntypedStorage.share_memory_`""" + from core.multiprocessing import get_sharing_strategy + + if self.device.type in ["cuda", core._C._get_privateuse1_backend_name()]: + pass # CUDA or PrivateUse1 doesn't use POSIX shared memory + elif get_sharing_strategy() == "file_system": + self._share_filename_cpu_() + else: + self._share_fd_cpu_() + return self + + @classmethod + def _new_shared(cls, size, *, device="cpu"): + """Create a new storage in shared memory with the same data type.""" + from core.multiprocessing import get_sharing_strategy + + device = core.device(device) + if device.type in ["cuda", core._C._get_privateuse1_backend_name(), "hpu"]: + return cls(size, device=device) + elif get_sharing_strategy() == "file_system": + return cls._new_using_filename_cpu(size) + else: + return cls._new_using_fd_cpu(size) + + def untyped(self): + return self + + def byteswap(self, dtype): + """Swap bytes in underlying data.""" + elem_size = core._utils._element_size(dtype) + # for complex types, don't swap first and second numbers + if dtype.is_complex: + elem_size = max(int(elem_size / 2), 1) + self._byteswap(elem_size) + + +def _share_memory_lock_protected(fn): + @functools.wraps(fn) + def wrapper(self, *args, **kwargs): + to_free = None + to_wait = None + with _share_memory_lock: + key = self._cdata + if key in _share_memory_map: + to_wait = _share_memory_map[key] + else: + _share_memory_map[key] = threading.RLock() + _share_memory_map[key].acquire() + to_free = key + + # If we're already in the process of sharing the storage, wait + # for it to be done. + if to_wait is not None: + with to_wait: + pass + + try: + return fn(self, *args, **kwargs) + finally: + # If we acquired the storage lock here and we're done working on it + # we can now release it and free the entry. + if to_free is not None: + # Ensure that the cdata from the storage didn't change and only + # the data_ptr did. + assert self._cdata == to_free + with _share_memory_lock: + _share_memory_map[to_free].release() + del _share_memory_map[to_free] + + return wrapper + + +class UntypedStorage(_StorageBase): + def __init__(self, data, device=None): + self.data = data # np array as storage + self.device = device + + @property + def is_cuda(self): + return self.device.type == "cuda" + + @property + def is_hpu(self): + return self.device.type == "hpu" + + @property + def filename(self) -> _Optional[str]: + """Returns the file name associated with this storage. + + The file name will be a string if the storage is on CPU and was created via + :meth:`~core.from_file()` with ``shared`` as ``True``. This attribute is ``None`` otherwise. + """ + return self._get_filename() + + @_share_memory_lock_protected + def share_memory_(self, *args, **kwargs): + """ + Moves the storage to shared memory. + + This is a no-op for storages already in shared memory and for CUDA + storages, which do not need to be moved for sharing across processes. + Storages in shared memory cannot be resized. + + Note that to mitigate issues like `this `_ + it is thread safe to call this function from multiple threads on the same object. + It is NOT thread safe though to call any other function on self without proper + synchronization. Please see :doc:`/notes/multiprocessing` for more details. + + .. note:: + When all references to a storage in shared memory are deleted, the associated shared memory + object will also be deleted. PyTorch has a special cleanup process to ensure that this happens + even if the current process exits unexpectedly. + + It is worth noting the difference between :meth:`share_memory_` and :meth:`from_file` with ``shared = True`` + + #. ``share_memory_`` uses `shm_open(3) `_ to create a + POSIX shared memory object while :meth:`from_file` uses + `open(2) `_ to open the filename passed by the user. + #. Both use an `mmap(2) call `_ with ``MAP_SHARED`` + to map the file/object into the current virtual address space + #. ``share_memory_`` will call ``shm_unlink(3)`` on the object after mapping it to make sure the shared memory + object is freed when no process has the object open. ``core.from_file(shared=True)`` does not unlink the + file. This file is persistent and will remain until it is deleted by the user. + + Returns: + ``self`` + """ + return super().share_memory_(*args, **kwargs) + + @_share_memory_lock_protected + def _share_fd_cpu_(self, *args, **kwargs): + return super()._share_fd_cpu_(*args, **kwargs) + + @_share_memory_lock_protected + def _share_filename_cpu_(self, *args, **kwargs): + return super()._share_filename_cpu_(*args, **kwargs) + + def data_ptr(self): + return self.data.ctypes.data + + def nbytes(self): + return self.data.nbytes + + @classmethod + def from_file(cls, filename, shared, nbytes): + data = np.memmap(filename) + return cls(data, device=core.device('cpu')) + + def __getitem__(self, slice): + if self.device.type == "meta": + raise NotImplementedError("Not available for 'meta' device type") + + return UntypedStorage(self.data[slice]) + +def _load_from_bytes(b): + return core.load(io.BytesIO(b), weights_only=False) + + +@functools.cache +def _new_dtypes(): + # These are dtypes serialized as UntypedStorage unlike those in + # _dtype_to_storage_type_map + return { + core.float8_e5m2, + core.float8_e4m3fn, + core.float8_e5m2fnuz, + core.float8_e4m3fnuz, + core.float8_e8m0fnu, + core.float4_e2m1fn_x2, + core.bits8, + core.bits16, + core.bits1x8, + core.bits2x4, + core.bits4x2, + core.complex32, + core.uint16, + core.uint32, + core.uint64, + } + + +@functools.cache +def _dtype_to_storage_type_map(): + # NOTE: We should no longer add dtypes to this map. This map + # is only used for BC/FC with older PyTorch versions. Going forward, + # new dtypes of TypedStorage should not translate to a legacy + # Storage class. Instead, new dtypes of TypedStorage should + # be serialized as an UntypedStorage paired with a core.dtype + return { + core.double: "DoubleStorage", + core.float: "FloatStorage", + core.half: "HalfStorage", + core.long: "LongStorage", + core.int: "IntStorage", + core.int16: "ShortStorage", + core.int8: "CharStorage", + core.uint8: "ByteStorage", + core.bool: "BoolStorage", + core.bfloat16: "BFloat16Storage", + core.cdouble: "ComplexDoubleStorage", + core.cfloat: "ComplexFloatStorage", + # core.qint8: "QInt8Storage", + # core.qint32: "QInt32Storage", + # core.quint8: "QUInt8Storage", + # core.quint4x2: "QUInt4x2Storage", + # core.quint2x4: "QUInt2x4Storage", + } + + +@functools.cache +def _storage_type_to_dtype_map(): + dtype_map = {val: key for key, val in _dtype_to_storage_type_map().items()} + return dtype_map + + +def _get_storage_from_sequence(sequence, dtype, device): + if dtype in [ + core.quint8, + core.quint4x2, + core.quint2x4, + core.qint32, + core.qint8, + ]: + interpret_dtypes = { + core.quint8: core.uint8, + core.quint4x2: core.uint8, + core.quint2x4: core.uint8, + core.qint32: core.int32, + core.qint8: core.int8, + } + tmp_tensor = core.tensor( + sequence, dtype=interpret_dtypes[dtype], device=device + ) + + else: + tmp_tensor = core.tensor(sequence, dtype=dtype, device=device) + + return tmp_tensor._typed_storage()._untyped_storage + + +def _isint(x): + if HAS_NUMPY: + return isinstance(x, (int, np.integer)) + else: + return isinstance(x, int) + + +_always_warn_typed_storage_removal = False + + +def _get_always_warn_typed_storage_removal(): + return _always_warn_typed_storage_removal + + +def _set_always_warn_typed_storage_removal(always_warn): + global _always_warn_typed_storage_removal + assert isinstance(always_warn, bool) + _always_warn_typed_storage_removal = always_warn + + +def _warn_typed_storage_removal(stacklevel=2): + global _always_warn_typed_storage_removal + + def is_first_time(): + if not hasattr(_warn_typed_storage_removal, "has_warned"): + return True + else: + return not _warn_typed_storage_removal.__dict__["has_warned"] + + if _get_always_warn_typed_storage_removal() or is_first_time(): + message = ( + "TypedStorage is deprecated. It will be removed in the future and " + "UntypedStorage will be the only storage class. This should only matter " + "to you if you are using storages directly. To access UntypedStorage " + "directly, use tensor.untyped_storage() instead of tensor.storage()" + ) + warnings.warn(message, UserWarning, stacklevel=stacklevel + 1) + _warn_typed_storage_removal.__dict__["has_warned"] = True + + +def _reset_warn_typed_storage_removal(): + _warn_typed_storage_removal.__dict__["has_warned"] = False + + +def _get_device_from_module(module: str): + last_part = module.rsplit(".", 1)[-1] + if last_part in ["cuda", core._C._get_privateuse1_backend_name(), "hpu"]: + return last_part + else: + return "cpu" + + +class TypedStorage: + is_sparse: _bool = False + # Used when stashing FakeTensor device onto storage in core.save(metadata_only=True) + _fake_device: _Optional[core.device] = None + + dtype: core.dtype + + @property + def _dtype(self): + return self.dtype + + @property + def filename(self) -> _Optional[str]: + """Returns the file name associated with this storage if the storage was memory mapped from a file. + or ``None`` if the storage was not created by memory mapping a file.""" + return self._untyped_storage.filename + + def fill_(self, value): + _warn_typed_storage_removal() + self._setitem(slice(0, self._size()), value) + return self + + def __new__( + cls, + *args, + wrap_storage=None, + dtype=None, + device=None, + _internal=False, + ): + if not _internal: + _warn_typed_storage_removal() + + if cls == core.storage._LegacyStorage: + raise RuntimeError( + "Only child classes of _LegacyStorage can be instantiated" + ) + + if cls == TypedStorage: + return super().__new__(cls) + + else: + arg_error_msg = ( + f"{cls}.__new__ received an invalid combination " + f"of arguments. Expected one of:\n" + " * no arguments\n" + " * (int size)\n" + " * (Sequence data)\n" + " * (*, UntypedStorage wrap_storage)" + ) + + if device is not None: + raise RuntimeError( + arg_error_msg + "\nKeyword argument 'device' cannot be specified" + ) + + if dtype is not None: + raise RuntimeError( + arg_error_msg + "\nKeyword argument 'dtype' cannot be specified" + ) + + if wrap_storage is None: + if len(args) > 1: + raise RuntimeError( + arg_error_msg + "\nToo many positional arguments" + ) + + if ( + len(args) == 1 + and not _isint(args[0]) + and not isinstance(args[0], collections.abc.Sequence) + ): + raise TypeError( + arg_error_msg + + f"\nArgument type not recognized: {type(args[0])}" + ) + + return TypedStorage( + *args, + dtype=cls._dtype, + device=_get_device_from_module(cls.__module__), + _internal=True, + ) + + else: + if len(args) != 0: + raise RuntimeError( + arg_error_msg + + "\nNo positional arguments should be given when using " + "'wrap_storage'" + ) + + if not isinstance(wrap_storage, core.UntypedStorage): + raise TypeError( + arg_error_msg + + f"\nArgument 'wrap_storage' must be UntypedStorage, but got {type(wrap_storage)}" + ) + + cls_device = _get_device_from_module(cls.__module__) + + if wrap_storage.device.type != cls_device: + raise RuntimeError( + arg_error_msg + + f"\nDevice of 'wrap_storage' must be {cls_device}" + f", but got {wrap_storage.device.type}" + ) + + return TypedStorage( + *args, + wrap_storage=wrap_storage, + dtype=cls.dtype, + _internal=True, + ) + + def __init__( + self, + *args, + device=None, + dtype=None, + wrap_storage=None, + _internal=False, + ): + if not _internal: + _warn_typed_storage_removal() + arg_error_msg = ( + "TypedStorage.__init__ received an invalid combination " + "of arguments. Expected one of:\n" + " * (*, core.device device, core.dtype dtype)\n" + " * (int size, *, core.device device, core.dtype dtype)\n" + " * (Sequence data, *, core.device device, core.dtype dtype)\n" + " * (*, UntypedStorage wrap_storage, core.dtype dtype)" + ) + + if wrap_storage is not None: + if len(args) != 0: + raise RuntimeError( + arg_error_msg + + "\nNo positional arguments should be given when using " + "'wrap_storage'" + ) + + if dtype is None: + raise RuntimeError( + arg_error_msg + "\nArgument 'dtype' must be specified" + ) + + if not isinstance(dtype, core.dtype): + raise TypeError( + arg_error_msg + + f"\nArgument 'dtype' must be core.dtype, not {type(dtype)}" + ) + + if device is not None: + raise RuntimeError( + arg_error_msg + + "\nArgument 'device' should not be specified when 'wrap_storage' is given" + ) + + self.dtype = dtype + + if not isinstance(wrap_storage, core.UntypedStorage): + raise TypeError( + arg_error_msg + + f"\nArgument 'wrap_storage' must be UntypedStorage, but got {type(wrap_storage)}" + ) + + self._untyped_storage = wrap_storage + + else: + self.dtype = core.get_default_dtype() if dtype is None else dtype + device = core.device("cpu" if device is None else device) + + # if self.dtype in [ + # core.quint8, + # core.quint4x2, + # core.quint2x4, + # core.qint32, + # core.qint8, + # ]: + # if device.type == "cuda": + # raise RuntimeError( + # "Cannot create CUDA storage with quantized dtype" + # ) + + if len(args) == 0: + self._untyped_storage = core.UntypedStorage(device=device) + + elif len(args) == 1: + if _isint(args[0]): + self._untyped_storage = core.UntypedStorage( + int(args[0]) * self._element_size(), device=device + ) + elif isinstance(args[0], collections.abc.Sequence): + self._untyped_storage = _get_storage_from_sequence( + args[0], self.dtype, device + ) + else: + raise TypeError( + arg_error_msg + + f"\nArgument type not recognized: {type(args[0])}" + ) + + else: + raise RuntimeError(arg_error_msg + "\nToo many positional arguments") + + @property + def is_cuda(self): + _warn_typed_storage_removal() + return self._untyped_storage.device.type == "cuda" + + @property + def is_hpu(self): + _warn_typed_storage_removal() + return self._untyped_storage.device.type == "hpu" + + def untyped(self): + """Return the internal :class:`core.UntypedStorage`.""" + _warn_typed_storage_removal() + return self._untyped_storage + + def _new_wrapped_storage(self, untyped_storage) -> Self: + assert type(untyped_storage) == core.UntypedStorage + + if type(self) == TypedStorage: + return cast( + Self, + TypedStorage( + wrap_storage=untyped_storage, dtype=self.dtype, _internal=True + ), + ) + else: + return type(self)(wrap_storage=untyped_storage) + + def __len__(self): + _warn_typed_storage_removal() + return self._size() + + def _maybe_wrap_index(self, idx, is_stop=False): + if idx is None: + if is_stop: + return self._size() + else: + return 0 + + else: + if type(idx) != int: + raise TypeError(f"can't index a {type(self)} with {type(idx)}") + if is_stop: + if (idx > self._size()) or (idx < -self._size()): + raise IndexError( + f"index {idx} out of range for storage of size {self.size()}" + ) + if idx > 0: + return idx + else: + return idx % self._size() + else: + if (idx >= self._size()) or (idx < -self._size()): + raise IndexError( + f"index {idx} out of range for storage of size {self.size()}" + ) + return idx % self._size() + + def __setitem__(self, idx, value): + _warn_typed_storage_removal() + return self._setitem(idx, value) + + def _setitem(self, idx, value): + if not isinstance(idx, (int, slice)): + raise RuntimeError(f"can't index a {type(self)} with {type(idx)}") + if core.is_storage(value): + raise RuntimeError(f"cannot set item with value type {type(value)}") + if self.dtype in [ + core.quint8, + core.quint4x2, + core.quint2x4, + core.qint32, + core.qint8, + ]: + interpret_dtypes = { + core.quint8: core.uint8, + core.quint4x2: core.uint8, + core.quint2x4: core.uint8, + core.qint32: core.int32, + core.qint8: core.int8, + } + tmp_dtype = interpret_dtypes[self.dtype] + tmp_tensor = core.tensor( + [], dtype=tmp_dtype, device=self._untyped_storage.device + ) + tmp_tensor.set_( + TypedStorage( + wrap_storage=self._untyped_storage, dtype=tmp_dtype, _internal=True + ) + ) + else: + tmp_tensor = core.tensor( + [], dtype=self.dtype, device=self._untyped_storage.device + ).set_(self) + + tmp_tensor[idx] = value + + def __getitem__(self, idx): + _warn_typed_storage_removal() + return self._getitem(idx) + + def _getitem(self, idx): + print(idx) + if self._untyped_storage.device.type == "meta": + raise NotImplementedError("Not available for 'meta' device type") + + # NOTE: Before TypedStorage existed, indexing with a slice used to be + # possible for Storage objects. However, it would return + # a storage view, which would be a hassle to implement in TypedStorage, + # so it was disabled + if isinstance(idx, slice): + raise RuntimeError( + "slices are only supported in UntypedStorage.__getitem__" + ) + elif not isinstance(idx, int): + raise RuntimeError(f"can't index a {type(self)} with {type(idx)}") + + # if self.dtype in [ + # core.quint8, + # core.quint4x2, + # core.quint2x4, + # core.qint32, + # core.qint8, + # ]: + # interpret_dtypes = { + # core.quint8: core.uint8, + # core.quint4x2: core.uint8, + # core.quint2x4: core.uint8, + # core.qint32: core.int32, + # core.qint8: core.int8, + # } + # return TypedStorage( + # wrap_storage=self._untyped_storage, + # dtype=interpret_dtypes[self.dtype], + # _internal=True, + # )._getitem(idx) + + idx_wrapped = self._maybe_wrap_index(idx) + + tmp_tensor = core.tensor( + [], dtype=self.dtype, device=self._untyped_storage.device + ).set_(self) + return tmp_tensor[idx_wrapped].item() + + def copy_(self, source: T, non_blocking: _Optional[bool] = None): + _warn_typed_storage_removal() + if isinstance(source, TypedStorage): + self._untyped_storage.copy_(source._untyped_storage, non_blocking) + else: + self._untyped_storage.copy_(source, non_blocking) + return self + + def nbytes(self): + _warn_typed_storage_removal() + return self._nbytes() + + # For internal use only, to avoid deprecation warning + def _nbytes(self): + return self._untyped_storage.nbytes() + + def type( + self, + dtype: _Optional[str] = None, + non_blocking: bool = False, + ) -> Union[_StorageBase, TypedStorage, str]: + _warn_typed_storage_removal() + if dtype is None: + legacy_class = self._get_legacy_storage_class() + + if legacy_class is not None: + return legacy_class.__module__ + "." + legacy_class.__name__ + + return ".".join([self.__module__, type(self).__name__]) + + else: + return self._untyped_storage.type(dtype, non_blocking) + + def cuda(self, device=None, non_blocking=False) -> Self: + _warn_typed_storage_removal() + if self.dtype in [ + core.quint8, + core.quint4x2, + core.quint2x4, + core.qint32, + core.qint8, + ]: + raise RuntimeError("Cannot create CUDA storage with quantized dtype") + cuda_storage = self._untyped_storage.cuda(device, non_blocking) + return self._new_wrapped_storage(cuda_storage) + + def hpu(self, device=None, non_blocking=False) -> Self: + _warn_typed_storage_removal() + if self.dtype in [ + core.quint8, + core.quint4x2, + core.quint2x4, + core.qint32, + core.qint8, + ]: + raise RuntimeError("Cannot create HPU storage with quantized dtype") + hpu_storage = self._untyped_storage.hpu(device, non_blocking) + return self._new_wrapped_storage(hpu_storage) + + def to(self, *, device: DeviceLikeType, non_blocking: bool = False) -> Self: + _warn_typed_storage_removal() + if not isinstance(device, core.device): + device = core.device(device) + if self.dtype in [ + core.quint8, + core.quint4x2, + core.quint2x4, + core.qint32, + core.qint8, + ]: + raise RuntimeError( + f"Cannot create {device.type.upper()} storage with quantized dtype" + ) + to_storage = self._untyped_storage.to(device=device, non_blocking=non_blocking) + return self._new_wrapped_storage(to_storage) + + def element_size(self): + _warn_typed_storage_removal() + return self._element_size() + + # For internal use only, to avoid deprecation warning + def _element_size(self): + return core._utils._element_size(self.dtype) + + def get_device(self) -> _int: + _warn_typed_storage_removal() + return self._untyped_storage.get_device() + + def __str__(self): + _warn_typed_storage_removal() + info_str = ( + f"[{core.typename(self)}(dtype={self.dtype}, " + f"device={self.device}) of size {len(self)}]" + ) + if self.device.type == "meta": + return "...\n" + info_str + else: + data_str = " " + "\n ".join(str(self[i]) for i in range(self.size())) + return data_str + "\n" + info_str + + def __repr__(self): + _warn_typed_storage_removal() + return str(self) + + def __iter__(self): + _warn_typed_storage_removal() + return iter(self[i] for i in range(self.size())) + + def __copy__(self): + _warn_typed_storage_removal() + return self._new_wrapped_storage(copy.copy(self._untyped_storage)) + + def __deepcopy__(self, memo): + _warn_typed_storage_removal() + return self._deepcopy(memo) + + # For internal use only, to avoid deprecation warning + def _deepcopy(self, memo): + return self._new_wrapped_storage(copy.deepcopy(self._untyped_storage, memo)) + + def __sizeof__(self): + _warn_typed_storage_removal() + return super().__sizeof__() + self.nbytes() + + def clone(self): + """Return a copy of this storage.""" + _warn_typed_storage_removal() + return self._new_wrapped_storage(self._untyped_storage.clone()) + + def tolist(self): + """Return a list containing the elements of this storage.""" + _warn_typed_storage_removal() + return list(self) + + def cpu(self): + """Return a CPU copy of this storage if it's not already on the CPU.""" + _warn_typed_storage_removal() + return self._new_wrapped_storage(self._untyped_storage.cpu()) + + def is_pinned(self, device: Union[str, core.device] = "cuda"): + r"""Determine whether the CPU TypedStorage is already pinned on device. + + Args: + device (str or core.device): The device to pin memory on (default: ``'cuda'``). + This argument is discouraged and subject to deprecated. + + Returns: + A boolean variable. + """ + _warn_typed_storage_removal() + return self._untyped_storage.is_pinned(device) + + def pin_memory(self, device: Union[str, core.device] = "cuda"): + r"""Copy the CPU TypedStorage to pinned memory, if it's not already pinned. + + Args: + device (str or core.device): The device to pin memory on (default: ``'cuda'``). + This argument is discouraged and subject to deprecated. + + Returns: + A pinned CPU storage. + """ + _warn_typed_storage_removal() + return self._new_wrapped_storage( + self._untyped_storage.pin_memory(device=device) + ) + + def share_memory_(self): + """See :meth:`core.UntypedStorage.share_memory_`""" + _warn_typed_storage_removal() + return self._share_memory_() + + # For internal use only, to avoid deprecation warning + def _share_memory_(self): + self._untyped_storage.share_memory_() + return self + + def _new_shared(self, size, *, device=None): + """Create a new storage in shared memory with the same data type.""" + if device is None: + device = "cpu" + device = core.device(device) + untyped_storage = core.UntypedStorage._new_shared( + size * self._element_size(), device=device + ) + return TypedStorage( + wrap_storage=untyped_storage, dtype=self.dtype, _internal=True + ) + + @property + def _cdata(self): + return self._untyped_storage._cdata + + @property + def device(self): + _warn_typed_storage_removal() + return self._untyped_storage.device + + def size(self): + _warn_typed_storage_removal() + return self._size() + + # For internal use only, to avoid deprecation warning + def _size(self): + # NB: don't indirect through __len__, as that requires + # an int to be returned + return self._untyped_storage.nbytes() // self._element_size() + + def pickle_storage_type(self): + _warn_typed_storage_removal() + return self._pickle_storage_type() + + # For internal use only, to avoid deprecation warning + def _pickle_storage_type(self): + try: + return _dtype_to_storage_type_map()[self.dtype] + except KeyError as e: + raise KeyError(f"dtype {self.dtype} is not recognized") from e + + def __reduce__(self): + b = io.BytesIO() + core.save(self, b, _use_new_zipfile_serialization=False) + return (_load_from_bytes, (b.getvalue(),)) + + def data_ptr(self): + _warn_typed_storage_removal() + return self._data_ptr() + + # For internal use only, to avoid deprecation warning + def _data_ptr(self): + return self._untyped_storage.data_ptr() + + def resizable(self): + _warn_typed_storage_removal() + return self._untyped_storage.resizable() + + def resize_(self, size): + _warn_typed_storage_removal() + self._resize_(size) + + # For internal use only, to avoid deprecation warning + def _resize_(self, size): + self._untyped_storage.resize_(size * self._element_size()) + + @classmethod + def _free_weak_ref(cls, *args, **kwargs): + return UntypedStorage._free_weak_ref(*args, **kwargs) + + def _weak_ref(self, *args, **kwargs): + return self._untyped_storage._weak_ref(*args, **kwargs) + + @classmethod + def from_buffer(cls, *args, **kwargs): + _warn_typed_storage_removal() + return cls._from_buffer(*args, **kwargs) + + @classmethod + def _from_buffer(cls, *args, dtype=None, device=None, **kwargs): + if cls == TypedStorage: + dtype = core.get_default_dtype() if dtype is None else dtype + device = core.device("cpu" if device is None else device) + if device.type != "cpu": + raise RuntimeError( + f"TypedStorage.from_buffer: Not available for device {device.type}" + ) + untyped_storage: core.UntypedStorage = core.UntypedStorage.from_buffer( + *args, dtype=dtype, **kwargs + ) + + else: + if dtype is not None or len(args) == 5: + raise RuntimeError( + "from_buffer: 'dtype' can only be specified in " + "UntypedStorage.from_buffer and TypedStorage.from_buffer" + ) + if device is not None: + raise RuntimeError( + "from_buffer: 'device' can only be specified in " + "UntypedStorage.from_buffer and TypedStorage.from_buffer" + ) + + dtype = cls._dtype + untyped_storage = core.UntypedStorage.from_buffer( + *args, dtype=dtype, **kwargs + ) + + return TypedStorage(wrap_storage=untyped_storage, dtype=dtype, _internal=True) + + def _to(self, dtype): + if not isinstance(dtype, core.dtype): + raise TypeError(f"Argument 'dtype' must be core.dtype, not {type(dtype)}") + storage = ( + core.tensor([], dtype=self.dtype, device=self.device) + .set_(self) + .to(dtype) + ._typed_storage() + ) + if storage.data_ptr() == self.data_ptr(): + storage = storage.clone() + return storage + + def double(self): + """Casts this storage to double type.""" + _warn_typed_storage_removal() + return self._to(core.double) + + def float(self): + """Casts this storage to float type.""" + _warn_typed_storage_removal() + return self._to(core.float) + + def half(self): + """Casts this storage to half type.""" + _warn_typed_storage_removal() + return self._to(core.half) + + def long(self): + """Casts this storage to long type.""" + _warn_typed_storage_removal() + return self._to(core.long) + + def int(self): + """Casts this storage to int type.""" + _warn_typed_storage_removal() + return self._to(core.int) + + def short(self): + """Casts this storage to short type.""" + _warn_typed_storage_removal() + return self._to(core.short) + + def char(self): + """Casts this storage to char type.""" + _warn_typed_storage_removal() + return self._to(core.int8) + + def byte(self): + """Casts this storage to byte type.""" + _warn_typed_storage_removal() + return self._to(core.uint8) + + def bool(self): + """Casts this storage to bool type.""" + _warn_typed_storage_removal() + return self._to(core.bool) + + def bfloat16(self): + """Casts this storage to bfloat16 type.""" + _warn_typed_storage_removal() + return self._to(core.bfloat16) + + def complex_double(self): + """Casts this storage to complex double type.""" + _warn_typed_storage_removal() + return self._to(core.cdouble) + + def complex_float(self): + """Casts this storage to complex float type.""" + _warn_typed_storage_removal() + return self._to(core.cfloat) + + def float8_e5m2(self): + """Casts this storage to float8_e5m2 type""" + _warn_typed_storage_removal() + return self._to(core.float8_e5m2) + + def float8_e4m3fn(self): + """Casts this storage to float8_e4m3fn type""" + _warn_typed_storage_removal() + return self._to(core.float8_e4m3fn) + + def float8_e5m2fnuz(self): + """Casts this storage to float8_e5m2fnuz type""" + _warn_typed_storage_removal() + return self._to(core.float8_e5m2fnuz) + + def float8_e4m3fnuz(self): + """Casts this storage to float8_e4m3fnuz type""" + _warn_typed_storage_removal() + return self._to(core.float8_e4m3fnuz) + + @classmethod + def from_file(cls, filename, shared, size): + """from_file(filename, shared=False, size=0) -> Storage + + Creates a CPU storage backed by a memory-mapped file. + + If ``shared`` is ``True``, then memory is shared between all processes. + All changes are written to the file. If ``shared`` is ``False``, then the changes on + the storage do not affect the file. + + ``size`` is the number of elements in the storage. If ``shared`` is ``False``, + then the file must contain at least ``size * sizeof(Type)`` bytes + (``Type`` is the type of storage). If ``shared`` is ``True`` the file will be created if needed. + + Args: + filename (str): file name to map + shared (bool): whether to share memory (whether ``MAP_SHARED`` or ``MAP_PRIVATE`` is passed to the + underlying `mmap(2) call `_) + size (int): number of elements in the storage + """ + _warn_typed_storage_removal() + if cls == TypedStorage: + raise RuntimeError("from_file can only be called on derived classes") + untyped_storage = UntypedStorage.from_file( + filename, shared, size * core._utils._element_size(cls.dtype) + ) + storage = cls(wrap_storage=untyped_storage) + return storage + + @classmethod + def _expired(cls, *args, **kwargs): + return UntypedStorage._expired(*args, **kwargs) + + def _write_file(self, *args, **kwargs): + return self._untyped_storage._write_file(*args, **kwargs) + + def _set_from_file(self, *args, **kwargs): + return self._untyped_storage._set_from_file(*args, **kwargs) + + def _set_cdata(self, *args, **kwargs): + return self._untyped_storage._set_cdata(*args, **kwargs) + + def _share_cuda_(self, *args, **kwargs): + return self._untyped_storage._share_cuda_(*args, **kwargs) + + def is_shared(self): + _warn_typed_storage_removal() + return self._is_shared() + + # For internal use only, to avoid deprecation warning + def _is_shared(self): + return self._untyped_storage.is_shared() + + @classmethod + def _new_shared_cuda(cls, *args, **kwargs): + return core.UntypedStorage._new_shared_cuda(*args, **kwargs) + + def _share_filename_cpu_(self, *args, **kwargs): + ( + manager_handle, + storage_handle, + size, + ) = self._untyped_storage._share_filename_cpu_(*args, **kwargs) + return manager_handle, storage_handle, size // self._element_size() + + def _shared_decref(self): + self._untyped_storage._shared_decref() + return self + + @classmethod + def _release_ipc_counter(cls, *args, device=None, **kwargs): + return core.UntypedStorage._release_ipc_counter_cuda(*args, **kwargs) + + def _shared_incref(self, *args, **kwargs): + return self._untyped_storage._shared_incref(*args, **kwargs) + + def _share_fd_cpu_(self, *args, **kwargs): + fd, size = self._untyped_storage._share_fd_cpu_(*args, **kwargs) + return fd, size // self._element_size() + + def _get_legacy_storage_class(self): + if self.dtype not in _dtype_to_storage_type_map(): + return None + + storage_name = _dtype_to_storage_type_map()[self.dtype] + + if self.device.type not in [ + "cpu", + "cuda", + "hpu", + core._C._get_privateuse1_backend_name(), + ]: + return None + + module = ( + torch if self.device.type == "cpu" else getattr(torch, self.device.type) + ) + + try: + return getattr(module, storage_name) + except AttributeError: + return None + + +TypedStorage.type.__doc__ = _type.__doc__ +TypedStorage.cuda.__doc__ = _StorageBase.cuda.__doc__ +TypedStorage.hpu.__doc__ = _StorageBase.hpu.__doc__ +TypedStorage.to.__doc__ = _to.__doc__ + + +class _LegacyStorageMeta(type): + dtype: core.dtype + + def __instancecheck__(cls, instance): + if type(instance) == TypedStorage: + cls_device = _get_device_from_module(cls.__module__) + return (cls_device == instance.device.type) and ( + cls.dtype == instance.dtype + ) + return False + + +class _LegacyStorage(TypedStorage, metaclass=_LegacyStorageMeta): + @classmethod + def _new_shared(cls, size): + """Create a new storage in shared memory with the same data type.""" + untyped_storage = core.UntypedStorage._new_shared(size * cls()._element_size()) + return cls(wrap_storage=untyped_storage) + + @classmethod + def _release_ipc_counter(cls, *args, **kwargs): + return core.UntypedStorage._release_ipc_counter_cuda(*args, **kwargs) + + @classmethod + def _new_shared_filename(cls, manager, obj, size): + bytes_size = size * core._utils._element_size(cls.dtype) + return cls( + wrap_storage=core.UntypedStorage._new_shared_filename_cpu( + manager, obj, bytes_size + ) + ) + + +def _get_dtype_from_pickle_storage_type(pickle_storage_type: str): + try: + return _storage_type_to_dtype_map()[pickle_storage_type] + except KeyError as e: + raise KeyError( + f'pickle storage type "{pickle_storage_type}" is not recognized' + ) from e \ No newline at end of file diff --git a/mindnlp/core/testing/__init__.py b/mindnlp/core/testing/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/mindnlp/core/testing/_internal/__init__.py b/mindnlp/core/testing/_internal/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/mindnlp/core/types.py b/mindnlp/core/types.py new file mode 100644 index 000000000..2ab5afc11 --- /dev/null +++ b/mindnlp/core/types.py @@ -0,0 +1,116 @@ +from builtins import ( # noqa: F401 + bool as _bool, + bytes as _bytes, + complex as _complex, + float as _float, + int as _int, + str as _str, +) +from typing import Any, IO, TYPE_CHECKING, Union, Dict +from typing_extensions import Self, TypeAlias + +from ._dtype import dtype + +class device(): + def __init__(self, type=None, index=None): + if type is not None: + if isinstance(type, str): + if ':' in type: + if index is not None: + raise ValueError("`type` must not include an index because index was " + f"passed explicitly: {type}") + _target, _id = type.split(':') + _id = int(_id) + else: + _target = type + _id = None if _target == 'cpu' else 0 + elif isinstance(type, device): + if index is not None: + raise ValueError("core.device(): When input is core.device, `index` can not be set.") + _target = type.type + _id = type.index + else: + raise TypeError("core.device(): `type` must be type of 'str' or 'core.device'.") + else: + raise ValueError("core.device(): `type` can not be None") + + self.type = _target + self.index = _id + + def __repr__(self): + if self.index is None: + return f"device(type={self.type})" + return f"device(type={self.type}, index={self.index})" + + def __eq__(self, __value): + if not isinstance(__value, device): + return False + return hash(self) == hash(__value) + + def __hash__(self): + return hash(self.type) ^ hash(self.index) + +# Meta-type for "numeric" things; matches our docs +Number: TypeAlias = Union[int, float, bool] +# tuple for isinstance(x, Number) checks. +# FIXME: refactor once python 3.9 support is dropped. +_Number = (int, float, bool) + +# Storage protocol implemented by ${Type}StorageBase classes +class Storage: + _cdata: int + device: device + dtype: dtype + _torch_load_uninitialized: bool + + def __deepcopy__(self, memo: Dict[int, Any]) -> "Storage": + raise NotImplementedError + + def _new_shared(self, size: int) -> "Storage": + raise NotImplementedError + + def _write_file( + self, + f: Any, + is_real_file: bool, + save_size: bool, + element_size: int, + ) -> None: + raise NotImplementedError + + def element_size(self) -> int: + raise NotImplementedError + + def is_shared(self) -> bool: + raise NotImplementedError + + def share_memory_(self) -> "Storage": + raise NotImplementedError + + def nbytes(self) -> int: + raise NotImplementedError + + def cpu(self) -> "Storage": + raise NotImplementedError + + def data_ptr(self) -> int: + raise NotImplementedError + + def from_file( + self, + filename: str, + shared: bool = False, + nbytes: int = 0, + ) -> "Storage": + raise NotImplementedError + + def _new_with_file( + self, + f: Any, + element_size: int, + ) -> "Storage": + raise NotImplementedError + +_device = device +_dtype = dtype +_size = tuple \ No newline at end of file diff --git a/mindnlp/core/utils/__init__.py b/mindnlp/core/utils/__init__.py new file mode 100644 index 000000000..c0f240662 --- /dev/null +++ b/mindnlp/core/utils/__init__.py @@ -0,0 +1,2 @@ +"""utils""" +from . import data \ No newline at end of file diff --git a/mindnlp/core/utils/_contextlib.py b/mindnlp/core/utils/_contextlib.py new file mode 100644 index 000000000..dc794a587 --- /dev/null +++ b/mindnlp/core/utils/_contextlib.py @@ -0,0 +1,158 @@ +"""context lib""" +# mypy: allow-untyped-defs +# Extra utilities for working with context managers that should have been +# in the standard library but are not + +import functools +import inspect +import warnings +import sys +from typing import Any, Callable, TypeVar, cast + +# Used for annotating the decorator usage of _DecoratorContextManager (e.g., +# 'no_grad' and 'enable_grad'). +# See https://mypy.readthedocs.io/en/latest/generics.html#declaring-decorators +FuncType = Callable[..., Any] +F = TypeVar('F', bound=FuncType) + + +def _wrap_generator(ctx_factory, func): + """ + Wrap each generator invocation with the context manager factory. + + The input should be a function that returns a context manager, + not a context manager itself, to handle one-shot context managers. + """ + @functools.wraps(func) + def generator_context(*args, **kwargs): + gen = func(*args, **kwargs) + + # Generators are suspended and unsuspended at `yield`, hence we + # make sure the grad mode is properly set every time the execution + # flow returns into the wrapped generator and restored when it + # returns through our `yield` to our caller (see PR #49017). + try: + # Issuing `None` to a generator fires it up + with ctx_factory(): + response = gen.send(None) + + while True: + try: + # Forward the response to our caller and get its next request + request = yield response + + except GeneratorExit: + # Inform the still active generator about its imminent closure + with ctx_factory(): + gen.close() + raise + + except BaseException: + # Propagate the exception thrown at us by the caller + with ctx_factory(): + response = gen.throw(*sys.exc_info()) + + else: + # Pass the last request to the generator and get its response + with ctx_factory(): + response = gen.send(request) + + # We let the exceptions raised above by the generator's `.throw` or + # `.send` methods bubble up to our caller, except for StopIteration + except StopIteration as e: + # The generator informed us that it is done: take whatever its + # returned value (if any) was and indicate that we're done too + # by returning it (see docs for python's return-statement). + return e.value + + return generator_context + + +def context_decorator(ctx, func): + """ + Like contextlib.ContextDecorator. + + But with the following differences: + 1. Is done by wrapping, rather than inheritance, so it works with context + managers that are implemented from C and thus cannot easily inherit from + Python classes + 2. Wraps generators in the intuitive way (c.f. https://bugs.python.org/issue37743) + 3. Errors out if you try to wrap a class, because it is ambiguous whether + or not you intended to wrap only the constructor + + The input argument can either be a context manager (in which case it must + be a multi-shot context manager that can be directly invoked multiple times) + or a callable that produces a context manager. + """ + assert not (callable(ctx) and hasattr(ctx, '__enter__')), ( + f"Passed in {ctx} is both callable and also a valid context manager " + "(has __enter__), making it ambiguous which interface to use. If you " + "intended to pass a context manager factory, rewrite your call as " + "context_decorator(lambda: ctx()); if you intended to pass a context " + "manager directly, rewrite your call as context_decorator(lambda: ctx)" + ) + + if not callable(ctx): + def ctx_factory(): + return ctx + else: + ctx_factory = ctx + + if inspect.isclass(func): + raise RuntimeError( + "Cannot decorate classes; it is ambiguous whether or not only the " + "constructor or all methods should have the context manager applied; " + "additionally, decorating a class at definition-site will prevent " + "use of the identifier as a conventional type. " + "To specify which methods to decorate, decorate each of them " + "individually." + ) + + if inspect.isgeneratorfunction(func): + return _wrap_generator(ctx_factory, func) + + @functools.wraps(func) + def decorate_context(*args, **kwargs): + with ctx_factory(): + return func(*args, **kwargs) + + return decorate_context + + +class _DecoratorContextManager: + """Allow a context manager to be used as a decorator.""" + + def __call__(self, orig_func: F) -> F: + if inspect.isclass(orig_func): + warnings.warn( + "Decorating classes is deprecated and will be disabled in " + "future versions. You should only decorate functions or methods. " + "To preserve the current behavior of class decoration, you can " + "directly decorate the `__init__` method and nothing else.", + FutureWarning, + stacklevel=2, + ) + func = cast(F, lambda *args, **kwargs: orig_func(*args, **kwargs)) + else: + func = orig_func + + return cast(F, context_decorator(self.clone, func)) + + def __enter__(self) -> None: + raise NotImplementedError + + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: + raise NotImplementedError + + def clone(self): + # override this method if your children class takes __init__ parameters + return self.__class__() + + +class _NoParamDecoratorContextManager(_DecoratorContextManager): + """Allow a context manager to be used as a decorator without parentheses.""" + + def __new__(cls, orig_func=None): + if orig_func is None: + return super().__new__(cls) + return cls()(orig_func) diff --git a/mindnlp/core/utils/_pytree.py b/mindnlp/core/utils/_pytree.py new file mode 100644 index 000000000..2cc6551fe --- /dev/null +++ b/mindnlp/core/utils/_pytree.py @@ -0,0 +1,1620 @@ +""" +Contains utility functions for working with nested python data structures. + +A *pytree* is Python nested data structure. It is a tree in the sense that +nodes are Python collections (e.g., list, tuple, dict) and the leaves are +Python values. Furthermore, a pytree should not contain reference cycles. + +pytrees are useful for working with nested collections of Tensors. For example, +one can use `tree_map` to map a function over all Tensors inside some nested +collection of Tensors and `tree_leaves` to get a flat list of all Tensors +inside some nested collection. pytrees are helpful for implementing nested +collection support for PyTorch APIs. + +This pytree implementation is not very performant due to Python overhead +To improve the performance we can move parts of the implementation to C++. +""" + +import dataclasses +import functools +import importlib +import json +import sys +import threading +import types +import warnings +from collections import defaultdict, deque, namedtuple, OrderedDict +from typing import ( + Any, + Callable, + cast, + DefaultDict, + Deque, + Dict, + FrozenSet, + Generic, + Hashable, + Iterable, + List, + Mapping, + NamedTuple, + Optional, + OrderedDict as GenericOrderedDict, + overload, + Protocol, + Sequence, + Tuple, + Type, + TypeVar, + Union, +) +from typing_extensions import deprecated + + +__all__ = [ + "PyTree", + "Context", + "FlattenFunc", + "UnflattenFunc", + "DumpableContext", + "ToDumpableContextFn", + "FromDumpableContextFn", + "TreeSpec", + "LeafSpec", + "keystr", + "key_get", + "register_pytree_node", + "tree_flatten", + "tree_flatten_with_path", + "tree_unflatten", + "tree_iter", + "tree_leaves", + "tree_leaves_with_path", + "tree_structure", + "tree_map", + "tree_map_with_path", + "tree_map_", + "tree_map_only", + "tree_map_only_", + "tree_all", + "tree_any", + "tree_all_only", + "tree_any_only", + "treespec_dumps", + "treespec_loads", + "treespec_pprint", +] + + +T = TypeVar("T") +S = TypeVar("S") +U = TypeVar("U") +R = TypeVar("R") + + +DEFAULT_TREESPEC_SERIALIZATION_PROTOCOL = 1 +NO_SERIALIZED_TYPE_NAME_FOUND = "NO_SERIALIZED_TYPE_NAME_FOUND" + + +class KeyEntry(Protocol): + def __hash__(self) -> int: + ... + + def __eq__(self, other: object) -> bool: + ... + + def __str__(self) -> str: + ... + + def get(self, parent: Any) -> Any: + ... + + +Context = Any +PyTree = Any +FlattenFunc = Callable[[PyTree], Tuple[List[Any], Context]] +UnflattenFunc = Callable[[Iterable[Any], Context], PyTree] +DumpableContext = Any # Any json dumpable text +ToDumpableContextFn = Callable[[Context], DumpableContext] +FromDumpableContextFn = Callable[[DumpableContext], Context] +ToStrFunc = Callable[["TreeSpec", List[str]], str] +MaybeFromStrFunc = Callable[[str], Optional[Tuple[Any, Context, str]]] +KeyPath = Tuple[KeyEntry, ...] +FlattenWithKeysFunc = Callable[[PyTree], Tuple[List[Tuple[KeyEntry, Any]], Any]] + + +# A NodeDef holds two callables: +# - flatten_fn should take the collection and return a flat list of values. +# It can also return some context that is used in reconstructing the +# collection. +# - unflatten_fn should take a flat list of values and some context +# (returned by flatten_fn). It returns the collection by reconstructing +# it from the list and the context. +# - flatten_with_keys_fn, which is a callable that takes a +# pytree and returns a list of (keypath, value) pairs and a context. +class NodeDef(NamedTuple): + type: Type[Any] + flatten_fn: FlattenFunc + unflatten_fn: UnflattenFunc + flatten_with_keys_fn: Optional[FlattenWithKeysFunc] + + +_NODE_REGISTRY_LOCK = threading.Lock() +SUPPORTED_NODES: Dict[Type[Any], NodeDef] = {} + + +# _SerializeNodeDef holds the following: +# - typ: the type of the node (e.g., "Dict", "List", etc) +# - serialized_type_name: the fully qualified name of the type, e.g. "collections.OrderedDict" +# - to_dumpable_context takes a TreeSpec, and returns a serialized string format of the +# context, and the version number +# - from_dumpable_context takes in a string representation of the context, and the +# version, and returns the deserialized context +class _SerializeNodeDef(NamedTuple): + typ: Type[Any] + serialized_type_name: str + to_dumpable_context: Optional[ToDumpableContextFn] + from_dumpable_context: Optional[FromDumpableContextFn] + + +SUPPORTED_SERIALIZED_TYPES: Dict[Type[Any], _SerializeNodeDef] = {} +SERIALIZED_TYPE_TO_PYTHON_TYPE: Dict[str, Type[Any]] = {} + +# NB: we try really hard to not import _cxx_pytree (which depends on optree) +# as much as possible. This is for isolation: a user who is not using C++ pytree +# shouldn't pay for it, and it helps makes things like cpython upgrades easier. +_cxx_pytree_exists = importlib.util.find_spec("optree") # type: ignore[attr-defined] +_cxx_pytree_imported = False +_cxx_pytree_pending_imports: List[Any] = [] + + +def register_pytree_node( + cls: Type[Any], + flatten_fn: FlattenFunc, + unflatten_fn: UnflattenFunc, + *, + serialized_type_name: Optional[str] = None, + to_dumpable_context: Optional[ToDumpableContextFn] = None, + from_dumpable_context: Optional[FromDumpableContextFn] = None, + flatten_with_keys_fn: Optional[FlattenWithKeysFunc] = None, +) -> None: + """Register a container-like type as pytree node. + + Args: + cls: the type to register + flatten_fn: A callable that takes a pytree and returns a flattened + representation of the pytree and additional context to represent the + flattened pytree. + unflatten_fn: A callable that takes a flattened version of the pytree, + additional context, and returns an unflattened pytree. + serialized_type_name: A keyword argument used to specify the fully qualified + name used when serializing the tree spec. + to_dumpable_context: An optional keyword argument to custom specify how + to convert the context of the pytree to a custom json dumpable + representation. This is used for json serialization, which is being + used in core.export right now. + from_dumpable_context: An optional keyword argument to custom specify how + to convert the custom json dumpable representation of the context + back to the original context. This is used for json deserialization, + which is being used in core.export right now. + flatten_with_keys_fn: An optional keyword argument to specify how to + access each pytree leaf's keypath when flattening and tree-mapping. + Like ``flatten_fn``, but in place of a List[leaf], it should return + a List[(keypath, leaf)]. + """ + with _NODE_REGISTRY_LOCK: + if cls in SUPPORTED_NODES: + raise ValueError(f"{cls} is already registered as pytree node.") + + _private_register_pytree_node( + cls, + flatten_fn, + unflatten_fn, + serialized_type_name=serialized_type_name, + to_dumpable_context=to_dumpable_context, + from_dumpable_context=from_dumpable_context, + flatten_with_keys_fn=flatten_with_keys_fn, + ) + + if not _cxx_pytree_exists: + return + + if _cxx_pytree_imported: + from . import _cxx_pytree as cxx + + cxx._private_register_pytree_node( + cls, + flatten_fn, + unflatten_fn, + serialized_type_name=serialized_type_name, + to_dumpable_context=to_dumpable_context, + from_dumpable_context=from_dumpable_context, + ) + else: + args = (cls, flatten_fn, unflatten_fn) + kwargs = { + "serialized_type_name": serialized_type_name, + "to_dumpable_context": to_dumpable_context, + "from_dumpable_context": from_dumpable_context, + } + _cxx_pytree_pending_imports.append((args, kwargs)) + + +def _register_namedtuple( + cls: Type[Any], + *, + serialized_type_name: str, +) -> None: + """ + Registers a namedtuple as a valid pytree node. By default namedtuples are + valid pytree nodes, but they are not serializable. This API provides the + argument `serialized_type_name` which allows these namedtuples to be + serialized. + + Args: + cls: the dataclass type to register + serialized_type_name: The serialized name for the dataclass. This is + required if you want to serialize the pytree TreeSpec containing this + namedtuple. + """ + _private_register_pytree_node( + cls, + _namedtuple_flatten, + _namedtuple_unflatten, + serialized_type_name=serialized_type_name, + to_dumpable_context=_namedtuple_serialize, + from_dumpable_context=_namedtuple_deserialize, + flatten_with_keys_fn=_namedtuple_flatten_with_keys, + ) + + +@deprecated( + "`core.utils._pytree._register_pytree_node` is deprecated. " + "Please use `core.utils._pytree.register_pytree_node` instead.", + category=FutureWarning, +) +def _register_pytree_node( + cls: Type[Any], + flatten_fn: FlattenFunc, + unflatten_fn: UnflattenFunc, + to_str_fn: Optional[ToStrFunc] = None, # deprecated + maybe_from_str_fn: Optional[MaybeFromStrFunc] = None, # deprecated + *, + serialized_type_name: Optional[str] = None, + to_dumpable_context: Optional[ToDumpableContextFn] = None, + from_dumpable_context: Optional[FromDumpableContextFn] = None, + flatten_with_keys_fn: Optional[FlattenWithKeysFunc] = None, +) -> None: + """Register a container-like type as pytree node for the Python pytree only. + + Args: + cls: the type to register + flatten_fn: A callable that takes a pytree and returns a flattened + representation of the pytree and additional context to represent the + flattened pytree. + unflatten_fn: A callable that takes a flattened version of the pytree, + additional context, and returns an unflattened pytree. + serialized_type_name: A keyword argument used to specify the fully qualified + name used when serializing the tree spec. + to_dumpable_context: An optional keyword argument to custom specify how + to convert the context of the pytree to a custom json dumpable + representation. This is used for json serialization, which is being + used in core.export right now. + from_dumpable_context: An optional keyword argument to custom specify how + to convert the custom json dumpable representation of the context + back to the original context. This is used for json deserialization, + which is being used in core.export right now. + flatten_with_keys_fn: An optional keyword argument to specify how to + access each pytree leaf's keypath when flattening and tree-mapping. + Like ``flatten_fn``, but in place of a List[leaf], it should return + a List[(keypath, leaf)]. + """ + if to_str_fn is not None or maybe_from_str_fn is not None: + warnings.warn( + "`to_str_fn` and `maybe_from_str_fn` is deprecated. " + "Please use `to_dumpable_context` and `from_dumpable_context` instead.", + FutureWarning, + stacklevel=2, + ) + + _private_register_pytree_node( + cls, + flatten_fn, + unflatten_fn, + serialized_type_name=serialized_type_name, + to_dumpable_context=to_dumpable_context, + from_dumpable_context=from_dumpable_context, + flatten_with_keys_fn=flatten_with_keys_fn, + ) + + +def _private_register_pytree_node( + cls: Type[Any], + flatten_fn: FlattenFunc, + unflatten_fn: UnflattenFunc, + *, + serialized_type_name: Optional[str] = None, + to_dumpable_context: Optional[ToDumpableContextFn] = None, + from_dumpable_context: Optional[FromDumpableContextFn] = None, + flatten_with_keys_fn: Optional[FlattenWithKeysFunc] = None, +) -> None: + """This is an internal function that is used to register a pytree node type + for the Python pytree only. End-users should use :func:`register_pytree_node` + instead. + """ + with _NODE_REGISTRY_LOCK: + if cls in SUPPORTED_NODES: + # TODO: change this warning to an error after OSS/internal stabilize + warnings.warn( + f"{cls} is already registered as pytree node. " + "Overwriting the previous registration.", + ) + + node_def = NodeDef(cls, flatten_fn, unflatten_fn, flatten_with_keys_fn) + SUPPORTED_NODES[cls] = node_def + + if (to_dumpable_context is None) ^ (from_dumpable_context is None): + raise ValueError( + f"Both to_dumpable_context and from_dumpable_context for {cls} must " + "be None or registered." + ) + + if serialized_type_name is None: + serialized_type_name = NO_SERIALIZED_TYPE_NAME_FOUND + + serialize_node_def = _SerializeNodeDef( + cls, + serialized_type_name, + to_dumpable_context, + from_dumpable_context, + ) + SUPPORTED_SERIALIZED_TYPES[cls] = serialize_node_def + SERIALIZED_TYPE_TO_PYTHON_TYPE[serialized_type_name] = cls + + +@dataclasses.dataclass(frozen=True) +class SequenceKey(Generic[T]): + idx: int + + def __str__(self) -> str: + return f"[{self.idx!r}]" + + def get(self, sequence: Sequence[T]) -> T: + return sequence[self.idx] + + +K = TypeVar("K", bound=Hashable) + + +@dataclasses.dataclass(frozen=True) +class MappingKey(Generic[K, T]): + key: K + + def __str__(self) -> str: + return f"[{self.key!r}]" + + def get(self, mapping: Mapping[K, T]) -> T: + return mapping[self.key] + + +@dataclasses.dataclass(frozen=True) +class GetAttrKey: + name: str + + def __str__(self) -> str: + return f".{self.name}" + + def get(self, obj: Any) -> Any: + return getattr(obj, self.name) + + +def _tuple_flatten(d: Tuple[Any, ...]) -> Tuple[List[Any], Context]: + return list(d), None + + +def _tuple_flatten_with_keys( + d: Tuple[Any, ...] +) -> Tuple[List[Tuple[KeyEntry, Any]], Context]: + values, context = _tuple_flatten(d) + return [(SequenceKey(i), v) for i, v in enumerate(values)], context + + +def _tuple_unflatten(values: Iterable[Any], context: Context) -> Tuple[Any, ...]: + return tuple(values) + + +def _list_flatten(d: List[Any]) -> Tuple[List[Any], Context]: + return d, None + + +def _list_flatten_with_keys(d: List[Any]) -> Tuple[List[Tuple[KeyEntry, Any]], Context]: + values, context = _list_flatten(d) + return [(SequenceKey(i), v) for i, v in enumerate(values)], context + + +def _list_unflatten(values: Iterable[Any], context: Context) -> List[Any]: + return list(values) + + +def _dict_flatten(d: Dict[Any, Any]) -> Tuple[List[Any], Context]: + return list(d.values()), list(d.keys()) + + +def _dict_flatten_with_keys( + d: Dict[Any, Any] +) -> Tuple[List[Tuple[KeyEntry, Any]], Context]: + values, context = _dict_flatten(d) + return [(MappingKey(k), v) for k, v in zip(context, values)], context + + +def _dict_unflatten(values: Iterable[Any], context: Context) -> Dict[Any, Any]: + return dict(zip(context, values)) + + +def _namedtuple_flatten(d: NamedTuple) -> Tuple[List[Any], Context]: + return list(d), type(d) + + +def _namedtuple_flatten_with_keys( + d: NamedTuple, +) -> Tuple[List[Tuple[KeyEntry, Any]], Context]: + values, context = _namedtuple_flatten(d) + return ( + [(GetAttrKey(field), v) for field, v in zip(context._fields, values)], + context, + ) + + +def _namedtuple_unflatten(values: Iterable[Any], context: Context) -> NamedTuple: + return cast(NamedTuple, context(*values)) + + +def _namedtuple_serialize(context: Context) -> DumpableContext: + if context not in SUPPORTED_SERIALIZED_TYPES: + raise NotImplementedError( + f"Can't serialize TreeSpec of namedtuple class {context} because we " + "didn't register a serializated_type_name. Please register using " + "`_register_namedtuple`." + ) + + serialize_node_def = SUPPORTED_SERIALIZED_TYPES[context] + serialized_type_name = serialize_node_def.serialized_type_name + + if serialized_type_name == NO_SERIALIZED_TYPE_NAME_FOUND: + raise NotImplementedError( + f"Can't serialize TreeSpec of namedtuple class {context} because we " + "couldn't find a serializated_type_name. Please register using " + "`_register_namedtuple`." + ) + return serialized_type_name + + +def _namedtuple_deserialize(dumpable_context: DumpableContext) -> Context: + if dumpable_context not in SERIALIZED_TYPE_TO_PYTHON_TYPE: + raise NotImplementedError( + f"Can't deserialize TreeSpec of namedtuple class {dumpable_context} " + "because we couldn't find a serializated name." + ) + + typ = SERIALIZED_TYPE_TO_PYTHON_TYPE[dumpable_context] + return typ + + +def _ordereddict_flatten(d: GenericOrderedDict[Any, Any]) -> Tuple[List[Any], Context]: + return list(d.values()), list(d.keys()) + + +def _ordereddict_flatten_with_keys( + d: GenericOrderedDict[Any, Any] +) -> Tuple[List[Tuple[KeyEntry, Any]], Context]: + values, context = _ordereddict_flatten(d) + return [(MappingKey(k), v) for k, v in zip(context, values)], context + + +def _ordereddict_unflatten( + values: Iterable[Any], + context: Context, +) -> GenericOrderedDict[Any, Any]: + return OrderedDict((key, value) for key, value in zip(context, values)) + + +_odict_flatten = _ordereddict_flatten +_odict_unflatten = _ordereddict_unflatten + + +def _defaultdict_flatten(d: DefaultDict[Any, Any]) -> Tuple[List[Any], Context]: + values, dict_context = _dict_flatten(d) + return values, [d.default_factory, dict_context] + + +def _defaultdict_flatten_with_keys( + d: DefaultDict[Any, Any] +) -> Tuple[List[Tuple[KeyEntry, Any]], Context]: + values, context = _defaultdict_flatten(d) + _, dict_context = context + return [(MappingKey(k), v) for k, v in zip(dict_context, values)], context + + +def _defaultdict_unflatten( + values: Iterable[Any], + context: Context, +) -> DefaultDict[Any, Any]: + default_factory, dict_context = context + return defaultdict(default_factory, _dict_unflatten(values, dict_context)) + + +def _defaultdict_serialize(context: Context) -> DumpableContext: + default_factory, dict_context = context + json_defaultdict = { + "default_factory_module": default_factory.__module__, + "default_factory_name": default_factory.__qualname__, + "dict_context": dict_context, + } + return json_defaultdict + + +def _defaultdict_deserialize(dumpable_context: DumpableContext) -> Context: + assert isinstance(dumpable_context, dict) + assert set(dumpable_context) == { + "default_factory_module", + "default_factory_name", + "dict_context", + } + + default_factory_module = dumpable_context["default_factory_module"] + default_factory_name = dumpable_context["default_factory_name"] + assert isinstance(default_factory_module, str) + assert isinstance(default_factory_name, str) + module = importlib.import_module(default_factory_module) + default_factory = getattr(module, default_factory_name) + + dict_context = dumpable_context["dict_context"] + return [default_factory, dict_context] + + +def _deque_flatten(d: Deque[Any]) -> Tuple[List[Any], Context]: + return list(d), d.maxlen + + +def _deque_flatten_with_keys( + d: Deque[Any], +) -> Tuple[List[Tuple[KeyEntry, Any]], Context]: + values, context = _deque_flatten(d) + return [(SequenceKey(i), v) for i, v in enumerate(values)], context + + +def _deque_unflatten(values: Iterable[Any], context: Context) -> Deque[Any]: + return deque(values, maxlen=context) + + +_private_register_pytree_node( + tuple, + _tuple_flatten, + _tuple_unflatten, + serialized_type_name="builtins.tuple", + flatten_with_keys_fn=_tuple_flatten_with_keys, +) +_private_register_pytree_node( + list, + _list_flatten, + _list_unflatten, + serialized_type_name="builtins.list", + flatten_with_keys_fn=_list_flatten_with_keys, +) +_private_register_pytree_node( + dict, + _dict_flatten, + _dict_unflatten, + serialized_type_name="builtins.dict", + flatten_with_keys_fn=_dict_flatten_with_keys, +) +_private_register_pytree_node( + namedtuple, # type: ignore[arg-type] + _namedtuple_flatten, + _namedtuple_unflatten, + serialized_type_name="collections.namedtuple", + to_dumpable_context=_namedtuple_serialize, + from_dumpable_context=_namedtuple_deserialize, + flatten_with_keys_fn=_namedtuple_flatten_with_keys, +) +_private_register_pytree_node( + OrderedDict, + _ordereddict_flatten, + _ordereddict_unflatten, + serialized_type_name="collections.OrderedDict", + flatten_with_keys_fn=_ordereddict_flatten_with_keys, +) +_private_register_pytree_node( + defaultdict, + _defaultdict_flatten, + _defaultdict_unflatten, + serialized_type_name="collections.defaultdict", + to_dumpable_context=_defaultdict_serialize, + from_dumpable_context=_defaultdict_deserialize, + flatten_with_keys_fn=_defaultdict_flatten_with_keys, +) +_private_register_pytree_node( + deque, + _deque_flatten, + _deque_unflatten, + serialized_type_name="collections.deque", + flatten_with_keys_fn=_deque_flatten_with_keys, +) + + +STANDARD_DICT_TYPES: FrozenSet[type] = frozenset( + {dict, OrderedDict, defaultdict}, +) +BUILTIN_TYPES: FrozenSet[type] = frozenset( + {tuple, list, dict, namedtuple, OrderedDict, defaultdict, deque}, # type: ignore[arg-type] +) + + +# h/t https://stackoverflow.com/questions/2166818/how-to-check-if-an-object-is-an-instance-of-a-namedtuple +def _is_namedtuple_instance(tree: Any) -> bool: + typ = type(tree) + bases = typ.__bases__ + if len(bases) != 1 or bases[0] != tuple: + return False + fields = getattr(typ, "_fields", None) + if not isinstance(fields, tuple): + return False + return all(type(entry) == str for entry in fields) + + +def _get_node_type(tree: Any) -> Any: + if _is_namedtuple_instance(tree): + return namedtuple + return type(tree) + + +# A leaf is defined as anything that is not a Node. +def _is_leaf(tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None) -> bool: + return (is_leaf is not None and is_leaf(tree)) or _get_node_type( + tree + ) not in SUPPORTED_NODES + + +# A TreeSpec represents the structure of a pytree. It holds: +# "type": the type of root Node of the pytree +# context: some context that is useful in unflattening the pytree +# children_specs: specs for each child of the root Node +# num_leaves: the number of leaves +@dataclasses.dataclass(init=True, frozen=True, eq=True, repr=False) +class TreeSpec: + type: Any + context: Context + children_specs: List["TreeSpec"] + + num_nodes: int = dataclasses.field(init=False) + num_leaves: int = dataclasses.field(init=False) + num_children: int = dataclasses.field(init=False) + + def __post_init__(self) -> None: + num_nodes = sum((spec.num_nodes for spec in self.children_specs), start=1) + num_leaves = sum(spec.num_leaves for spec in self.children_specs) + num_children = len(self.children_specs) + object.__setattr__(self, "num_nodes", num_nodes) + object.__setattr__(self, "num_leaves", num_leaves) + object.__setattr__(self, "num_children", num_children) + + def __repr__(self, indent: int = 0) -> str: + repr_prefix: str = f"TreeSpec({self.type.__name__}, {self.context}, [" + children_specs_str: str = "" + if self.num_children > 0: + indent += 2 + children_specs_str += self.children_specs[0].__repr__(indent) + children_specs_str += "," if self.num_children > 1 else "" + children_specs_str += ",".join( + [ + "\n" + " " * indent + child.__repr__(indent) + for child in self.children_specs[1:] + ] + ) + repr_suffix: str = f"{children_specs_str}])" + return repr_prefix + repr_suffix + + def is_leaf(self) -> bool: + return self.num_nodes == 1 and self.num_leaves == 1 + + def _flatten_up_to_helper(self, tree: PyTree, subtrees: List[PyTree]) -> None: + if self.is_leaf(): + subtrees.append(tree) + return + + node_type = _get_node_type(tree) + if self.type not in BUILTIN_TYPES: + # Always require custom node types to match exactly + if node_type != self.type: + raise ValueError( + f"Type mismatch; " + f"expected {self.type!r}, but got {node_type!r}.", + ) + flatten_fn = SUPPORTED_NODES[node_type].flatten_fn + child_pytrees, context = flatten_fn(tree) + if len(child_pytrees) != self.num_children: + raise ValueError( + f"Node arity mismatch; " + f"expected {self.num_children}, but got {len(child_pytrees)}.", + ) + if context != self.context: + raise ValueError( + f"Node context mismatch for custom node type {self.type!r}.", + ) + else: + # For builtin dictionary types, we allow some flexibility + # Otherwise, we require exact matches + both_standard_dict = ( + self.type in STANDARD_DICT_TYPES and node_type in STANDARD_DICT_TYPES + ) + if node_type != self.type and not both_standard_dict: + raise ValueError( + f"Node type mismatch; " + f"expected {self.type!r}, but got {node_type!r}.", + ) + if len(tree) != self.num_children: + raise ValueError( + f"Node arity mismatch; " + f"expected {self.num_children}, but got {len(tree)}.", + ) + + if both_standard_dict: # dictionary types are compatible with each other + dict_context = ( + self.context + if self.type is not defaultdict + # ignore mismatch of `default_factory` for defaultdict + else self.context[1] + ) + expected_keys = dict_context + got_key_set = set(tree) + expected_key_set = set(expected_keys) + if got_key_set != expected_key_set: + missing_keys = expected_key_set.difference(got_key_set) + extra_keys = got_key_set.difference(expected_key_set) + message = "" + if missing_keys: + message += f"; missing key(s): {missing_keys}" + if extra_keys: + message += f"; extra key(s): {extra_keys}" + raise ValueError(f"Node keys mismatch{message}.") + child_pytrees = [tree[key] for key in expected_keys] + else: + flatten_fn = SUPPORTED_NODES[node_type].flatten_fn + child_pytrees, context = flatten_fn(tree) + if ( + context != self.context + and self.type is not deque # ignore mismatch of `maxlen` for deque + ): + raise ValueError( + f"Node context mismatch for node type {self.type!r}; " + f"expected {self.context!r}, but got {context!r}.", # namedtuple type mismatch + ) + + for child_pytree, child_spec in zip(child_pytrees, self.children_specs): + child_spec._flatten_up_to_helper(child_pytree, subtrees) + + def flatten_up_to(self, tree: PyTree) -> List[PyTree]: + subtrees: List[PyTree] = [] + self._flatten_up_to_helper(tree, subtrees) + return subtrees + + def unflatten(self, leaves: Iterable[Any]) -> PyTree: + if not isinstance(leaves, (list, tuple)): + leaves = list(leaves) + if len(leaves) != self.num_leaves: + raise ValueError( + f"treespec.unflatten(leaves): `leaves` has length {len(leaves)} " + f"but the spec refers to a pytree that holds {self.num_leaves} " + f"items ({self}).", + ) + if self.is_leaf(): + return leaves[0] + + unflatten_fn = SUPPORTED_NODES[self.type].unflatten_fn + + # Recursively unflatten the children + start = 0 + end = 0 + child_pytrees = [] + for child_spec in self.children_specs: + end += child_spec.num_leaves + child_pytrees.append(child_spec.unflatten(leaves[start:end])) + start = end + + return unflatten_fn(child_pytrees, self.context) + + +class LeafSpec(TreeSpec): + def __init__(self) -> None: + super().__init__(None, None, []) + + def __post_init__(self) -> None: + object.__setattr__(self, "num_nodes", 1) + object.__setattr__(self, "num_leaves", 1) + object.__setattr__(self, "num_children", 0) + + def __repr__(self, indent: int = 0) -> str: + return "*" + + +# All leaves are equivalent, so represent with a single object to save on +# object construction time +_LEAF_SPEC = LeafSpec() + + +def _tree_flatten_helper( + tree: PyTree, + leaves: List[Any], + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> TreeSpec: + if _is_leaf(tree, is_leaf=is_leaf): + leaves.append(tree) + return _LEAF_SPEC + + node_type = _get_node_type(tree) + flatten_fn = SUPPORTED_NODES[node_type].flatten_fn + child_pytrees, context = flatten_fn(tree) + + # Recursively flatten the children + children_specs = [ + _tree_flatten_helper(child, leaves, is_leaf=is_leaf) for child in child_pytrees + ] + + return TreeSpec(node_type, context, children_specs) + + +def tree_flatten( + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> Tuple[List[Any], TreeSpec]: + """Flattens a pytree into a list of values and a TreeSpec that can be used + to reconstruct the pytree. + """ + leaves: List[Any] = [] + spec = _tree_flatten_helper(tree, leaves, is_leaf=is_leaf) + return leaves, spec + + +def tree_unflatten(leaves: Iterable[Any], treespec: TreeSpec) -> PyTree: + """Given a list of values and a TreeSpec, builds a pytree. + This is the inverse operation of `tree_flatten`. + """ + if not isinstance(treespec, TreeSpec): + raise TypeError( + f"tree_unflatten(leaves, treespec): Expected `treespec` to be " + f"instance of TreeSpec but got item of type {type(treespec)}.", + ) + return treespec.unflatten(leaves) + + +def tree_iter( + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> Iterable[Any]: + """Get an iterator over the leaves of a pytree.""" + if _is_leaf(tree, is_leaf=is_leaf): + yield tree + else: + node_type = _get_node_type(tree) + flatten_fn = SUPPORTED_NODES[node_type].flatten_fn + child_pytrees, _ = flatten_fn(tree) + + # Recursively flatten the children + for child in child_pytrees: + yield from tree_iter(child, is_leaf=is_leaf) + + +def tree_leaves( + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> List[Any]: + """Get a list of leaves of a pytree.""" + return list(tree_iter(tree, is_leaf=is_leaf)) + + +def tree_structure( + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> TreeSpec: + """Get the TreeSpec for a pytree.""" + return tree_flatten(tree, is_leaf=is_leaf)[1] + + +def tree_map( + func: Callable[..., Any], + tree: PyTree, + *rests: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> PyTree: + """Map a multi-input function over pytree args to produce a new pytree. + + See also :func:`tree_map_`. + + >>> tree_map(lambda x: x + 1, {'x': 7, 'y': (42, 64)}) + {'x': 8, 'y': (43, 65)} + >>> tree_map(lambda x: x is None, {'x': 7, 'y': (42, 64), 'z': None}) + {'x': False, 'y': (False, False), 'z': True} + + If multiple inputs are given, the structure of the tree is taken from the first input; + subsequent inputs need only have ``tree`` as a prefix: + + >>> tree_map(lambda x, y: [x] + y, [5, 6], [[7, 9], [1, 2]]) + [[5, 7, 9], [6, 1, 2]] + + Args: + func (callable): A function that takes ``1 + len(rests)`` arguments, to be applied at the + corresponding leaves of the pytrees. + tree (pytree): A pytree to be mapped over, with each leaf providing the first positional + argument to function ``func``. + rests (tuple of pytree): A tuple of pytrees, each of which has the same structure as + ``tree`` or has ``tree`` as a prefix. + is_leaf (callable, optional): An extra leaf predicate function that will be called at each + flattening step. The function should have a single argument with signature + ``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated + as a leaf. Otherwise, the default pytree registry will be used to determine a node is a + leaf or not. If the function is not specified, the default pytree registry will be used. + + Returns: + A new pytree with the same structure as ``tree`` but with the value at each leaf given by + ``func(x, *xs)`` where ``x`` is the value at the corresponding leaf in ``tree`` and ``xs`` + is the tuple of values at corresponding nodes in ``rests``. + """ + leaves, treespec = tree_flatten(tree, is_leaf=is_leaf) + flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests] + return treespec.unflatten(map(func, *flat_args)) + + +def tree_map_( + func: Callable[..., Any], + tree: PyTree, + *rests: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> PyTree: + """Like :func:`tree_map`, but do an inplace call on each leaf and return the original tree. + + See also :func:`tree_map`. + + Args: + func (callable): A function that takes ``1 + len(rests)`` arguments, to be applied at the + corresponding leaves of the pytrees. + tree (pytree): A pytree to be mapped over, with each leaf providing the first positional + argument to function ``func``. + rests (tuple of pytree): A tuple of pytrees, each of which has the same structure as + ``tree`` or has ``tree`` as a prefix. + is_leaf (callable, optional): An extra leaf predicate function that will be called at each + flattening step. The function should have a single argument with signature + ``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated + as a leaf. Otherwise, the default pytree registry will be used to determine a node is a + leaf or not. If the function is not specified, the default pytree registry will be used. + + Returns: + The original ``tree`` with the value at each leaf is given by the side-effect of function + ``func(x, *xs)`` (not the return value) where ``x`` is the value at the corresponding leaf + in ``tree`` and ``xs`` is the tuple of values at values at corresponding nodes in ``rests``. + """ + leaves, treespec = tree_flatten(tree, is_leaf=is_leaf) + flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests] + deque(map(func, *flat_args), maxlen=0) # consume and exhaust the iterable + return tree + + +Type2 = Tuple[Type[T], Type[S]] +Type3 = Tuple[Type[T], Type[S], Type[U]] +if sys.version_info >= (3, 10): + TypeAny = Union[Type[Any], Tuple[Type[Any], ...], types.UnionType] +else: + TypeAny = Union[Type[Any], Tuple[Type[Any], ...]] + +Fn2 = Callable[[Union[T, S]], R] +Fn3 = Callable[[Union[T, S, U]], R] +Fn = Callable[[T], R] +FnAny = Callable[[Any], R] + +MapOnlyFn = Callable[[T], Callable[[Any], Any]] + + +# These specializations help with type inference on the lambda passed to this +# function +@overload +def map_only(__type_or_types_or_pred: Type2[T, S]) -> MapOnlyFn[Fn2[T, S, Any]]: + ... + + +@overload +def map_only(__type_or_types_or_pred: Type3[T, S, U]) -> MapOnlyFn[Fn3[T, S, U, Any]]: + ... + + +@overload +def map_only(__type_or_types_or_pred: Type[T]) -> MapOnlyFn[Fn[T, Any]]: + ... + + +# This specialization is needed for the implementations below that call +@overload +def map_only(__type_or_types_or_pred: TypeAny) -> MapOnlyFn[FnAny[Any]]: + ... + + +@overload +def map_only(__type_or_types_or_pred: Callable[[Any], bool]) -> MapOnlyFn[FnAny[Any]]: + ... + + +def map_only( + __type_or_types_or_pred: Union[TypeAny, Callable[[Any], bool]] +) -> MapOnlyFn[FnAny[Any]]: + """ + Suppose you are writing a tree_map over tensors, leaving everything + else unchanged. Ordinarily you would have to write: + + def go(t): + if isinstance(t, Tensor): + return ... + else: + return t + + With this function, you only need to write: + + @map_only(Tensor) + def go(t): + return ... + + You can also directly use 'tree_map_only' + """ + if isinstance(__type_or_types_or_pred, (type, tuple)) or ( + sys.version_info >= (3, 10) + and isinstance(__type_or_types_or_pred, types.UnionType) + ): + + def pred(x: Any) -> bool: + return isinstance(x, __type_or_types_or_pred) # type: ignore[arg-type] + + elif callable(__type_or_types_or_pred): + pred = __type_or_types_or_pred # type: ignore[assignment] + else: + raise TypeError("Argument must be a type, a tuple of types, or a callable.") + + def wrapper(func: Callable[[T], Any]) -> Callable[[Any], Any]: + @functools.wraps(func) + def wrapped(x: T) -> Any: + if pred(x): + return func(x) + return x + + return wrapped + + return wrapper + + +@overload +def tree_map_only( + __type_or_types_or_pred: Type[T], + func: Fn[T, Any], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> PyTree: + ... + + +@overload +def tree_map_only( + __type_or_types_or_pred: Type2[T, S], + func: Fn2[T, S, Any], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> PyTree: + ... + + +@overload +def tree_map_only( + __type_or_types_or_pred: Type3[T, S, U], + func: Fn3[T, S, U, Any], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> PyTree: + ... + + +@overload +def tree_map_only( + __type_or_types_or_pred: Callable[[Any], bool], + func: FnAny[Any], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> PyTree: + ... + + +def tree_map_only( + __type_or_types_or_pred: Union[TypeAny, Callable[[Any], bool]], + func: FnAny[Any], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> PyTree: + return tree_map(map_only(__type_or_types_or_pred)(func), tree, is_leaf=is_leaf) + + +@overload +def tree_map_only_( + __type_or_types_or_pred: Type[T], + func: Fn[T, Any], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> PyTree: + ... + + +@overload +def tree_map_only_( + __type_or_types_or_pred: Type2[T, S], + func: Fn2[T, S, Any], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> PyTree: + ... + + +@overload +def tree_map_only_( + __type_or_types_or_pred: Type3[T, S, U], + func: Fn3[T, S, U, Any], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> PyTree: + ... + + +@overload +def tree_map_only_( + __type_or_types_or_pred: Callable[[Any], bool], + func: FnAny[Any], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> PyTree: + ... + + +def tree_map_only_( + __type_or_types_or_pred: Union[TypeAny, Callable[[Any], bool]], + func: FnAny[Any], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> PyTree: + return tree_map_(map_only(__type_or_types_or_pred)(func), tree, is_leaf=is_leaf) + + +def tree_all( + pred: Callable[[Any], bool], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> bool: + flat_args = tree_iter(tree, is_leaf=is_leaf) + return all(map(pred, flat_args)) + + +def tree_any( + pred: Callable[[Any], bool], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> bool: + flat_args = tree_iter(tree, is_leaf=is_leaf) + return any(map(pred, flat_args)) + + +@overload +def tree_all_only( + __type_or_types: Type[T], + pred: Fn[T, bool], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> bool: + ... + + +@overload +def tree_all_only( + __type_or_types: Type2[T, S], + pred: Fn2[T, S, bool], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> bool: + ... + + +@overload +def tree_all_only( + __type_or_types: Type3[T, S, U], + pred: Fn3[T, S, U, bool], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> bool: + ... + + +def tree_all_only( + __type_or_types: TypeAny, + pred: FnAny[bool], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> bool: + flat_args = tree_iter(tree, is_leaf=is_leaf) + return all(pred(x) for x in flat_args if isinstance(x, __type_or_types)) + + +@overload +def tree_any_only( + __type_or_types: Type[T], + pred: Fn[T, bool], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> bool: + ... + + +@overload +def tree_any_only( + __type_or_types: Type2[T, S], + pred: Fn2[T, S, bool], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> bool: + ... + + +@overload +def tree_any_only( + __type_or_types: Type3[T, S, U], + pred: Fn3[T, S, U, bool], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> bool: + ... + + +def tree_any_only( + __type_or_types: TypeAny, + pred: FnAny[bool], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> bool: + flat_args = tree_iter(tree, is_leaf=is_leaf) + return any(pred(x) for x in flat_args if isinstance(x, __type_or_types)) + + +# Broadcasts a pytree to the provided TreeSpec and returns the flattened +# values. If this is not possible, then this function returns None. +# +# For example, given pytree=0 and spec=TreeSpec(list, None, [LeafSpec(), LeafSpec()]), +# would return [0, 0]. This is useful for part of the vmap implementation: +# a user can pass in vmap(fn, in_dims)(*inputs). `in_dims` should be +# broadcastable to the tree structure of `inputs` and we use +# _broadcast_to_and_flatten to check this. +def _broadcast_to_and_flatten( + tree: PyTree, + treespec: TreeSpec, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> Optional[List[Any]]: + assert isinstance(treespec, TreeSpec) + + if _is_leaf(tree, is_leaf=is_leaf): + return [tree] * treespec.num_leaves + if treespec.is_leaf(): + return None + node_type = _get_node_type(tree) + if node_type != treespec.type: + return None + + flatten_fn = SUPPORTED_NODES[node_type].flatten_fn + child_pytrees, ctx = flatten_fn(tree) + + # Check if the Node is different from the spec + if len(child_pytrees) != treespec.num_children or ctx != treespec.context: + return None + + # Recursively flatten the children + result: List[Any] = [] + for child, child_spec in zip(child_pytrees, treespec.children_specs): + flat = _broadcast_to_and_flatten(child, child_spec, is_leaf=is_leaf) + if flat is not None: + result += flat + else: + return None + + return result + + +@dataclasses.dataclass +class _TreeSpecSchema: + """ + _TreeSpecSchema is the schema used to serialize the TreeSpec + It contains the following fields: + - type: A string name of the type. null for the case of a LeafSpec. + - context: Any format which is json dumpable + - children_spec: A list of children serialized specs. + """ + + type: Optional[str] + context: DumpableContext + children_spec: List["_TreeSpecSchema"] + + +class _ProtocolFn(NamedTuple): + treespec_to_json: Callable[[TreeSpec], DumpableContext] + json_to_treespec: Callable[[DumpableContext], TreeSpec] + + +_SUPPORTED_PROTOCOLS: Dict[int, _ProtocolFn] = {} + + +def _treespec_to_json(treespec: TreeSpec) -> _TreeSpecSchema: + if treespec.is_leaf(): + return _TreeSpecSchema(None, None, []) + + if treespec.type not in SUPPORTED_SERIALIZED_TYPES: + raise NotImplementedError( + f"Serializing {treespec.type} in pytree is not registered.", + ) + + serialize_node_def = SUPPORTED_SERIALIZED_TYPES[treespec.type] + + serialized_type_name = serialize_node_def.serialized_type_name + + if serialized_type_name == NO_SERIALIZED_TYPE_NAME_FOUND: + raise NotImplementedError( + f"No registered serialization name for {treespec.type} found. " + "Please update your _register_pytree_node call with a `serialized_type_name` kwarg." + ) + + if serialize_node_def.to_dumpable_context is None: + try: + serialized_context = json.dumps(treespec.context) + except TypeError as e: + raise TypeError( + "Unable to serialize context. " + "Please make the context json dump-able, or register a " + "custom serializer using _register_pytree_node." + ) from e + else: + serialized_context = serialize_node_def.to_dumpable_context(treespec.context) + + child_schemas = [_treespec_to_json(child) for child in treespec.children_specs] + + return _TreeSpecSchema(serialized_type_name, serialized_context, child_schemas) + + +def _json_to_treespec(json_schema: DumpableContext) -> TreeSpec: + if ( + json_schema["type"] is None + and json_schema["context"] is None + and len(json_schema["children_spec"]) == 0 + ): + return _LEAF_SPEC + + if json_schema["type"] not in SERIALIZED_TYPE_TO_PYTHON_TYPE: + raise NotImplementedError( + f'Deserializing {json_schema["type"]} in pytree is not registered.', + ) + + typ = SERIALIZED_TYPE_TO_PYTHON_TYPE[json_schema["type"]] + serialize_node_def = SUPPORTED_SERIALIZED_TYPES[typ] + + if serialize_node_def.from_dumpable_context is None: + try: + context = json.loads(json_schema["context"]) + except TypeError as ex: + raise TypeError( + "Unable to deserialize context. " + "Please make the context json load-able, or register a " + "custom serializer using _register_pytree_node.", + ) from ex + else: + context = serialize_node_def.from_dumpable_context(json_schema["context"]) + + children_specs = [ + _json_to_treespec(child_string) for child_string in json_schema["children_spec"] + ] + + return TreeSpec(typ, context, children_specs) + + +_SUPPORTED_PROTOCOLS[1] = _ProtocolFn(_treespec_to_json, _json_to_treespec) + + +def treespec_dumps(treespec: TreeSpec, protocol: Optional[int] = None) -> str: + if not isinstance(treespec, TreeSpec): + raise TypeError( + f"treespec_dumps(treespec, protocol): Expected `treespec` to be instance of " + f"TreeSpec but got item of type {type(treespec)}.", + ) + + if protocol is None: + protocol = DEFAULT_TREESPEC_SERIALIZATION_PROTOCOL + + if protocol in _SUPPORTED_PROTOCOLS: + json_spec = _SUPPORTED_PROTOCOLS[protocol].treespec_to_json(treespec) + else: + raise ValueError( + f"Unknown protocol {protocol}. " + f"Available protocols: {list(_SUPPORTED_PROTOCOLS.keys())}", + ) + + str_spec = json.dumps((protocol, dataclasses.asdict(json_spec))) + return str_spec + + +def treespec_loads(serialized: str) -> TreeSpec: + protocol, json_schema = json.loads(serialized) + + if protocol in _SUPPORTED_PROTOCOLS: + return _SUPPORTED_PROTOCOLS[protocol].json_to_treespec(json_schema) + raise ValueError( + f"Unknown protocol {protocol}. " + f"Available protocols: {list(_SUPPORTED_PROTOCOLS.keys())}", + ) + + +class _DummyLeaf: + def __repr__(self) -> str: + return "*" + + +def treespec_pprint(treespec: TreeSpec) -> str: + dummy_tree = tree_unflatten( + [_DummyLeaf() for _ in range(treespec.num_leaves)], + treespec, + ) + return repr(dummy_tree) + + +# TODO(angelayi): remove this function after OSS/internal stabilize +@deprecated( + "`pytree_to_str` is deprecated. Please use `treespec_dumps` instead.", + category=FutureWarning, +) +def pytree_to_str(treespec: TreeSpec) -> str: + return treespec_dumps(treespec) + + +# TODO(angelayi): remove this function after OSS/internal stabilize +@deprecated( + "`str_to_pytree` is deprecated. Please use `treespec_loads` instead.", + category=FutureWarning, +) +def str_to_pytree(json: str) -> TreeSpec: + return treespec_loads(json) + + +def arg_tree_leaves(*args: PyTree, **kwargs: PyTree) -> List[Any]: + """Get a flat list of arguments to this function + + A slightly faster version of tree_leaves((args, kwargs)) + """ + leaves: List[Any] = [] + for a in args: + leaves.extend(tree_iter(a)) + for a in kwargs.values(): + leaves.extend(tree_iter(a)) + return leaves + + +def tree_flatten_with_path( + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> Tuple[List[Tuple[KeyPath, Any]], TreeSpec]: + """Flattens a pytree like :func:`tree_flatten`, but also returns each leaf's key path. + + Args: + tree: a pytree to flatten. If it contains a custom type, that type must be + registered with an appropriate `tree_flatten_with_path_fn` when registered + with :func:`register_pytree_node`. + is_leaf: An extra leaf predicate function that will be called at each + flattening step. The function should have a single argument with signature + ``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated + as a leaf. Otherwise, the default pytree registry will be used to determine a node is a + leaf or not. If the function is not specified, the default pytree registry will be used. + Returns: + A tuple where the first element is a list of (key path, leaf) pairs, and the + second element is a :class:`TreeSpec` representing the structure of the flattened + tree. + """ + _, treespec = tree_flatten(tree, is_leaf) + return list(_generate_key_paths((), tree, is_leaf)), treespec + + +def tree_leaves_with_path( + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> List[Tuple[KeyPath, Any]]: + """Gets the leaves of a pytree like ``tree_leaves`` and returns each leaf's key path. + + Args: + tree: a pytree. If it contains a custom type, that type must be + registered with an appropriate `tree_flatten_with_path_fn` when registered + with :func:`register_pytree_node`. + is_leaf: An extra leaf predicate function that will be called at each + flattening step. The function should have a single argument with signature + ``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated + as a leaf. Otherwise, the default pytree registry will be used to determine a node is a + leaf or not. If the function is not specified, the default pytree registry will be used. + Returns: + A list of (key path, leaf) pairs. + """ + return list(_generate_key_paths((), tree, is_leaf)) + + +def _generate_key_paths( + key_path: KeyPath, + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> Iterable[Tuple[KeyPath, Any]]: + if is_leaf and is_leaf(tree): + yield key_path, tree + return + + node_type = _get_node_type(tree) + handler = SUPPORTED_NODES.get(node_type) + if not handler: + # This is a leaf + yield key_path, tree + return + + flatten_with_keys = handler.flatten_with_keys_fn + if flatten_with_keys: + key_children, _ = flatten_with_keys(tree) + for k, c in key_children: + yield from _generate_key_paths((*key_path, k), c, is_leaf) + else: + # We registered this pytree but didn't add a flatten_with_keys_fn, complain. + raise ValueError( + f"Did not find a flatten_with_keys_fn for type: {node_type}. " + "Please pass a flatten_with_keys_fn argument to register_pytree_node." + ) + + +def tree_map_with_path( + func: Callable[..., Any], + tree: PyTree, + *rests: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> PyTree: + """Like :func:`tree_map`, but the provided callable takes an additional key path argument. + + Args: + func: A function that takes ``2 + len(rests)`` arguments, to be applied at the + corresponding leaves of the pytrees. The first positional argument + to ``func`` is the key path of the leaf in question. The second + positional argument is the value of the leaf. + tree: A pytree to be mapped over, with each leaf providing the first positional + argument to function ``func``. + rests: A tuple of pytrees, each of which has the same structure as + ``tree`` or has ``tree`` as a prefix. + is_leaf: An extra leaf predicate function that will be called at each + flattening step. The function should have a single argument with signature + ``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated + as a leaf. Otherwise, the default pytree registry will be used to determine a node is a + leaf or not. If the function is not specified, the default pytree registry will be used. + + Returns + A new pytree with the same structure as ``tree`` but with the value at each leaf given by + ``func(keypath, x, *xs)`` where ``keypath`` is the key path at the + corresponding leaf in ``tree``, ``x`` is the value at that leaf, and + ``xs`` is the tuple of values at corresponding nodes in ``rests``. + """ + keypath_leaves, treespec = tree_flatten_with_path(tree, is_leaf) + keypath_leaves = list(zip(*keypath_leaves)) + all_keypath_leaves = keypath_leaves + [treespec.flatten_up_to(r) for r in rests] + return treespec.unflatten(func(*xs) for xs in zip(*all_keypath_leaves)) + + +def keystr(kp: KeyPath) -> str: + """Given a key path, return a pretty-printed representation.""" + return "".join([str(k) for k in kp]) + + +def key_get(obj: Any, kp: KeyPath) -> Any: + """Given an object and a key path, return the value at the key path.""" + for k in kp: + obj = k.get(obj) + return obj \ No newline at end of file diff --git a/mindnlp/core/utils/_typing_utils.py b/mindnlp/core/utils/_typing_utils.py new file mode 100644 index 000000000..5e5234b35 --- /dev/null +++ b/mindnlp/core/utils/_typing_utils.py @@ -0,0 +1,14 @@ +"""Miscellaneous utilities to aid with typing.""" + +from typing import Optional, TypeVar + + +# Helper to turn Optional[T] into T when we know None either isn't +# possible or should trigger an exception. +T = TypeVar("T") + + +def not_none(obj: Optional[T]) -> T: + if obj is None: + raise TypeError("Invariant encountered: value was None when it should not be") + return obj \ No newline at end of file diff --git a/mindnlp/core/utils/checkpoint.py b/mindnlp/core/utils/checkpoint.py new file mode 100644 index 000000000..0a0834389 --- /dev/null +++ b/mindnlp/core/utils/checkpoint.py @@ -0,0 +1,31 @@ +from typing import * # noqa: F403 +from mindnlp import core + +def detach_variable(inputs: Tuple[Any, ...]) -> Tuple[core.Tensor, ...]: + if isinstance(inputs, tuple): + out = [] + for inp in inputs: + if not isinstance(inp, core.Tensor): + out.append(inp) + continue + + x = inp.detach() + x.requires_grad = inp.requires_grad + out.append(x) + return tuple(out) + else: + raise RuntimeError( + "Only tuple of tensors is supported. Got Unsupported input type: ", + type(inputs).__name__, + ) + +def checkpoint( + function, + *args, + use_reentrant = None, + context_fn = None, + determinism_check = None, + debug = None, + **kwargs +): + return function(*args, **kwargs) \ No newline at end of file diff --git a/mindnlp/core/utils/cpp_extension.py b/mindnlp/core/utils/cpp_extension.py new file mode 100644 index 000000000..286c40f76 --- /dev/null +++ b/mindnlp/core/utils/cpp_extension.py @@ -0,0 +1,15 @@ +def load( + name, + sources, + extra_cflags=None, + extra_cuda_cflags=None, + extra_ldflags=None, + extra_include_paths=None, + build_directory=None, + verbose=False, + with_cuda=None, + is_python_module=True, + is_standalone=False, + keep_intermediates=True, +): + pass diff --git a/mindnlp/core/utils/hooks.py b/mindnlp/core/utils/hooks.py new file mode 100644 index 000000000..ca834c5cc --- /dev/null +++ b/mindnlp/core/utils/hooks.py @@ -0,0 +1,254 @@ +"""hooks""" +from collections import OrderedDict +import weakref +import warnings +from typing import Any, Tuple +from mindnlp import core + +__all__ = ["RemovableHandle", "unserializable_hook", "warn_if_has_hooks", "BackwardHook"] + +class RemovableHandle: + r""" + A handle which provides the capability to remove a hook. + + Args: + hooks_dict (dict): A dictionary of hooks, indexed by hook ``id``. + extra_dict (Union[dict, List[dict]]): An additional dictionary or list of + dictionaries whose keys will be deleted when the same keys are + removed from ``hooks_dict``. + """ + + id: int + next_id: int = 0 + + def __init__(self, hooks_dict: Any, *, extra_dict: Any = None) -> None: + self.hooks_dict_ref = weakref.ref(hooks_dict) + self.id = RemovableHandle.next_id + RemovableHandle.next_id += 1 + + self.extra_dict_ref: Tuple = () + if isinstance(extra_dict, dict): + self.extra_dict_ref = (weakref.ref(extra_dict),) + elif isinstance(extra_dict, list): + self.extra_dict_ref = tuple(weakref.ref(d) for d in extra_dict) + + def remove(self) -> None: + hooks_dict = self.hooks_dict_ref() + if hooks_dict is not None and self.id in hooks_dict: + del hooks_dict[self.id] + + for ref in self.extra_dict_ref: + extra_dict = ref() + if extra_dict is not None and self.id in extra_dict: + del extra_dict[self.id] + + def __getstate__(self): + if self.extra_dict_ref is None: + return (self.hooks_dict_ref(), self.id) + else: + return (self.hooks_dict_ref(), self.id, tuple(ref() for ref in self.extra_dict_ref)) + + def __setstate__(self, state) -> None: + if state[0] is None: + # create a dead reference + self.hooks_dict_ref = weakref.ref(OrderedDict()) + else: + self.hooks_dict_ref = weakref.ref(state[0]) + self.id = state[1] + RemovableHandle.next_id = max(RemovableHandle.next_id, self.id + 1) + + if len(state) < 3 or state[2] is None: + self.extra_dict_ref = () + else: + self.extra_dict_ref = tuple(weakref.ref(d) for d in state[2]) + + def __enter__(self) -> "RemovableHandle": + return self + + def __exit__(self, type: Any, value: Any, tb: Any) -> None: + self.remove() + + +def unserializable_hook(f): + """ + Mark a function as an unserializable hook with this decorator. + + This suppresses warnings that would otherwise arise if you attempt + to serialize a tensor that has a hook. + """ + f.__torch_unserializable__ = True + return f + + +def warn_if_has_hooks(tensor): + if tensor._backward_hooks: + for k in tensor._backward_hooks: + hook = tensor._backward_hooks[k] + if not hasattr(hook, "__torch_unserializable__"): + warnings.warn(f"backward hook {repr(hook)} on tensor will not be " + "serialized. If this is expected, you can " + "decorate the function with @core.utils.hooks.unserializable_hook " + "to suppress this warning") + +class BackwardHook: + """ + A wrapper class to implement nn.Module backward hooks. + + It handles: + - Ignoring non-Tensor inputs and replacing them by None before calling the user hook + - Generating the proper Node to capture a set of Tensor's gradients + - Linking the gradients captures for the outputs with the gradients captured for the input + - Calling the user hook once both output and input gradients are available + """ + + def __init__(self, module, user_hooks, user_pre_hooks): + self.user_hooks = user_hooks + self.user_pre_hooks = user_pre_hooks + self.module = module + + self.grad_outputs = None + self.n_outputs = -1 + self.output_tensors_index = None + self.n_inputs = -1 + self.input_tensors_index = None + + def _pack_with_none(self, indices, values, size): + res = [None] * size + for idx, val in zip(indices, values): + res[idx] = val + + return tuple(res) + + def _unpack_none(self, indices, values): + res = [] + for idx in indices: + res.append(values[idx]) + + return tuple(res) + + def _set_user_hook(self, grad_fn): + def hook(grad_input, _): + if self.grad_outputs is None: + # This happens because the gradient in your nn.Module flows to + # the Module's input without " passing through the Module's + # output, e.g. when you're doing double backward. + return + res = self._pack_with_none(self.input_tensors_index, grad_input, self.n_inputs) + + for hook in self.user_hooks: + out = hook(self.module, res, self.grad_outputs) + + if out is None: + continue + + if len(out) != len(res): + raise RuntimeError("Backward hook returned an invalid number of grad_input, " + f"got {len(out)}, but expected {len(res)}") + + res = out + + self.grad_outputs = None + + return self._unpack_none(self.input_tensors_index, res) + + grad_fn.register_hook(hook) + + def _apply_on_tensors(self, fn, args): + # Can be used to apply the given function to the tensors contained in the + # args. Will return updated args and the tensors indices + tensors_idx = [] + tensors = [] + + requires_grad = False + for i, arg in enumerate(args): + if isinstance(arg, core.Tensor): + tensors_idx.append(i) + tensors.append(arg) + requires_grad |= arg.requires_grad + + # if not (requires_grad and core.is_grad_enabled()): + # return args, None + + # new_tensors = core.nn.modules._functions.BackwardHookFunction.apply(*tensors) + new_tensors = tensors + if len(new_tensors) == 0: + raise RuntimeError("Cannot set Module backward hook for a Module with no input Tensors.") + + grad_fns = [t.grad_fn for t in new_tensors if t.grad_fn is not None and t.grad_fn.name() == "BackwardHookFunctionBackward"] + if len(grad_fns) == 0: + raise RuntimeError("Error while setting up backward hooks. Please open " + "an issue with a code sample to reproduce this.") + + fn(grad_fns[0]) + + arg_list = list(args) + for idx, val in zip(tensors_idx, new_tensors): + arg_list[idx] = val + + if type(args) is tuple: + out = tuple(arg_list) + else: + out = type(args)(*arg_list) + return out, tensors_idx + + def setup_input_hook(self, args): + def fn(grad_fn): + self._set_user_hook(grad_fn) + + res, input_idx = self._apply_on_tensors(fn, args) + self.n_inputs = len(args) + self.input_tensors_index = input_idx + return res + + def setup_output_hook(self, args): + def fn(grad_fn): + def hook(_, grad_output): + self.grad_outputs = self._pack_with_none(self.output_tensors_index, + grad_output, + self.n_outputs) + + if self.user_pre_hooks: + expected_len = len(self.grad_outputs) + for user_pre_hook in self.user_pre_hooks: + hook_grad_outputs = user_pre_hook(self.module, self.grad_outputs) + if hook_grad_outputs is None: + continue + + actual_len = len(hook_grad_outputs) + if actual_len != expected_len: + raise RuntimeError("Backward pre hook returned an invalid number of grad_output, " + f"got {actual_len}, but expected {expected_len}") + self.grad_outputs = hook_grad_outputs + + # We need to be able to clear self.grad_outputs but also return it + local_grad_outputs = self.grad_outputs + + # Special case if no input required gradients, this hook should call the user + # hook directly + if self.input_tensors_index is None: + grad_inputs = self._pack_with_none([], [], self.n_inputs) + for user_hook in self.user_hooks: + res = user_hook(self.module, grad_inputs, self.grad_outputs) + if res is not None and not (isinstance(res, tuple) and all(el is None for el in res)): + raise RuntimeError("Backward hook for Modules where no input requires " + "gradient should always return None or None for all gradients.") + self.grad_outputs = None + + if local_grad_outputs is not None: + assert self.output_tensors_index is not None # mypy + return tuple(local_grad_outputs[i] for i in self.output_tensors_index) + + grad_fn.register_hook(hook) + + is_tuple = True + if not isinstance(args, tuple): + args = (args,) + is_tuple = False + + res, output_idx = self._apply_on_tensors(fn, args) + self.n_outputs = len(args) + self.output_tensors_index = output_idx + + if not is_tuple: + res = res[0] + return res diff --git a/mindnlp/core/utils/model_zoo.py b/mindnlp/core/utils/model_zoo.py new file mode 100644 index 000000000..d7c082651 --- /dev/null +++ b/mindnlp/core/utils/model_zoo.py @@ -0,0 +1 @@ +from tqdm import tqdm \ No newline at end of file diff --git a/mindnlp/core/version.py b/mindnlp/core/version.py new file mode 100644 index 000000000..358b18b88 --- /dev/null +++ b/mindnlp/core/version.py @@ -0,0 +1,5 @@ +import mindspore + +hip = None +cuda = mindspore.get_context('device_target') == 'GPU' +npu = mindspore.get_context('device_target') == 'Ascend' diff --git a/mindnlp/dataset/load.py b/mindnlp/dataset/load.py index 4a5afc396..6a8c9424d 100644 --- a/mindnlp/dataset/load.py +++ b/mindnlp/dataset/load.py @@ -23,9 +23,6 @@ DownloadConfig, DownloadMode, VerificationMode, Version from mindspore.dataset import GeneratorDataset from mindspore.communication import get_rank, get_group_size -from mindnlp.configs import DEFAULT_ROOT -from ..accelerate import DistributedType -from ..accelerate.utils import accelerate_distributed_type class TransferIterableDataset(): @@ -301,8 +298,6 @@ def load_dataset( ``` """ shuffle = config_kwargs.get('shuffle', False) - if cache_dir is None: - cache_dir = os.path.join(DEFAULT_ROOT, "datasets", path) ds_ret = hf_load(path, name=name, @@ -335,19 +330,12 @@ def load_dataset( column_names = list(raw_ds.features.keys()) source = TransferDataset(raw_ds, column_names) if isinstance(raw_ds, Dataset) \ else TransferIterableDataset(raw_ds, column_names) - if accelerate_distributed_type == DistributedType.MULTI_NPU: - ms_ds = GeneratorDataset(source=source, - column_names=column_names, - shuffle=shuffle, - num_parallel_workers=num_proc if num_proc else 1, - num_shards=get_group_size(), shard_id=get_rank()) - datasets_dict[key] = ms_ds - else: - ms_ds = GeneratorDataset(source=source, - column_names=column_names, - shuffle=shuffle, - num_parallel_workers=num_proc if num_proc else 1) - datasets_dict[key] = ms_ds + + ms_ds = GeneratorDataset(source=source, + column_names=column_names, + shuffle=shuffle, + num_parallel_workers=num_proc if num_proc else 1) + datasets_dict[key] = ms_ds if len(datasets_dict) == 1: return datasets_dict.popitem()[1] diff --git a/mindnlp/engine/utils.py b/mindnlp/engine/utils.py index 782ce40fe..0ed550cb8 100644 --- a/mindnlp/engine/utils.py +++ b/mindnlp/engine/utils.py @@ -29,7 +29,7 @@ import mindspore from mindnlp.core import ops -from mindnlp.core.nn import functional as F +from core.nn import functional as F from mindnlp.configs import GENERATOR_SEED from mindnlp.utils import is_mindspore_available, ExplicitEnum diff --git a/mindnlp/evaluate.py b/mindnlp/evaluate.py new file mode 100644 index 000000000..02c955cf8 --- /dev/null +++ b/mindnlp/evaluate.py @@ -0,0 +1 @@ +from evaluate import * \ No newline at end of file diff --git a/mindnlp/experimental/rwkv6/modeling_rwkv6.py b/mindnlp/experimental/rwkv6/modeling_rwkv6.py index 912f188b1..8c505434b 100644 --- a/mindnlp/experimental/rwkv6/modeling_rwkv6.py +++ b/mindnlp/experimental/rwkv6/modeling_rwkv6.py @@ -19,8 +19,8 @@ # ============================================================================ import mindspore import mindnlp -import mindnlp.core.nn as nn -import mindnlp.core.ops as ops +import core.nn as nn +import core.ops as ops from typing import Tuple @@ -188,7 +188,7 @@ def __init__(self, args: dict): self.set_train(False) # 加载权重 - w = mindnlp.core.serialization.load(args['MODEL_NAME'] + '.pth') + w = core.serialization.load(args['MODEL_NAME'] + '.pth') # 将所有权重转换为float32 self.num_layer = 0 diff --git a/mindnlp/integrations/safetensors.py b/mindnlp/integrations/safetensors.py index 86f2f0794..2e1db82a9 100644 --- a/mindnlp/integrations/safetensors.py +++ b/mindnlp/integrations/safetensors.py @@ -5,9 +5,9 @@ import safetensors -import mindtorch +from mindnlp import core -from mindtorch.configs import SUPPORT_BF16 +from core.configs import SUPPORT_BF16 if SUPPORT_BF16: from mindspore.common.np_dtype import bfloat16 # pylint: disable=import-error @@ -20,19 +20,19 @@ _MS_TYPES = { - "F64": mindtorch.float64, - "F32": mindtorch.float32, - "F16": mindtorch.float16, - "BF16": mindtorch.bfloat16, - "I64": mindtorch.int64, - "U64": mindtorch.uint64, - "I32": mindtorch.int32, - "U32": mindtorch.uint32, - "I16": mindtorch.int16, - "U16": mindtorch.uint16, - "I8": mindtorch.int8, - "U8": mindtorch.uint8, - "BOOL": mindtorch.bool, + "F64": core.float64, + "F32": core.float32, + "F16": core.float16, + "BF16": core.bfloat16, + "I64": core.int64, + "U64": core.uint64, + "I32": core.int32, + "U32": core.uint32, + "I16": core.int16, + "U16": core.uint16, + "I8": core.int8, + "U8": core.uint8, + "BOOL": core.bool, } _NP_TYPES = { @@ -93,7 +93,7 @@ def get(self, *args, **kwargs): tensor = tensor.reshape(self.shape) if not SUPPORT_BF16 and self.info["dtype"] == 'BF16': tensor = tensor.astype(np.float16) - tensor = mindtorch.from_numpy(tensor) + tensor = core.from_numpy(tensor) return tensor @property diff --git a/mindnlp/peft.py b/mindnlp/peft.py new file mode 100644 index 000000000..06c839a9c --- /dev/null +++ b/mindnlp/peft.py @@ -0,0 +1 @@ +from peft import * \ No newline at end of file diff --git a/mindnlp/quant/mindbnb/bitsandbytes/nn/modules.py b/mindnlp/quant/mindbnb/bitsandbytes/nn/modules.py index 98e121d16..447cac288 100644 --- a/mindnlp/quant/mindbnb/bitsandbytes/nn/modules.py +++ b/mindnlp/quant/mindbnb/bitsandbytes/nn/modules.py @@ -33,7 +33,7 @@ ) from mindnlp.core import nn -from mindnlp.core.nn import Parameter +from core.nn import Parameter def empty(*size, dtype=None): diff --git a/mindnlp/quant/mindbnb/integrations/replace_modules.py b/mindnlp/quant/mindbnb/integrations/replace_modules.py index 6d58c3948..8af349806 100644 --- a/mindnlp/quant/mindbnb/integrations/replace_modules.py +++ b/mindnlp/quant/mindbnb/integrations/replace_modules.py @@ -22,7 +22,7 @@ import bitsandbytes as bnb from mindnlp.core import nn -from mindnlp.core.nn import Parameter +from core.nn import Parameter logger = logging.getLogger(__name__) diff --git a/mindnlp/quant/smooth_quant/quant.py b/mindnlp/quant/smooth_quant/quant.py index 12c5fc82b..0dcbb67e5 100644 --- a/mindnlp/quant/smooth_quant/quant.py +++ b/mindnlp/quant/smooth_quant/quant.py @@ -4,7 +4,7 @@ from mindspore import Tensor from mindspore.ops.primitive import PrimitiveWithInfer, prim_attr_register from mindnlp.core import nn, ops -from mindnlp.core.serialization import load +from core.serialization import load from mindnlp.configs import ON_ORANGE_PI from .smooth import smooth_lm diff --git a/mindnlp/transformers/__init__.py b/mindnlp/transformers/__init__.py new file mode 100644 index 000000000..aed4fa323 --- /dev/null +++ b/mindnlp/transformers/__init__.py @@ -0,0 +1 @@ +from .models import * diff --git a/mindnlp/transformers/models/__init__.py b/mindnlp/transformers/models/__init__.py new file mode 100644 index 000000000..079dc6b93 --- /dev/null +++ b/mindnlp/transformers/models/__init__.py @@ -0,0 +1,2 @@ +from . import auto +from .auto import * diff --git a/mindnlp/transformers/models/auto.py b/mindnlp/transformers/models/auto.py new file mode 100644 index 000000000..500610755 --- /dev/null +++ b/mindnlp/transformers/models/auto.py @@ -0,0 +1,16 @@ + +from transformers.models.auto import modeling_auto +from transformers.models.auto import configuration_auto +from transformers.models.auto import feature_extraction_auto +from transformers.models.auto import image_processing_auto +from transformers.models.auto import processing_auto +from transformers.models.auto import tokenization_auto +from transformers.models.auto import auto_factory + +from transformers.models.auto.modeling_auto import * +from transformers.models.auto.configuration_auto import * +from transformers.models.auto.feature_extraction_auto import * +from transformers.models.auto.image_processing_auto import * +from transformers.models.auto.processing_auto import * +from transformers.models.auto.tokenization_auto import * +from transformers.models.auto.auto_factory import * diff --git a/mindnlp/transformers/models/auto_bk/__init__.py b/mindnlp/transformers/models/auto_bk/__init__.py new file mode 100644 index 000000000..b481626c3 --- /dev/null +++ b/mindnlp/transformers/models/auto_bk/__init__.py @@ -0,0 +1,21 @@ +from transformers.models import auto +from transformers.models.auto import configuration_auto +from transformers.models.auto import feature_extraction_auto +from transformers.models.auto import image_processing_auto +from transformers.models.auto import processing_auto +from transformers.models.auto import tokenization_auto + +from transformers.models.auto.configuration_auto import * +from transformers.models.auto.feature_extraction_auto import * +from transformers.models.auto.image_processing_auto import * +from transformers.models.auto.processing_auto import * +from transformers.models.auto.tokenization_auto import * + +from . import modeling_auto + +from .auto_factory import * +from .modeling_auto import * + +__all__ = [] +__all__.extend(auto.__all__) +__all__.extend(modeling_auto.__all__) diff --git a/mindnlp/transformers/models/auto_bk/auto_factory.py b/mindnlp/transformers/models/auto_bk/auto_factory.py new file mode 100644 index 000000000..4ec822205 --- /dev/null +++ b/mindnlp/transformers/models/auto_bk/auto_factory.py @@ -0,0 +1,772 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Factory function to build auto-model classes.""" + +import copy +import importlib +import json +import os +import warnings +from typing import Any, TypeVar, Union + +from transformers.configuration_utils import PretrainedConfig +from transformers.dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code +from transformers.utils import ( + CONFIG_NAME, + cached_file, + copy_func, + extract_commit_hash, + find_adapter_config_file, + is_peft_available, + is_torch_available, + logging, + requires_backends, +) +from transformers.models.auto.configuration_auto import AutoConfig, model_type_to_module_name, replace_list_option_in_docstrings + + +if is_torch_available(): + from transformers.generation import GenerationMixin + + +logger = logging.get_logger(__name__) + +_T = TypeVar("_T") +# Tokenizers will depend on packages installed, too much variance and there are no common base or Protocol +_LazyAutoMappingValue = tuple[Union[type[Any], None], Union[type[Any], None]] + +CLASS_DOCSTRING = """ + This is a generic model class that will be instantiated as one of the model classes of the library when created + with the [`~BaseAutoModelClass.from_pretrained`] class method or the [`~BaseAutoModelClass.from_config`] class + method. + + This class cannot be instantiated directly using `__init__()` (throws an error). +""" + +FROM_CONFIG_DOCSTRING = """ + Instantiates one of the model classes of the library from a configuration. + + Note: + Loading a model from its configuration file does **not** load the model weights. It only affects the + model's configuration. Use [`~BaseAutoModelClass.from_pretrained`] to load the model weights. + + Args: + config ([`PretrainedConfig`]): + The model class to instantiate is selected based on the configuration class: + + List options + attn_implementation (`str`, *optional*): + The attention implementation to use in the model (if relevant). Can be any of `"eager"` (manual implementation of the attention), `"sdpa"` (using [`F.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html)), or `"flash_attention_2"` (using [Dao-AILab/flash-attention](https://github.com/Dao-AILab/flash-attention)). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual `"eager"` implementation. + + Examples: + + ```python + >>> from transformers import AutoConfig, BaseAutoModelClass + + >>> # Download configuration from huggingface.co and cache. + >>> config = AutoConfig.from_pretrained("checkpoint_placeholder") + >>> model = BaseAutoModelClass.from_config(config) + ``` +""" + +FROM_PRETRAINED_TORCH_DOCSTRING = """ + Instantiate one of the model classes of the library from a pretrained model. + + The model class to instantiate is selected based on the `model_type` property of the config object (either + passed as an argument or loaded from `pretrained_model_name_or_path` if possible), or when it's missing, by + falling back to using pattern matching on `pretrained_model_name_or_path`: + + List options + + The model is set in evaluation mode by default using `model.eval()` (so for instance, dropout modules are + deactivated). To train the model, you should first set it back in training mode with `model.train()` + + Args: + pretrained_model_name_or_path (`str` or `os.PathLike`): + Can be either: + + - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. + - A path to a *directory* containing model weights saved using + [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. + - A path or url to a *tensorflow index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In + this case, `from_tf` should be set to `True` and a configuration object should be provided as + `config` argument. This loading path is slower than converting the TensorFlow checkpoint in a + PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards. + model_args (additional positional arguments, *optional*): + Will be passed along to the underlying model `__init__()` method. + config ([`PretrainedConfig`], *optional*): + Configuration for the model to use instead of an automatically loaded configuration. Configuration can + be automatically loaded when: + + - The model is a model provided by the library (loaded with the *model id* string of a pretrained + model). + - The model was saved using [`~PreTrainedModel.save_pretrained`] and is reloaded by supplying the + save directory. + - The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a + configuration JSON file named *config.json* is found in the directory. + state_dict (*dict[str, torch.Tensor]*, *optional*): + A state dictionary to use instead of a state dictionary loaded from saved weights file. + + This option can be used if you want to create a model from a pretrained configuration but load your own + weights. In this case though, you should check if using [`~PreTrainedModel.save_pretrained`] and + [`~PreTrainedModel.from_pretrained`] is not a simpler option. + cache_dir (`str` or `os.PathLike`, *optional*): + Path to a directory in which a downloaded pretrained model configuration should be cached if the + standard cache should not be used. + from_tf (`bool`, *optional*, defaults to `False`): + Load the model weights from a TensorFlow checkpoint save file (see docstring of + `pretrained_model_name_or_path` argument). + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + resume_download: + Deprecated and ignored. All downloads are now resumed by default when possible. + Will be removed in v5 of Transformers. + proxies (`dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + output_loading_info(`bool`, *optional*, defaults to `False`): + Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. + local_files_only(`bool`, *optional*, defaults to `False`): + Whether or not to only look at local files (e.g., not try downloading the model). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + trust_remote_code (`bool`, *optional*, defaults to `False`): + Whether or not to allow for custom models defined on the Hub in their own modeling files. This option + should only be set to `True` for repositories you trust and in which you have read the code, as it will + execute code present on the Hub on your local machine. + code_revision (`str`, *optional*, defaults to `"main"`): + The specific revision to use for the code on the Hub, if the code leaves in a different repository than + the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based + system for storing models and other artifacts on huggingface.co, so `revision` can be any identifier + allowed by git. + kwargs (additional keyword arguments, *optional*): + Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., + `output_attentions=True`). Behaves differently depending on whether a `config` is provided or + automatically loaded: + + - If a configuration is provided with `config`, `**kwargs` will be directly passed to the + underlying model's `__init__` method (we assume all relevant updates to the configuration have + already been done) + - If a configuration is not provided, `kwargs` will be first passed to the configuration class + initialization function ([`~PretrainedConfig.from_pretrained`]). Each key of `kwargs` that + corresponds to a configuration attribute will be used to override said attribute with the + supplied `kwargs` value. Remaining keys that do not correspond to any configuration attribute + will be passed to the underlying model's `__init__` function. + + Examples: + + ```python + >>> from transformers import AutoConfig, BaseAutoModelClass + + >>> # Download model and configuration from huggingface.co and cache. + >>> model = BaseAutoModelClass.from_pretrained("checkpoint_placeholder") + + >>> # Update configuration during loading + >>> model = BaseAutoModelClass.from_pretrained("checkpoint_placeholder", output_attentions=True) + >>> model.config.output_attentions + True + + >>> # Loading from a TF checkpoint file instead of a PyTorch model (slower) + >>> config = AutoConfig.from_pretrained("./tf_model/shortcut_placeholder_tf_model_config.json") + >>> model = BaseAutoModelClass.from_pretrained( + ... "./tf_model/shortcut_placeholder_tf_checkpoint.ckpt.index", from_tf=True, config=config + ... ) + ``` +""" + +FROM_PRETRAINED_TF_DOCSTRING = """ + Instantiate one of the model classes of the library from a pretrained model. + + The model class to instantiate is selected based on the `model_type` property of the config object (either + passed as an argument or loaded from `pretrained_model_name_or_path` if possible), or when it's missing, by + falling back to using pattern matching on `pretrained_model_name_or_path`: + + List options + + Args: + pretrained_model_name_or_path (`str` or `os.PathLike`): + Can be either: + + - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. + - A path to a *directory* containing model weights saved using + [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. + - A path or url to a *PyTorch state_dict save file* (e.g, `./pt_model/pytorch_model.bin`). In this + case, `from_pt` should be set to `True` and a configuration object should be provided as `config` + argument. This loading path is slower than converting the PyTorch model in a TensorFlow model + using the provided conversion scripts and loading the TensorFlow model afterwards. + model_args (additional positional arguments, *optional*): + Will be passed along to the underlying model `__init__()` method. + config ([`PretrainedConfig`], *optional*): + Configuration for the model to use instead of an automatically loaded configuration. Configuration can + be automatically loaded when: + + - The model is a model provided by the library (loaded with the *model id* string of a pretrained + model). + - The model was saved using [`~PreTrainedModel.save_pretrained`] and is reloaded by supplying the + save directory. + - The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a + configuration JSON file named *config.json* is found in the directory. + cache_dir (`str` or `os.PathLike`, *optional*): + Path to a directory in which a downloaded pretrained model configuration should be cached if the + standard cache should not be used. + from_pt (`bool`, *optional*, defaults to `False`): + Load the model weights from a PyTorch checkpoint save file (see docstring of + `pretrained_model_name_or_path` argument). + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + resume_download: + Deprecated and ignored. All downloads are now resumed by default when possible. + Will be removed in v5 of Transformers. + proxies (`dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + output_loading_info(`bool`, *optional*, defaults to `False`): + Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. + local_files_only(`bool`, *optional*, defaults to `False`): + Whether or not to only look at local files (e.g., not try downloading the model). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + trust_remote_code (`bool`, *optional*, defaults to `False`): + Whether or not to allow for custom models defined on the Hub in their own modeling files. This option + should only be set to `True` for repositories you trust and in which you have read the code, as it will + execute code present on the Hub on your local machine. + code_revision (`str`, *optional*, defaults to `"main"`): + The specific revision to use for the code on the Hub, if the code leaves in a different repository than + the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based + system for storing models and other artifacts on huggingface.co, so `revision` can be any identifier + allowed by git. + kwargs (additional keyword arguments, *optional*): + Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., + `output_attentions=True`). Behaves differently depending on whether a `config` is provided or + automatically loaded: + + - If a configuration is provided with `config`, `**kwargs` will be directly passed to the + underlying model's `__init__` method (we assume all relevant updates to the configuration have + already been done) + - If a configuration is not provided, `kwargs` will be first passed to the configuration class + initialization function ([`~PretrainedConfig.from_pretrained`]). Each key of `kwargs` that + corresponds to a configuration attribute will be used to override said attribute with the + supplied `kwargs` value. Remaining keys that do not correspond to any configuration attribute + will be passed to the underlying model's `__init__` function. + + Examples: + + ```python + >>> from transformers import AutoConfig, BaseAutoModelClass + + >>> # Download model and configuration from huggingface.co and cache. + >>> model = BaseAutoModelClass.from_pretrained("checkpoint_placeholder") + + >>> # Update configuration during loading + >>> model = BaseAutoModelClass.from_pretrained("checkpoint_placeholder", output_attentions=True) + >>> model.config.output_attentions + True + + >>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower) + >>> config = AutoConfig.from_pretrained("./pt_model/shortcut_placeholder_pt_model_config.json") + >>> model = BaseAutoModelClass.from_pretrained( + ... "./pt_model/shortcut_placeholder_pytorch_model.bin", from_pt=True, config=config + ... ) + ``` +""" + +FROM_PRETRAINED_FLAX_DOCSTRING = """ + Instantiate one of the model classes of the library from a pretrained model. + + The model class to instantiate is selected based on the `model_type` property of the config object (either + passed as an argument or loaded from `pretrained_model_name_or_path` if possible), or when it's missing, by + falling back to using pattern matching on `pretrained_model_name_or_path`: + + List options + + Args: + pretrained_model_name_or_path (`str` or `os.PathLike`): + Can be either: + + - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. + - A path to a *directory* containing model weights saved using + [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. + - A path or url to a *PyTorch state_dict save file* (e.g, `./pt_model/pytorch_model.bin`). In this + case, `from_pt` should be set to `True` and a configuration object should be provided as `config` + argument. This loading path is slower than converting the PyTorch model in a TensorFlow model + using the provided conversion scripts and loading the TensorFlow model afterwards. + model_args (additional positional arguments, *optional*): + Will be passed along to the underlying model `__init__()` method. + config ([`PretrainedConfig`], *optional*): + Configuration for the model to use instead of an automatically loaded configuration. Configuration can + be automatically loaded when: + + - The model is a model provided by the library (loaded with the *model id* string of a pretrained + model). + - The model was saved using [`~PreTrainedModel.save_pretrained`] and is reloaded by supplying the + save directory. + - The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a + configuration JSON file named *config.json* is found in the directory. + cache_dir (`str` or `os.PathLike`, *optional*): + Path to a directory in which a downloaded pretrained model configuration should be cached if the + standard cache should not be used. + from_pt (`bool`, *optional*, defaults to `False`): + Load the model weights from a PyTorch checkpoint save file (see docstring of + `pretrained_model_name_or_path` argument). + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + resume_download: + Deprecated and ignored. All downloads are now resumed by default when possible. + Will be removed in v5 of Transformers. + proxies (`dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + output_loading_info(`bool`, *optional*, defaults to `False`): + Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. + local_files_only(`bool`, *optional*, defaults to `False`): + Whether or not to only look at local files (e.g., not try downloading the model). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + trust_remote_code (`bool`, *optional*, defaults to `False`): + Whether or not to allow for custom models defined on the Hub in their own modeling files. This option + should only be set to `True` for repositories you trust and in which you have read the code, as it will + execute code present on the Hub on your local machine. + code_revision (`str`, *optional*, defaults to `"main"`): + The specific revision to use for the code on the Hub, if the code leaves in a different repository than + the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based + system for storing models and other artifacts on huggingface.co, so `revision` can be any identifier + allowed by git. + kwargs (additional keyword arguments, *optional*): + Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., + `output_attentions=True`). Behaves differently depending on whether a `config` is provided or + automatically loaded: + + - If a configuration is provided with `config`, `**kwargs` will be directly passed to the + underlying model's `__init__` method (we assume all relevant updates to the configuration have + already been done) + - If a configuration is not provided, `kwargs` will be first passed to the configuration class + initialization function ([`~PretrainedConfig.from_pretrained`]). Each key of `kwargs` that + corresponds to a configuration attribute will be used to override said attribute with the + supplied `kwargs` value. Remaining keys that do not correspond to any configuration attribute + will be passed to the underlying model's `__init__` function. + + Examples: + + ```python + >>> from transformers import AutoConfig, BaseAutoModelClass + + >>> # Download model and configuration from huggingface.co and cache. + >>> model = BaseAutoModelClass.from_pretrained("checkpoint_placeholder") + + >>> # Update configuration during loading + >>> model = BaseAutoModelClass.from_pretrained("checkpoint_placeholder", output_attentions=True) + >>> model.config.output_attentions + True + + >>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower) + >>> config = AutoConfig.from_pretrained("./pt_model/shortcut_placeholder_pt_model_config.json") + >>> model = BaseAutoModelClass.from_pretrained( + ... "./pt_model/shortcut_placeholder_pytorch_model.bin", from_pt=True, config=config + ... ) + ``` +""" + + +def _get_model_class(config, model_mapping): + supported_models = model_mapping[type(config)] + if not isinstance(supported_models, (list, tuple)): + return supported_models + + name_to_model = {model.__name__: model for model in supported_models} + architectures = getattr(config, "architectures", []) + for arch in architectures: + if arch in name_to_model: + return name_to_model[arch] + elif f"TF{arch}" in name_to_model: + return name_to_model[f"TF{arch}"] + elif f"Flax{arch}" in name_to_model: + return name_to_model[f"Flax{arch}"] + + # If not architecture is set in the config or match the supported models, the first element of the tuple is the + # defaults. + return supported_models[0] + + +class _BaseAutoModelClass: + # Base class for auto models. + _model_mapping = None + + def __init__(self, *args, **kwargs) -> None: + raise OSError( + f"{self.__class__.__name__} is designed to be instantiated " + f"using the `{self.__class__.__name__}.from_pretrained(pretrained_model_name_or_path)` or " + f"`{self.__class__.__name__}.from_config(config)` methods." + ) + + @classmethod + def from_config(cls, config, **kwargs): + trust_remote_code = kwargs.pop("trust_remote_code", None) + has_remote_code = hasattr(config, "auto_map") and cls.__name__ in config.auto_map + has_local_code = type(config) in cls._model_mapping.keys() + if has_remote_code: + class_ref = config.auto_map[cls.__name__] + if "--" in class_ref: + upstream_repo = class_ref.split("--")[0] + else: + upstream_repo = None + trust_remote_code = resolve_trust_remote_code( + trust_remote_code, config._name_or_path, has_local_code, has_remote_code, upstream_repo=upstream_repo + ) + + if has_remote_code and trust_remote_code: + if "--" in class_ref: + repo_id, class_ref = class_ref.split("--") + else: + repo_id = config.name_or_path + model_class = get_class_from_dynamic_module(class_ref, repo_id, **kwargs) + # This block handles the case where the user is loading a model with `trust_remote_code=True` + # but a library model exists with the same name. We don't want to override the autoclass + # mappings in this case, or all future loads of that model will be the remote code model. + if not has_local_code: + cls.register(config.__class__, model_class, exist_ok=True) + model_class.register_for_auto_class(auto_class=cls) + _ = kwargs.pop("code_revision", None) + model_class = add_generation_mixin_to_remote_model(model_class) + return model_class._from_config(config, **kwargs) + elif type(config) in cls._model_mapping.keys(): + model_class = _get_model_class(config, cls._model_mapping) + return model_class._from_config(config, **kwargs) + + raise ValueError( + f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n" + f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}." + ) + + @classmethod + def _prepare_config_for_auto_class(cls, config: PretrainedConfig) -> PretrainedConfig: + """Additional autoclass-specific config post-loading manipulation. May be overridden in subclasses.""" + return config + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike[str]], *model_args, **kwargs): + config = kwargs.pop("config", None) + trust_remote_code = kwargs.get("trust_remote_code", None) + kwargs["_from_auto"] = True + hub_kwargs_names = [ + "cache_dir", + "force_download", + "local_files_only", + "proxies", + "resume_download", + "revision", + "subfolder", + "use_auth_token", + "token", + ] + hub_kwargs = {name: kwargs.pop(name) for name in hub_kwargs_names if name in kwargs} + code_revision = kwargs.pop("code_revision", None) + commit_hash = kwargs.pop("_commit_hash", None) + adapter_kwargs = kwargs.pop("adapter_kwargs", None) + + token = hub_kwargs.pop("token", None) + use_auth_token = hub_kwargs.pop("use_auth_token", None) + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.", + FutureWarning, + ) + if token is not None: + raise ValueError( + "`token` and `use_auth_token` are both specified. Please set only the argument `token`." + ) + token = use_auth_token + + if token is not None: + hub_kwargs["token"] = token + + if commit_hash is None: + if not isinstance(config, PretrainedConfig): + # We make a call to the config file first (which may be absent) to get the commit hash as soon as possible + resolved_config_file = cached_file( + pretrained_model_name_or_path, + CONFIG_NAME, + _raise_exceptions_for_gated_repo=False, + _raise_exceptions_for_missing_entries=False, + _raise_exceptions_for_connection_errors=False, + **hub_kwargs, + ) + commit_hash = extract_commit_hash(resolved_config_file, commit_hash) + else: + commit_hash = getattr(config, "_commit_hash", None) + + if is_peft_available(): + if adapter_kwargs is None: + adapter_kwargs = {} + if token is not None: + adapter_kwargs["token"] = token + + maybe_adapter_path = find_adapter_config_file( + pretrained_model_name_or_path, _commit_hash=commit_hash, **adapter_kwargs + ) + + if maybe_adapter_path is not None: + with open(maybe_adapter_path, "r", encoding="utf-8") as f: + adapter_config = json.load(f) + + adapter_kwargs["_adapter_model_path"] = pretrained_model_name_or_path + pretrained_model_name_or_path = adapter_config["base_model_name_or_path"] + + if not isinstance(config, PretrainedConfig): + kwargs_orig = copy.deepcopy(kwargs) + # ensure not to pollute the config object with torch_dtype="auto" - since it's + # meaningless in the context of the config object - torch.dtype values are acceptable + if kwargs.get("torch_dtype", None) == "auto": + _ = kwargs.pop("torch_dtype") + # to not overwrite the quantization_config if config has a quantization_config + if kwargs.get("quantization_config", None) is not None: + _ = kwargs.pop("quantization_config") + + config, kwargs = AutoConfig.from_pretrained( + pretrained_model_name_or_path, + return_unused_kwargs=True, + code_revision=code_revision, + _commit_hash=commit_hash, + **hub_kwargs, + **kwargs, + ) + + # if torch_dtype=auto was passed here, ensure to pass it on + if kwargs_orig.get("torch_dtype", None) == "auto": + kwargs["torch_dtype"] = "auto" + if kwargs_orig.get("quantization_config", None) is not None: + kwargs["quantization_config"] = kwargs_orig["quantization_config"] + + has_remote_code = hasattr(config, "auto_map") and cls.__name__ in config.auto_map + has_local_code = type(config) in cls._model_mapping.keys() + upstream_repo = None + if has_remote_code: + class_ref = config.auto_map[cls.__name__] + if "--" in class_ref: + upstream_repo = class_ref.split("--")[0] + trust_remote_code = resolve_trust_remote_code( + trust_remote_code, + pretrained_model_name_or_path, + has_local_code, + has_remote_code, + ) + kwargs["trust_remote_code"] = trust_remote_code + + # Set the adapter kwargs + kwargs["adapter_kwargs"] = adapter_kwargs + + if has_remote_code and trust_remote_code: + model_class = get_class_from_dynamic_module( + class_ref, pretrained_model_name_or_path, code_revision=code_revision, **hub_kwargs, **kwargs + ) + _ = hub_kwargs.pop("code_revision", None) + # This block handles the case where the user is loading a model with `trust_remote_code=True` + # but a library model exists with the same name. We don't want to override the autoclass + # mappings in this case, or all future loads of that model will be the remote code model. + if not has_local_code: + cls.register(config.__class__, model_class, exist_ok=True) + model_class.register_for_auto_class(auto_class=cls) + model_class = add_generation_mixin_to_remote_model(model_class) + return model_class.from_pretrained( + pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs + ) + elif type(config) in cls._model_mapping.keys(): + model_class = _get_model_class(config, cls._model_mapping) + if model_class.config_class == config.sub_configs.get("text_config", None): + config = config.get_text_config() + return model_class.from_pretrained( + pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs + ) + raise ValueError( + f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n" + f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}." + ) + + @classmethod + def register(cls, config_class, model_class, exist_ok=False) -> None: + """ + Register a new model for this class. + + Args: + config_class ([`PretrainedConfig`]): + The configuration corresponding to the model to register. + model_class ([`PreTrainedModel`]): + The model to register. + """ + if hasattr(model_class, "config_class") and model_class.config_class.__name__ != config_class.__name__: + raise ValueError( + "The model class you are passing has a `config_class` attribute that is not consistent with the " + f"config class you passed (model has {model_class.config_class} and you passed {config_class}. Fix " + "one of those so they match!" + ) + cls._model_mapping.register(config_class, model_class, exist_ok=exist_ok) + + +class _BaseAutoBackboneClass(_BaseAutoModelClass): + # Base class for auto backbone models. + _model_mapping = None + + @classmethod + def _load_timm_backbone_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): + requires_backends(cls, ["vision", "timm"]) + from ...models.timm_backbone import TimmBackboneConfig + + config = kwargs.pop("config", TimmBackboneConfig()) + + if kwargs.get("out_features", None) is not None: + raise ValueError("Cannot specify `out_features` for timm backbones") + + if kwargs.get("output_loading_info", False): + raise ValueError("Cannot specify `output_loading_info=True` when loading from timm") + + num_channels = kwargs.pop("num_channels", config.num_channels) + features_only = kwargs.pop("features_only", config.features_only) + use_pretrained_backbone = kwargs.pop("use_pretrained_backbone", config.use_pretrained_backbone) + out_indices = kwargs.pop("out_indices", config.out_indices) + config = TimmBackboneConfig( + backbone=pretrained_model_name_or_path, + num_channels=num_channels, + features_only=features_only, + use_pretrained_backbone=use_pretrained_backbone, + out_indices=out_indices, + ) + return super().from_config(config, **kwargs) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): + use_timm_backbone = kwargs.pop("use_timm_backbone", False) + if use_timm_backbone: + return cls._load_timm_backbone_from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) + + return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) + + +def insert_head_doc(docstring, head_doc: str = ""): + if len(head_doc) > 0: + return docstring.replace( + "one of the model classes of the library ", + f"one of the model classes of the library (with a {head_doc} head) ", + ) + return docstring.replace( + "one of the model classes of the library ", "one of the base model classes of the library " + ) + + +def auto_class_update(cls, checkpoint_for_example: str = "google-bert/bert-base-cased", head_doc: str = ""): + # Create a new class with the right name from the base class + model_mapping = cls._model_mapping + name = cls.__name__ + class_docstring = insert_head_doc(CLASS_DOCSTRING, head_doc=head_doc) + cls.__doc__ = class_docstring.replace("BaseAutoModelClass", name) + + # Now we need to copy and re-register `from_config` and `from_pretrained` as class methods otherwise we can't + # have a specific docstrings for them. + from_config = copy_func(_BaseAutoModelClass.from_config) + from_config_docstring = insert_head_doc(FROM_CONFIG_DOCSTRING, head_doc=head_doc) + from_config_docstring = from_config_docstring.replace("BaseAutoModelClass", name) + from_config_docstring = from_config_docstring.replace("checkpoint_placeholder", checkpoint_for_example) + from_config.__doc__ = from_config_docstring + from_config = replace_list_option_in_docstrings(model_mapping._model_mapping, use_model_types=False)(from_config) + cls.from_config = classmethod(from_config) + + if name.startswith("TF"): + from_pretrained_docstring = FROM_PRETRAINED_TF_DOCSTRING + elif name.startswith("Flax"): + from_pretrained_docstring = FROM_PRETRAINED_FLAX_DOCSTRING + else: + from_pretrained_docstring = FROM_PRETRAINED_TORCH_DOCSTRING + from_pretrained = copy_func(_BaseAutoModelClass.from_pretrained) + from_pretrained_docstring = insert_head_doc(from_pretrained_docstring, head_doc=head_doc) + from_pretrained_docstring = from_pretrained_docstring.replace("BaseAutoModelClass", name) + from_pretrained_docstring = from_pretrained_docstring.replace("checkpoint_placeholder", checkpoint_for_example) + shortcut = checkpoint_for_example.split("/")[-1].split("-")[0] + from_pretrained_docstring = from_pretrained_docstring.replace("shortcut_placeholder", shortcut) + from_pretrained.__doc__ = from_pretrained_docstring + from_pretrained = replace_list_option_in_docstrings(model_mapping._model_mapping)(from_pretrained) + cls.from_pretrained = classmethod(from_pretrained) + return cls + + +def get_values(model_mapping): + result = [] + for model in model_mapping.values(): + if isinstance(model, (list, tuple)): + result += list(model) + else: + result.append(model) + + return result + + +def getattribute_from_module(module, attr): + if attr is None: + return None + if isinstance(attr, tuple): + return tuple(getattribute_from_module(module, a) for a in attr) + if hasattr(module, attr): + return getattr(module, attr) + # Some of the mappings have entries model_type -> object of another model type. In that case we try to grab the + # object at the top level. + transformers_module = importlib.import_module("transformers") + + if module != transformers_module: + try: + return getattribute_from_module(transformers_module, attr) + except ValueError: + raise ValueError(f"Could not find {attr} neither in {module} nor in {transformers_module}!") + else: + raise ValueError(f"Could not find {attr} in {transformers_module}!") + + +def add_generation_mixin_to_remote_model(model_class): + """ + Adds `GenerationMixin` to the inheritance of `model_class`, if `model_class` is a PyTorch model. + + This function is used for backwards compatibility purposes: in v4.45, we've started a deprecation cycle to make + `PreTrainedModel` stop inheriting from `GenerationMixin`. Without this function, older models dynamically loaded + from the Hub may not have the `generate` method after we remove the inheritance. + """ + # 1. If it is not a PT model (i.e. doesn't inherit Module), do nothing + if "torch.nn.modules.module.Module" not in str(model_class.__mro__): + return model_class + + # 2. If it already **directly** inherits from GenerationMixin, do nothing + if "GenerationMixin" in str(model_class.__bases__): + return model_class + + # 3. Prior to v4.45, we could detect whether a model was `generate`-compatible if it had its own `generate` and/or + # `prepare_inputs_for_generation` method. + has_custom_generate_in_class = hasattr(model_class, "generate") and "GenerationMixin" not in str( + getattr(model_class, "generate") + ) + has_custom_prepare_inputs = hasattr(model_class, "prepare_inputs_for_generation") and "GenerationMixin" not in str( + getattr(model_class, "prepare_inputs_for_generation") + ) + if has_custom_generate_in_class or has_custom_prepare_inputs: + model_class_with_generation_mixin = type( + model_class.__name__, (model_class, GenerationMixin), {**model_class.__dict__} + ) + return model_class_with_generation_mixin + return model_class + +__all__ = ["get_values"] \ No newline at end of file diff --git a/mindnlp/transformers/models/auto_bk/modeling_auto.py b/mindnlp/transformers/models/auto_bk/modeling_auto.py new file mode 100644 index 000000000..2dce59859 --- /dev/null +++ b/mindnlp/transformers/models/auto_bk/modeling_auto.py @@ -0,0 +1,448 @@ +import warnings + +from transformers.models.auto.modeling_auto import ( + MODEL_FOR_MASK_GENERATION_MAPPING, + MODEL_FOR_KEYPOINT_DETECTION_MAPPING, + MODEL_FOR_TEXT_ENCODING_MAPPING, + MODEL_FOR_IMAGE_TO_IMAGE_MAPPING, + MODEL_MAPPING, + MODEL_FOR_PRETRAINING_MAPPING, + MODEL_WITH_LM_HEAD_MAPPING, + MODEL_FOR_CAUSAL_LM_MAPPING, + MODEL_FOR_MASKED_LM_MAPPING, + MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, + MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, + MODEL_FOR_QUESTION_ANSWERING_MAPPING, + MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING, + MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING, + MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING, + MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, + MODEL_FOR_MULTIPLE_CHOICE_MAPPING, + MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING, + MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING, + MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING, + MODEL_FOR_IMAGE_SEGMENTATION_MAPPING, + MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING, + MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING, + MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING, + MODEL_FOR_OBJECT_DETECTION_MAPPING, + MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING, + MODEL_FOR_DEPTH_ESTIMATION_MAPPING, + MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING, + MODEL_FOR_VISION_2_SEQ_MAPPING, + MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING, + MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING, + MODEL_FOR_CTC_MAPPING, + MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING, + MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING, + MODEL_FOR_AUDIO_XVECTOR_MAPPING, + MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING, + MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING, + MODEL_FOR_BACKBONE_MAPPING, + MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING, + MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING, + MODEL_FOR_IMAGE_MAPPING, + MODEL_FOR_RETRIEVAL_MAPPING, + MODEL_FOR_TIME_SERIES_CLASSIFICATION_MAPPING, + MODEL_FOR_TIME_SERIES_REGRESSION_MAPPING, +) +from .auto_factory import ( + _BaseAutoBackboneClass, + _BaseAutoModelClass, + auto_class_update, +) + +class AutoModelForMaskGeneration(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_MASK_GENERATION_MAPPING + + +class AutoModelForKeypointDetection(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_KEYPOINT_DETECTION_MAPPING + + +class AutoModelForTextEncoding(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_TEXT_ENCODING_MAPPING + + +class AutoModelForImageToImage(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_IMAGE_TO_IMAGE_MAPPING + + +class AutoModel(_BaseAutoModelClass): + _model_mapping = MODEL_MAPPING + + +AutoModel = auto_class_update(AutoModel) + + +class AutoModelForPreTraining(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_PRETRAINING_MAPPING + + +AutoModelForPreTraining = auto_class_update(AutoModelForPreTraining, head_doc="pretraining") + + +# Private on purpose, the public class will add the deprecation warnings. +class _AutoModelWithLMHead(_BaseAutoModelClass): + _model_mapping = MODEL_WITH_LM_HEAD_MAPPING + + +_AutoModelWithLMHead = auto_class_update(_AutoModelWithLMHead, head_doc="language modeling") + + +class AutoModelForCausalLM(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_CAUSAL_LM_MAPPING + + +AutoModelForCausalLM = auto_class_update(AutoModelForCausalLM, head_doc="causal language modeling") + + +class AutoModelForMaskedLM(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_MASKED_LM_MAPPING + + +AutoModelForMaskedLM = auto_class_update(AutoModelForMaskedLM, head_doc="masked language modeling") + + +class AutoModelForSeq2SeqLM(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING + + +AutoModelForSeq2SeqLM = auto_class_update( + AutoModelForSeq2SeqLM, + head_doc="sequence-to-sequence language modeling", + checkpoint_for_example="google-t5/t5-base", +) + + +class AutoModelForSequenceClassification(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING + + +AutoModelForSequenceClassification = auto_class_update( + AutoModelForSequenceClassification, head_doc="sequence classification" +) + + +class AutoModelForQuestionAnswering(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_QUESTION_ANSWERING_MAPPING + + +AutoModelForQuestionAnswering = auto_class_update(AutoModelForQuestionAnswering, head_doc="question answering") + + +class AutoModelForTableQuestionAnswering(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING + + +AutoModelForTableQuestionAnswering = auto_class_update( + AutoModelForTableQuestionAnswering, + head_doc="table question answering", + checkpoint_for_example="google/tapas-base-finetuned-wtq", +) + + +class AutoModelForVisualQuestionAnswering(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING + + +AutoModelForVisualQuestionAnswering = auto_class_update( + AutoModelForVisualQuestionAnswering, + head_doc="visual question answering", + checkpoint_for_example="dandelin/vilt-b32-finetuned-vqa", +) + + +class AutoModelForDocumentQuestionAnswering(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING + + +AutoModelForDocumentQuestionAnswering = auto_class_update( + AutoModelForDocumentQuestionAnswering, + head_doc="document question answering", + checkpoint_for_example='impira/layoutlm-document-qa", revision="52e01b3', +) + + +class AutoModelForTokenClassification(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING + + +AutoModelForTokenClassification = auto_class_update(AutoModelForTokenClassification, head_doc="token classification") + + +class AutoModelForMultipleChoice(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_MULTIPLE_CHOICE_MAPPING + + +AutoModelForMultipleChoice = auto_class_update(AutoModelForMultipleChoice, head_doc="multiple choice") + + +class AutoModelForNextSentencePrediction(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING + + +AutoModelForNextSentencePrediction = auto_class_update( + AutoModelForNextSentencePrediction, head_doc="next sentence prediction" +) + + +class AutoModelForImageClassification(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING + + +AutoModelForImageClassification = auto_class_update(AutoModelForImageClassification, head_doc="image classification") + + +class AutoModelForZeroShotImageClassification(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING + + +AutoModelForZeroShotImageClassification = auto_class_update( + AutoModelForZeroShotImageClassification, head_doc="zero-shot image classification" +) + + +class AutoModelForImageSegmentation(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_IMAGE_SEGMENTATION_MAPPING + + +AutoModelForImageSegmentation = auto_class_update(AutoModelForImageSegmentation, head_doc="image segmentation") + + +class AutoModelForSemanticSegmentation(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING + + +AutoModelForSemanticSegmentation = auto_class_update( + AutoModelForSemanticSegmentation, head_doc="semantic segmentation" +) + + +class AutoModelForUniversalSegmentation(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING + + +AutoModelForUniversalSegmentation = auto_class_update( + AutoModelForUniversalSegmentation, head_doc="universal image segmentation" +) + + +class AutoModelForInstanceSegmentation(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING + + +AutoModelForInstanceSegmentation = auto_class_update( + AutoModelForInstanceSegmentation, head_doc="instance segmentation" +) + + +class AutoModelForObjectDetection(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_OBJECT_DETECTION_MAPPING + + +AutoModelForObjectDetection = auto_class_update(AutoModelForObjectDetection, head_doc="object detection") + + +class AutoModelForZeroShotObjectDetection(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING + + +AutoModelForZeroShotObjectDetection = auto_class_update( + AutoModelForZeroShotObjectDetection, head_doc="zero-shot object detection" +) + + +class AutoModelForDepthEstimation(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_DEPTH_ESTIMATION_MAPPING + + +AutoModelForDepthEstimation = auto_class_update(AutoModelForDepthEstimation, head_doc="depth estimation") + + +class AutoModelForVideoClassification(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING + + +AutoModelForVideoClassification = auto_class_update(AutoModelForVideoClassification, head_doc="video classification") + + +class AutoModelForVision2Seq(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_VISION_2_SEQ_MAPPING + + +AutoModelForVision2Seq = auto_class_update(AutoModelForVision2Seq, head_doc="vision-to-text modeling") + + +class AutoModelForImageTextToText(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING + + +AutoModelForImageTextToText = auto_class_update(AutoModelForImageTextToText, head_doc="image-text-to-text modeling") + + +class AutoModelForAudioClassification(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING + + +AutoModelForAudioClassification = auto_class_update(AutoModelForAudioClassification, head_doc="audio classification") + + +class AutoModelForCTC(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_CTC_MAPPING + + +AutoModelForCTC = auto_class_update(AutoModelForCTC, head_doc="connectionist temporal classification") + + +class AutoModelForSpeechSeq2Seq(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING + + +AutoModelForSpeechSeq2Seq = auto_class_update( + AutoModelForSpeechSeq2Seq, head_doc="sequence-to-sequence speech-to-text modeling" +) + + +class AutoModelForAudioFrameClassification(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING + + +AutoModelForAudioFrameClassification = auto_class_update( + AutoModelForAudioFrameClassification, head_doc="audio frame (token) classification" +) + + +class AutoModelForAudioXVector(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_AUDIO_XVECTOR_MAPPING + + +class AutoModelForTextToSpectrogram(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING + + +class AutoModelForTextToWaveform(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING + + +class AutoBackbone(_BaseAutoBackboneClass): + _model_mapping = MODEL_FOR_BACKBONE_MAPPING + + +AutoModelForAudioXVector = auto_class_update(AutoModelForAudioXVector, head_doc="audio retrieval via x-vector") + + +class AutoModelForMaskedImageModeling(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING + + +AutoModelForMaskedImageModeling = auto_class_update(AutoModelForMaskedImageModeling, head_doc="masked image modeling") + + +class AutoModelWithLMHead(_AutoModelWithLMHead): + @classmethod + def from_config(cls, config): + warnings.warn( + "The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use " + "`AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and " + "`AutoModelForSeq2SeqLM` for encoder-decoder models.", + FutureWarning, + ) + return super().from_config(config) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): + warnings.warn( + "The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use " + "`AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and " + "`AutoModelForSeq2SeqLM` for encoder-decoder models.", + FutureWarning, + ) + return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) + + +__all__ = [ + "MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING", + "MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING", + "MODEL_FOR_AUDIO_XVECTOR_MAPPING", + "MODEL_FOR_BACKBONE_MAPPING", + "MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING", + "MODEL_FOR_CAUSAL_LM_MAPPING", + "MODEL_FOR_CTC_MAPPING", + "MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING", + "MODEL_FOR_DEPTH_ESTIMATION_MAPPING", + "MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING", + "MODEL_FOR_IMAGE_MAPPING", + "MODEL_FOR_IMAGE_SEGMENTATION_MAPPING", + "MODEL_FOR_IMAGE_TO_IMAGE_MAPPING", + "MODEL_FOR_KEYPOINT_DETECTION_MAPPING", + "MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING", + "MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING", + "MODEL_FOR_MASKED_LM_MAPPING", + "MODEL_FOR_MASK_GENERATION_MAPPING", + "MODEL_FOR_MULTIPLE_CHOICE_MAPPING", + "MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING", + "MODEL_FOR_OBJECT_DETECTION_MAPPING", + "MODEL_FOR_PRETRAINING_MAPPING", + "MODEL_FOR_QUESTION_ANSWERING_MAPPING", + "MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING", + "MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING", + "MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING", + "MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING", + "MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING", + "MODEL_FOR_TEXT_ENCODING_MAPPING", + "MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING", + "MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING", + "MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING", + "MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING", + "MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING", + "MODEL_FOR_VISION_2_SEQ_MAPPING", + "MODEL_FOR_RETRIEVAL_MAPPING", + "MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING", + "MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING", + "MODEL_MAPPING", + "MODEL_WITH_LM_HEAD_MAPPING", + "MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING", + "MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING", + "MODEL_FOR_TIME_SERIES_CLASSIFICATION_MAPPING", + "MODEL_FOR_TIME_SERIES_REGRESSION_MAPPING", + "AutoModel", + "AutoBackbone", + "AutoModelForAudioClassification", + "AutoModelForAudioFrameClassification", + "AutoModelForAudioXVector", + "AutoModelForCausalLM", + "AutoModelForCTC", + "AutoModelForDepthEstimation", + "AutoModelForImageClassification", + "AutoModelForImageSegmentation", + "AutoModelForImageToImage", + "AutoModelForInstanceSegmentation", + "AutoModelForKeypointDetection", + "AutoModelForMaskGeneration", + "AutoModelForTextEncoding", + "AutoModelForMaskedImageModeling", + "AutoModelForMaskedLM", + "AutoModelForMultipleChoice", + "AutoModelForNextSentencePrediction", + "AutoModelForObjectDetection", + "AutoModelForPreTraining", + "AutoModelForQuestionAnswering", + "AutoModelForSemanticSegmentation", + "AutoModelForSeq2SeqLM", + "AutoModelForSequenceClassification", + "AutoModelForSpeechSeq2Seq", + "AutoModelForTableQuestionAnswering", + "AutoModelForTextToSpectrogram", + "AutoModelForTextToWaveform", + "AutoModelForTokenClassification", + "AutoModelForUniversalSegmentation", + "AutoModelForVideoClassification", + "AutoModelForVision2Seq", + "AutoModelForVisualQuestionAnswering", + "AutoModelForDocumentQuestionAnswering", + "AutoModelWithLMHead", + "AutoModelForZeroShotImageClassification", + "AutoModelForZeroShotObjectDetection", + "AutoModelForImageTextToText", +] \ No newline at end of file diff --git a/mindnlp/transformers/pipelines.py b/mindnlp/transformers/pipelines.py new file mode 100644 index 000000000..e69de29bb diff --git a/mindnlp/utils/__init__.py b/mindnlp/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/mindnlp/utils/safetensors_patch.py b/mindnlp/utils/safetensors_patch.py new file mode 100644 index 000000000..740b7ba5c --- /dev/null +++ b/mindnlp/utils/safetensors_patch.py @@ -0,0 +1,206 @@ +import json +import mmap +from typing import OrderedDict +import numpy as np +import mindspore +from mindspore import Tensor + +from mindnlp.core.configs import SUPPORT_BF16 + +if SUPPORT_BF16: + from mindspore.common.np_dtype import bfloat16 # pylint: disable=import-error +else: + from ml_dtypes import bfloat16 + +MAGIC_NUMBER = 0x1950A86A20F9469CFC6C +PROTOCOL_VERSION = 1001 +MAX_HEADER_SIZE = 100 * 1000 * 1000 + +_MS_TYPES = { + "F64": mindspore.float64, + "F32": mindspore.float32, + "F16": mindspore.float16, + "BF16": mindspore.bfloat16, + "I64": mindspore.int64, + "U64": mindspore.uint64, + "I32": mindspore.int32, + "U32": mindspore.uint32, + "I16": mindspore.int16, + "U16": mindspore.uint16, + "I8": mindspore.int8, + "U8": mindspore.uint8, + "BOOL": mindspore.bool_, +} + +_NP_TYPES = { + "F64": np.float64, + "F32": np.float32, + "F16": np.float16, + "BF16": bfloat16, + "I64": np.int64, + "U64": np.uint64, + "I32": np.int32, + "U32": np.uint32, + "I16": np.int16, + "U16": np.uint16, + "I8": np.int8, + "U8": np.uint8, + "BOOL": bool, +} + + +_DTYPE_SIZE = { + "BOOL": 1, + "U8": 1, + "I8": 1, + "F8_E5M2": 1, + "F8_E4M3": 1, + "I16": 2, + "U16": 2, + "I32": 4, + "U32": 4, + "I64": 8, + "U64": 8, + "F16": 2, + "BF16": 2, + "F32": 4, + "F64": 8, +} + +class PySafeSlice: + def __init__(self, info, bufferfile, base_ptr, buffermmap): + self.info = info + self.bufferfile = bufferfile + self.buffermmap = buffermmap + self.base_ptr = base_ptr + + self.start = [0 for _ in self.shape] + self.stop = list(self.shape) + self.step = [1 for _ in self.shape] + + @property + def ndim(self): + return len(self.shape) + + def get(self, *args, **kwargs): + nbytes = int(np.prod(self.shape)) * np.dtype(self.dtype).itemsize + offset = self.start_offset + tensor = np.frombuffer(self.buffermmap, dtype=self.dtype, offset=offset, + count=nbytes // np.dtype(self.dtype).itemsize) + tensor = tensor.reshape(self.shape) + if not SUPPORT_BF16 and self.info["dtype"] == 'BF16': + tensor = tensor.astype(np.float16) + tensor = Tensor.from_numpy(tensor) + return tensor + + @property + def start_offset(self): + return self.base_ptr + self.info["data_offsets"][0] + + def get_shape(self): + return self.shape + + def get_dtype(self): + return self.info["dtype"] + + @property + def shape(self): + return self.info["shape"] + + @property + def dtype(self): + return _NP_TYPES[self.info["dtype"]] + + @property + def nelements(self): + return np.prod(self.info["shape"]) + + @property + def bits(self): + return _DTYPE_SIZE[self.info["dtype"]] + + @property + def nbytes(self): + return self.nelements * self.bits + + def __getitem__(self, slice): + return self.get()[slice] + +def getSize(fileobject): + fileobject.seek(0, 2) # move the cursor to the end of the file + size = fileobject.tell() + fileobject.seek(0) # move the cursor to the start of the file + return size + + +def metadata_validate(metadata): + start = 0 + for key, info in metadata.items(): + s, e = info["data_offsets"] + if s != start or e < s: + raise ValueError(f"SafeTensorError::InvalidOffset({key})") + start = e + nelements = np.prod(info["shape"]) + nbytes = nelements * _DTYPE_SIZE[info["dtype"]] + if (e - s) != nbytes: + raise ValueError("SafeTensorError::TensorInvalidInfo") + return start + +def read_metadata(buffer): + buffer_len = getSize(buffer) + if buffer_len < 8: + raise ValueError("SafeTensorError::HeaderTooSmall") + + n = np.frombuffer(buffer.read(8), dtype=np.uint64).item() + + if n > MAX_HEADER_SIZE: + raise ValueError("SafeTensorError::HeaderTooLarge") + + stop = n + 8 + if stop > buffer_len: + raise ValueError("SafeTensorError::InvalidHeaderLength") + + tensors = json.loads(buffer.read(n), object_pairs_hook=OrderedDict) + + metadata = tensors.pop("__metadata__", None) + buffer_end = metadata_validate(tensors) + + if buffer_end + 8 + n != buffer_len: + raise ValueError("SafeTensorError::MetadataIncompleteBuffer") + + return stop, tensors, metadata + + +class fast_safe_open: + def __init__(self, filename, framework=None, device="cpu"): + self.filename = filename + self.framework = framework + self.file = open(self.filename, "rb") + self.file_mmap = mmap.mmap(self.file.fileno(), 0, access=mmap.ACCESS_COPY) + self.base, self.tensors_decs, self.__metadata__ = read_metadata(self.file) + self.tensors = OrderedDict() + for key, info in self.tensors_decs.items(): + self.tensors[key] = PySafeSlice(info, self.file, self.base, self.file_mmap) + self.tensors[key].key = key + + def __enter__(self): + return self + + def __exit__(self, *args): + self.file.close() + + def metadata(self): + return self.__metadata__ + + def keys(self): + return list(self.tensors.keys()) + + def get_tensor(self, name): + return self.tensors[name].get() + + def get_slice(self, name): + return self.tensors[name] + +def setup_safetensors_patch(): + import safetensors + safetensors.safe_open = fast_safe_open diff --git a/mindnlp/utils/torch_proxy.py b/mindnlp/utils/torch_proxy.py new file mode 100644 index 000000000..22e0c410a --- /dev/null +++ b/mindnlp/utils/torch_proxy.py @@ -0,0 +1,94 @@ +import sys +import types +import importlib +import importlib.metadata + +class TorchProxyModule(types.ModuleType): + def __init__(self): + super().__init__("torch") + # 保存真实模块的引用 + self._real_module = None + + def _load_real_module(self): + """按需加载真实模块""" + if self._real_module is None: + # 尝试直接导入mindnlp.core作为torch + self._real_module = importlib.import_module("mindnlp.core") + # 添加必要的元数据属性 + self._real_module.__name__ = "torch" + self._real_module.__package__ = "torch" + self._real_module.__file__ = "" + + return self._real_module + + def __getattr__(self, name): + """任何属性访问都重定向到真实模块""" + # 处理特殊元数据属性 + if name in {"__name__", "__package__", "__file__"}: + return getattr(self._load_real_module(), name) + + return getattr(self._load_real_module(), name) + + def __setattr__(self, name, value): + """属性设置也重定向到真实模块""" + # 跳过自身内部属性 + if name in {"_real_module", "__name__", "__package__", "__file__"}: + super().__setattr__(name, value) + else: + setattr(self._load_real_module(), name, value) + + def __dir__(self): + """返回真实模块的属性列表""" + return dir(self._load_real_module()) + + def __getattribute__(self, name): + """特殊处理元数据相关属性""" + if name == '__file__': + return '' + if name == '__package__': + return 'torch' + if name == '__spec__': + return self._create_mock_spec() + return super().__getattribute__(name) + +def initialize_torch_proxy(): + + torch_proxy = TorchProxyModule() + sys.modules["torch"] = torch_proxy + + # 设置必要的元数据 + torch_proxy.__version__ = "1.13.1+mindnlp" + + return torch_proxy + +def setup_metadata_patch(): + """解决 importlib.metadata 找不到 torch 的问题""" + # 保存原始函数 + orig_distribution = importlib.metadata.distribution + orig_distributions = importlib.metadata.distributions + + # 拦截对 torch 分发的查询 + def patched_distribution(dist_name): + if dist_name == "torch": + return types.SimpleNamespace( + version="1.13.1+mindnlp", + metadata={"Name": "torch", "Version": "1.13.1+mindnlp"}, + read_text=lambda f: f"Name: torch\nVersion: 1.13.1+mindnlp" if f == "METADATA" else None + ) + return orig_distribution(dist_name) + + # 确保分发列表中有 torch + def patched_distributions(**kwargs): + dists = list(orig_distributions(**kwargs)) + dists.append(types.SimpleNamespace( + name="torch", + version="1.13.1+mindnlp", + metadata={"Name": "torch", "Version": "1.13.1+mindnlp"}, + files=[], + locate_file=lambda p: None + )) + return dists + + # 应用补丁 + importlib.metadata.distribution = patched_distribution + importlib.metadata.distributions = patched_distributions diff --git a/setup.py b/setup.py index 04a5544d7..7bbfec1f4 100644 --- a/setup.py +++ b/setup.py @@ -15,14 +15,53 @@ setup packpage """ import os +import sys import stat import shlex import shutil import subprocess +import sysconfig from setuptools import find_packages from setuptools import setup from setuptools.command.egg_info import egg_info from setuptools.command.build_py import build_py +from setuptools.command.install import install + +def _create_namespace_links(): + # 获取目标路径 (site-packages/mindnlp/transformers) + install_lib = sysconfig.get_path("purelib") # 兼容虚拟环境 + target_dir = os.path.join(install_lib, "mindnlp", "transformers") + + print('target_dir', target_dir) + # 获取源路径 (site-packages/transformers) + try: + import transformers + source_path = os.path.dirname(transformers.__file__) + except ImportError: + # 如果 transformers 未安装则自动安装 + subprocess.check_call([sys.executable, "-m", "pip", "install", "transformers"]) + import transformers + source_path = os.path.dirname(transformers.__file__) + + # 清理旧链接 + if os.path.exists(target_dir): + if os.path.islink(target_dir) or sys.platform == "win32": + os.remove(target_dir) + else: + shutil.rmtree(target_dir) + + # 创建符号链接 + if sys.platform == "win32": + # Windows 需管理员权限或开发者模式 + subprocess.check_call(f'mklink /J "{target_dir}" "{source_path}"', shell=True) + else: + os.symlink(source_path, target_dir, target_is_directory=True) + +class CustomInstall(install): + def run(self): + super().run() + if "install" in sys.argv: + _create_namespace_links() # 安装后创建链接 version = '0.5.0' @@ -114,6 +153,7 @@ def run(self): cmdclass={ 'egg_info': EggInfo, 'build_py': BuildPy, + "install": CustomInstall, }, install_requires=[ 'mindspore>=2.5.0', @@ -131,7 +171,6 @@ def run(self): 'pyctcdecode', 'pytest==7.2.0', 'pillow>=10.0.0', - 'mindtorch@git+https://openi.pcl.ac.cn/lvyufeng/mindtorch.git' ], classifiers=[ 'License :: OSI Approved :: Apache Software License' diff --git a/tests/core/autograd/__init__.py b/tests/core/autograd/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/core/autograd/test_add.py b/tests/core/autograd/test_add.py new file mode 100644 index 000000000..d8dd9c70b --- /dev/null +++ b/tests/core/autograd/test_add.py @@ -0,0 +1,100 @@ +from unittest import TestCase +import mindnlp +from mindnlp.core import tensor + + +class TestAdd(TestCase): + + def test_simple_add(self): + # scalar add + t1 = tensor(1.0) + t2 = tensor(2.0) + t3 = t1 + t2 + self.assertEqual(t3.data.tolist(), 3.0) + + t1 = tensor(2.0, requires_grad=True) + t2 = tensor(3.0) + t3 = t1 + t2 + print(t1.grad_fn, t2.grad_fn) + t3.backward() + print(t3.grad_fn) + self.assertEqual(t1.grad.data.tolist(), 1.0) + + t1 = tensor(2.0) + t2 = tensor(3.0, requires_grad=True) + t3 = t1 + t2 + t3.backward() + self.assertEqual(t2.grad.data.tolist(), 1.0) + + t1 = tensor(2.0, requires_grad=True) + t2 = tensor(3.0, requires_grad=True) + t3 = t1 + t2 + t3.backward() + self.assertEqual(t1.grad.data.tolist(), 1.0) + self.assertEqual(t2.grad.data.tolist(), 1.0) + + # vector add + t1 = tensor([1.0, 2.0]) + t2 = tensor([2.0, 3.0]) + t3 = t1 + t2 + self.assertEqual(t3.data.tolist(), [3.0, 5.0]) + + t1 = tensor([1.0, 2.0], requires_grad=True) + t2 = tensor([2.0, 3.0]) + t3 = t1 + t2 + t3.backward(tensor([1.0, 1.0])) + self.assertEqual(t1.grad.data.tolist(), [1.0, 1.0]) + + t1 = tensor([1.0, 2.0]) + t2 = tensor([2.0, 3.0], requires_grad=True) + t3 = t1 + t2 + t3.backward(tensor([1.0, 1.0])) + self.assertEqual(t2.grad.data.tolist(), [1.0, 1.0]) + + t1 = tensor([1.0, 2.0], requires_grad=True) + t2 = tensor([2.0, 3.0], requires_grad=True) + t3 = t1 + t2 + t3.backward(tensor([1.0, 1.0])) + self.assertEqual(t1.grad.data.tolist(), [1.0, 1.0]) + self.assertEqual(t2.grad.data.tolist(), [1.0, 1.0]) + + def test_broadcast_add(self): + # (2,) + () + t1 = tensor([1.0, 2.0], requires_grad=True) + t2 = tensor(2.0, requires_grad=True) + t3 = t1 + t2 + t3.backward(tensor([1.0, 1.0])) + self.assertEqual(t1.grad.data.tolist(), [1.0, 1.0]) + self.assertEqual(t2.grad.data.tolist(), 2.0) + + # (2,) + (1,) + t1 = tensor([1.0, 2.0], requires_grad=True) + t2 = tensor([2.0], requires_grad=True) + t3 = t1 + t2 + t3.backward(tensor([1.0, 1.0])) + self.assertEqual(t1.grad.data.tolist(), [1.0, 1.0]) + self.assertEqual(t2.grad.data.tolist(), [2.0]) + + # (2, 2) + () + t1 = tensor([[1.0, 2.0], [3.0, 4.0]], requires_grad=True) + t2 = tensor(2.0, requires_grad=True) + t3 = t1 + t2 + t3.backward(tensor([[1.0, 1.0], [1.0, 1.0]])) + self.assertEqual(t1.grad.data.tolist(), [[1.0, 1.0], [1.0, 1.0]]) + self.assertEqual(t2.grad.data.tolist(), 4.0) + + # (2, 2) + (1,) + t1 = tensor([[1.0, 2.0], [3.0, 4.0]], requires_grad=True) + t2 = tensor([2.0], requires_grad=True) + t3 = t1 + t2 + t3.backward(tensor([[1.0, 1.0], [1.0, 1.0]])) + self.assertEqual(t1.grad.data.tolist(), [[1.0, 1.0], [1.0, 1.0]]) + self.assertEqual(t2.grad.data.tolist(), [4.0]) + + # (2, 2) + (2, ) + t1 = tensor([[1.0, 2.0], [3.0, 4.0]], requires_grad=True) + t2 = tensor([2.0, 3.0], requires_grad=True) + t3 = t1 + t2 + t3.backward(tensor([[1.0, 1.0], [1.0, 1.0]])) + self.assertEqual(t1.grad.data.tolist(), [[1.0, 1.0], [1.0, 1.0]]) + self.assertEqual(t2.grad.data.tolist(), [2.0, 2.0]) \ No newline at end of file diff --git a/tests/core/autograd/test_div.py b/tests/core/autograd/test_div.py new file mode 100644 index 000000000..02c4e7de4 --- /dev/null +++ b/tests/core/autograd/test_div.py @@ -0,0 +1,98 @@ +from unittest import TestCase +import mindnlp +from mindnlp.core import tensor + + +class TestDiv(TestCase): + + def test_simple_div(self): + # scalar div + t1 = tensor(1.0) + t2 = tensor(2.0) + t3 = t1 / t2 + self.assertEqual(t3.tolist(), 0.5) + + t1 = tensor(1.0, requires_grad=True) + t2 = tensor(2.0) + t3 = t1 / t2 + t3.backward() + self.assertEqual(t1.grad.tolist(), 0.5) + + t1 = tensor(1.0) + t2 = tensor(2.0, requires_grad=True) + t3 = t1 / t2 + t3.backward() + self.assertEqual(t2.grad.tolist(), -0.25) + + t1 = tensor(1.0, requires_grad=True) + t2 = tensor(2.0, requires_grad=True) + t3 = t1 / t2 + t3.backward() + self.assertEqual(t1.grad.tolist(), 0.5) + self.assertEqual(t2.grad.tolist(), -0.25) + + # vector div + t1 = tensor([1.0, 2.0]) + t2 = tensor([2.0, 4.0]) + t3 = t1 / t2 + self.assertEqual(t3.tolist(), [0.5, 0.5]) + + t1 = tensor([1.0, 2.0], requires_grad=True) + t2 = tensor([2.0, 4.0]) + t3 = t1 / t2 + t3.backward(tensor([1.0, 1.0])) + self.assertEqual(t1.grad.tolist(), [0.5, 0.25]) + + t1 = tensor([1.0, 2.0]) + t2 = tensor([2.0, 4.0], requires_grad=True) + t3 = t1 / t2 + t3.backward(tensor([1.0, 1.0])) + self.assertEqual(t2.grad.tolist(), [-0.25, -1/8]) + + t1 = tensor([1.0, 2.0], requires_grad=True) + t2 = tensor([2.0, 4.0], requires_grad=True) + t3 = t1 / t2 + t3.backward(tensor([1.0, 1.0])) + self.assertEqual(t1.grad.tolist(), [0.5, 0.25]) + self.assertEqual(t2.grad.tolist(), [-0.25, -1/8]) + + def test_broadcast_div(self): + # (2,) / () + t1 = tensor([1.0, 2.0], requires_grad=True) + t2 = tensor(2.0, requires_grad=True) + t3 = t1 / t2 + t3.backward(tensor([1.0, 1.0])) + self.assertEqual(t1.grad.tolist(), [0.5, 0.5]) + self.assertEqual(t2.grad.tolist(), -0.75) + + # (2,) / (1,) + t1 = tensor([1.0, 2.0], requires_grad=True) + t2 = tensor([2.0], requires_grad=True) + t3 = t1 / t2 + t3.backward(tensor([1.0, 1.0])) + self.assertEqual(t1.grad.tolist(), [0.5, 0.5]) + self.assertEqual(t2.grad.tolist(), [-0.75]) + + # (2, 2) / () + t1 = tensor([[1.0, 2.0], [3.0, 4.0]], requires_grad=True) + t2 = tensor(2.0, requires_grad=True) + t3 = t1 / t2 + t3.backward(tensor([[1.0, 1.0], [1.0, 1.0]])) + self.assertEqual(t1.grad.tolist(), [[0.5, 0.5], [0.5, 0.5]]) + self.assertEqual(t2.grad.tolist(), -2.5) + + # (2, 2) / (1,) + t1 = tensor([[1.0, 2.0], [3.0, 4.0]], requires_grad=True) + t2 = tensor([2.0], requires_grad=True) + t3 = t1 / t2 + t3.backward(tensor([[1.0, 1.0], [1.0, 1.0]])) + self.assertEqual(t1.grad.tolist(), [[0.5, 0.5], [0.5, 0.5]]) + self.assertEqual(t2.grad.tolist(), [-2.5]) + + # (2, 2) / (2, ) + t1 = tensor([[1.0, 2.0], [3.0, 4.0]], requires_grad=True) + t2 = tensor([2.0, 4.0], requires_grad=True) + t3 = t1 / t2 + t3.backward(tensor([[1.0, 1.0], [1.0, 1.0]])) + self.assertEqual(t1.grad.tolist(), [[0.5, 0.25], [0.5, 0.25]]) + self.assertEqual(t2.grad.tolist(), [-1.0, -0.375]) \ No newline at end of file diff --git a/tests/core/autograd/test_exp.py b/tests/core/autograd/test_exp.py new file mode 100644 index 000000000..625c84047 --- /dev/null +++ b/tests/core/autograd/test_exp.py @@ -0,0 +1,29 @@ +from unittest import TestCase + +import numpy as np +import mindnlp +from mindnlp.core import tensor + + +class TestExp(TestCase): + + def test_exp(self): + # scalar exp + t1 = tensor(2.0) + t2 = t1.exp() + np.testing.assert_allclose(t2.array, np.exp(2)) + + t1 = tensor(2.0, requires_grad=True) + t2 = t1.exp() + t2.backward() + np.testing.assert_allclose(t1.grad.array, np.exp(2)) + + # vector exp + t1 = tensor([1.0, 2.0]) + t2 = t1.exp() + np.testing.assert_allclose(t2.array, np.exp([1, 2])) + + t1 = tensor([1.0, 2.0], requires_grad=True) + t2 = t1.exp() + t2.backward(tensor([1.0, 1.0])) + np.testing.assert_allclose(t1.grad.array, np.exp([1, 2])) \ No newline at end of file diff --git a/tests/core/autograd/test_function.py b/tests/core/autograd/test_function.py new file mode 100644 index 000000000..d2607d7bb --- /dev/null +++ b/tests/core/autograd/test_function.py @@ -0,0 +1,27 @@ +import mindnlp +from mindnlp import core as torch +from mindnlp.core.autograd import Function + +class Test(Function): + + @staticmethod + def forward(ctx, x, y): + ctx.save_for_backward(x, y) + return x + y + 1 + + @staticmethod + def backward(ctx, grad): + x, y = ctx.saved_tensors + print(x, y) + return torch.ones_like(x), torch.zeros_like(y) + +def fn_test(x, y): + return Test.apply(x, y) + +def test_function_no_record_forward_inputs(): + x = torch.randn(3, 3, requires_grad=True) + y = torch.randn(3, requires_grad=True) + out = fn_test(x, y) + out.backward() + print(x.requires_grad) + print(y.requires_grad) diff --git a/tests/core/autograd/test_split.py b/tests/core/autograd/test_split.py new file mode 100644 index 000000000..b2c1c722f --- /dev/null +++ b/tests/core/autograd/test_split.py @@ -0,0 +1,22 @@ +from unittest import TestCase +import mindnlp +from mindnlp import core as torch +from mindnlp.core import tensor + + +class TestSplit(TestCase): + def test_simple_split(self): + x = torch.randn(3, 2) + y1, y2 = x.tensor_split(2, -1) + assert y1.shape == (3, 1) + assert y2.shape == (3, 1) + + def test_split_backward(self): + # scalar add + x = torch.randn(3, 2, requires_grad=True) + y1, y2 = x.tensor_split(2, -1) + assert y1.shape == (3, 1) + assert y2.shape == (3, 1) + z = y1 + y2 + z.sum().backward() + print(x.grad) diff --git a/tests/core/test_autograd.py b/tests/core/test_autograd.py new file mode 100644 index 000000000..e4c7f79ff --- /dev/null +++ b/tests/core/test_autograd.py @@ -0,0 +1,75 @@ +import mindnlp +import torch + +def test_model(): + """ + Test the model implemented using TinyTorch against a corresponding model implemented using PyTorch. + + This test compares the loss values and gradients obtained from both models and asserts their closeness. + + Raises: + AssertionError: If the loss values or gradients differ beyond the specified tolerance. + """ + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.l1 = torch.nn.Linear(3, 4, name='l1') + self.l2 = torch.nn.Linear(4, 1, name='l2') + + def forward(self, x): + z = self.l1(x).relu() + out = self.l2(z).sigmoid() + return out + + X = torch.tensor([[0., 1., 2.], [10., 20., 30.]]) + y = torch.tensor([[1.], [0.]]) + + model = Model() + + class ModelT(torch.nn.Module): + def __init__(self): + super().__init__() + self.l1 = torch.nn.Linear(3, 4) + self.l2 = torch.nn.Linear(4, 1) + + def forward(self, x): + z = self.l1(x).relu() + out = self.l2(z).sigmoid() + return out + + modelT = ModelT().double() + + XT = torch.tensor([[0., 1., 2.], [10., 20., 30.]]).double() + yT = torch.tensor([[1.], [0.]]).double() + + with torch.no_grad(): + modelT.l1.weight = torch.nn.Parameter(torch.tensor(model.l1.weight.data, dtype=torch.float64)) + modelT.l1.bias = torch.nn.Parameter(torch.tensor(model.l1.bias.data, dtype=torch.float64)) + + modelT.l2.weight = torch.nn.Parameter(torch.tensor(model.l2.weight.data, dtype=torch.float64)) + modelT.l2.bias = torch.nn.Parameter(torch.tensor(model.l2.bias.data, dtype=torch.float64)) + + optimizer = torch.optim.SGD(model.parameters(), lr=1e-2, weight_decay=1e-3) + scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.5, total_iters=100) + + optimizerT = torch.optim.SGD(modelT.parameters(), lr=1e-2, weight_decay=1e-3) + schedulerT = torch.optim.lr_scheduler.LinearLR(optimizerT, start_factor=1.0, end_factor=0.5, total_iters=100) + + loss_fn = torch.nn.BCELoss() + + tol = 1e-6 + + for i in range(1000): + y_hat = model(X) + y_hatT = modelT(XT) + loss = torch.nn.functional.binary_cross_entropy(y_hat, y) + lossT = torch.nn.functional.binary_cross_entropy(y_hatT, yT) + assert abs(loss.data - lossT.data.item()) < tol + loss.backward() + lossT.backward() + optimizer.step() + optimizerT.step() + scheduler.step() + schedulerT.step() + optimizer.zero_grad() + optimizerT.zero_grad() \ No newline at end of file