Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[sync] Sync feature/colossal-infer with main #5737

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
9efc79e
add parallel output for mistral model
wangbluo Apr 30, 2024
2632916
remove useless code
wangbluo May 1, 2024
c25f83c
fix missing pad token (#5690)
Edenzzzz May 6, 2024
77ec773
[zero]remove registered gradients hooks (#5687)
flybird11111 May 7, 2024
58954b2
[misc] Add an existing issue checkbox in bug report (#5691)
Edenzzzz May 7, 2024
88f057c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 7, 2024
108ddfb
add parallel_output for the opt model
wangbluo May 3, 2024
ca56b93
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 7, 2024
a8408b4
remove comment code
wangbluo May 7, 2024
4e50cce
fix the mistral model
wangbluo May 7, 2024
2229778
Merge pull request #5684 from wangbluo/parallel_output
wangbluo May 8, 2024
d4c5ef4
[gemini]remove registered gradients hooks (#5696)
flybird11111 May 9, 2024
a3cc68c
[Shardformer] Support the Qwen2 model (#5699)
wangbluo May 9, 2024
537f6a3
[Shardformer]fix the num_heads assert for llama model and qwen model …
wangbluo May 10, 2024
785cd9a
[misc] Update PyTorch version in docs (#5711)
Edenzzzz May 13, 2024
393c8f5
[hotfix] fix inference typo (#5438)
hugo-syn May 13, 2024
43995ee
[Feature] Distributed optimizers: Lamb, Galore, CAME and Adafactor (#…
Edenzzzz May 14, 2024
913c920
[Colossal-LLaMA] Fix sft issue for llama2 (#5719)
TongLi3701 May 15, 2024
2011b13
[misc] Update PyTorch version in docs (#5724)
binmakeswell May 16, 2024
9d83c6d
[lazy] fix lazy cls init (#5720)
flybird11111 May 17, 2024
8633c15
[sync] Sync feature/colossal-infer with main
yuanheng-zhao May 20, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .github/ISSUE_TEMPLATE/bug-report.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,13 @@ body:
attributes:
value: >
#### Not suitable for your needs? [Open a blank issue](https://github.com/hpcaitech/ColossalAI/issues/new).
- type: checkboxes
attributes:
label: Is there an existing issue for this bug?
description: Please search [here](https://github.com/hpcaitech/ColossalAI/issues) to see if an open or closed issue already exists for the bug you have encountered.
options:
- label: I have searched the existing issues
required: true
- type: textarea
attributes:
label: 🐛 Describe the bug
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/build_on_pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ jobs:

- name: Install Colossal-AI
run: |
BUILD_EXT=1 pip install -v -e .
pip install -v -e .
pip install -r requirements/requirements-test.txt

- name: Store Colossal-AI Cache
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,7 @@ Please visit our [documentation](https://www.colossalai.org/) and [examples](htt
## Installation

Requirements:
- PyTorch >= 1.11 and PyTorch <= 2.1
- PyTorch >= 2.1
- Python >= 3.7
- CUDA >= 11.0
- [NVIDIA GPU Compute Capability](https://developer.nvidia.com/cuda-gpus) >= 7.0 (V100/RTX20 and higher)
Expand Down
3 changes: 2 additions & 1 deletion applications/Colossal-LLaMA/prepare_sft_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import os
from multiprocessing import cpu_count

from colossal_llama.dataset.conversation import default_conversation
from colossal_llama.dataset.conversation import LLaMA2_Conv
from colossal_llama.dataset.spliced_and_tokenized_dataset import supervised_tokenize_sft
from datasets import dataset_dict, load_dataset
from transformers import AddedToken, AutoTokenizer
Expand Down Expand Up @@ -78,6 +78,7 @@ def main():
# Fix </s> split issue: https://github.com/huggingface/transformers/issues/23833
if args.llama_version == 2:
tokenizer.add_tokens(AddedToken("</s>", normalized=False, special=True), special_tokens=True)
default_conversation = LLaMA2_Conv

tokenizer.add_bos_token = False
tokenizer.add_eos_token = False
Expand Down
40 changes: 31 additions & 9 deletions colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import ctypes
import random
import warnings
from collections import defaultdict
from contextlib import contextmanager
from copy import deepcopy
from functools import partial
from types import MethodType
from typing import Any, Callable, Dict, Iterator, List, Optional, OrderedDict, Tuple, Union
Expand All @@ -24,6 +26,8 @@
from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO
from colossalai.cluster import ProcessGroupMesh
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
from colossalai.interface.optimizer import DistributedOptim
from colossalai.nn.optimizer import DistGaloreAwamW
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer
Expand Down Expand Up @@ -735,7 +739,7 @@ def _get_grads_to_sync(all_working_grads) -> Union[List[Tensor], None]:
# Get all working gradients and gradients to be synchronized.
all_working_grads = _get_all_working_grads()
grads_to_sync = _get_grads_to_sync(all_working_grads)
if self.require_grad_sync and grads_to_sync is not None:
if self._grad_store.require_grad_sync and grads_to_sync is not None:
# Synchronize sequence parallelism gradients if required.
SeqParallelUtils.allreduce_partial_data_grad(process_group=self.tp_pg, grads=grads_to_sync)
else:
Expand All @@ -759,7 +763,7 @@ def backward(self, loss, retain_graph=False):
# Call the superclass backward method to compute gradients.
super().backward(loss, retain_graph)

if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism:
if self._grad_store.require_grad_sync and self.model.shard_config.enable_sequence_parallelism:
# If gradient synchronization is required, sync sequence parallelism gradients.
self._sync_sp_grads()
else:
Expand All @@ -784,7 +788,7 @@ def backward_by_grad(self, tensor, grad):
# Call the superclass backward_by_grad method to compute gradients.
super().backward_by_grad(tensor, grad)

if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism:
if self._grad_store.require_grad_sync and self.model.shard_config.enable_sequence_parallelism:
# If gradient synchronization is required, sync sequence parallelism gradients.
self._sync_sp_grads()
else:
Expand Down Expand Up @@ -1171,6 +1175,15 @@ def configure(
lr_scheduler: Optional[LRScheduler] = None,
) -> Tuple[Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
param_info = get_param_info(optimizer)

# TODO: Support Galore + ZeRO
zero_stage = self.zero_stage
zero_config = deepcopy(self.zero_config)
if isinstance(optimizer, DistGaloreAwamW) and zero_stage > 0 and self.dp_size > 0:
warnings.warn("Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.")
zero_config["partition_grad"] = False
zero_stage = 0

if not isinstance(model, ModelWrapper):
use_ddp = (self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0) or (
self.dp_size == 1
Expand All @@ -1194,7 +1207,8 @@ def configure(
custom_policy=self.custom_policy,
)
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
if self.zero_stage == 0:
if zero_stage == 0:
is_zero = False
if self.precision in ["fp16", "bf16"]:
optimizer = HybridParallelAMPOptimizer(
optimizer,
Expand All @@ -1218,11 +1232,11 @@ def configure(
tp_process_group=self.tp_group,
)
else:
zero_dp_size = dist.get_world_size(dp_group)
if zero_dp_size == 1:
is_zero = self.dp_size > 1
if self.dp_size == 1:
warnings.warn(
"Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. "
"If you are not intended to use cpu_offload, please consider set zero_stage=0."
"If you do not intend to use cpu_offload, please consider set zero_stage=0."
)

assert self.precision != "fp32", "Please set precision to 'fp16' or 'bf16' when using ZeRO."
Expand All @@ -1236,11 +1250,19 @@ def configure(
pp_process_group=self.pp_group,
verbose=True,
clip_grad_norm=self.max_norm,
**self.zero_config,
**zero_config,
**self.amp_config,
)
# inject update_master_params
model.update_master_params = MethodType(optimizer.update_master_params, model)

# Setup optimizers that require global states
optim = optimizer.optim
if isinstance(optim, DistributedOptim):
shard_to_param = optimizer.get_master_to_working_map() if is_zero else {}
padding_map = optimizer.get_param_padding_map() if is_zero else defaultdict(int)
optim.setup_distributed(self.tp_group, self.dp_group, shard_to_param, padding_map, is_zero)

return model, optimizer, criterion, dataloader, lr_scheduler

def execute_pipeline(
Expand Down Expand Up @@ -1272,7 +1294,7 @@ def execute_pipeline(

# run with gradients accumulation
if model.require_grad_sync == False or (
isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer.require_grad_sync == False
isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer._grad_store.require_grad_sync == False
):
return outputs

Expand Down
25 changes: 24 additions & 1 deletion colossalai/booster/plugin/low_level_zero_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@
from typing import Callable, Dict, Iterator, List, Optional, Tuple

import torch
import torch.distributed
import torch.distributed as dist
import torch.nn as nn
from torch.distributed.distributed_c10d import _get_default_group
from torch.nn import Parameter
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
Expand All @@ -28,6 +31,8 @@
sharded_optimizer_loading_epilogue,
)
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
from colossalai.interface.optimizer import DistributedOptim
from colossalai.nn.optimizer import DistGaloreAwamW
from colossalai.quantization import BnbQuantizationConfig, quantize_model
from colossalai.zero import LowLevelZeroOptimizer

Expand Down Expand Up @@ -428,13 +433,31 @@ def configure(
if not isinstance(model, ModelWrapper):
model = LowLevelZeroModel(model, self.precision)

# TODO: Support Galore + ZeRO
zero_stage = self.stage
zero_optim_kwargs = {**self.zero_optim_kwargs}
dp_size = dist.get_world_size()
if isinstance(optimizer, DistGaloreAwamW) and zero_stage > 0 and dp_size > 0:
warnings.warn("Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.")
zero_optim_kwargs["partition_grad"] = False
zero_stage = 0

if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
optimizer: LowLevelZeroOptimizer = LowLevelZeroOptimizer(
optimizer, **self.zero_optim_kwargs, verbose=self.verbose
optimizer, **zero_optim_kwargs, verbose=self.verbose
)
# inject update_master_params
model.update_master_params = MethodType(optimizer.update_master_params, model)

# Setup optimizers that require global states
optim = optimizer.optim
is_zero = dp_size > 1 and zero_stage > 0
dp_group = _get_default_group() # Use the whole world
if isinstance(optim, DistributedOptim):
shard_to_param = optimizer.get_master_to_working_map()
padding_map = optimizer.get_param_padding_map()
optim.setup_distributed(None, dp_group, shard_to_param, padding_map, is_zero)

return model, optimizer, criterion, dataloader, lr_scheduler

def control_checkpoint_io(self) -> bool:
Expand Down
7 changes: 6 additions & 1 deletion colossalai/cluster/process_group_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,12 @@ class ProcessGroupMesh:

def __init__(self, *size: int) -> None:
assert dist.is_initialized(), "Please initialize torch.distributed first."
assert prod(size) == dist.get_world_size(), "The product of the size must be equal to the world size."
world_size = dist.get_world_size()
prod_size = prod(size)
assert (
prod_size == world_size
), f"The product of the size({prod_size}) must be equal to the world size({world_size})."

self._shape = size
self._rank = dist.get_rank()
self._coord = ProcessGroupMesh.unravel(self._rank, self._shape)
Expand Down
6 changes: 3 additions & 3 deletions colossalai/device/device_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,9 +306,8 @@ def _init_global_to_logical_rank_mapping(
# index means the local rank in the current axis
# inner_tensor refers to the processes with the same local rank

if inner_tensor.numel() == 1:
# if the inner_tensor only has one element, it means that
# it already reaches the last axis
if inner_tensor.dim() == 0:
# if the inner_tensor already reaches the last axis,
# we append its local_rank in the last axis to the index_list
# and assign to the mapping
# the value of the mapping is the the local rank at the indexed axis of the device mesh
Expand Down Expand Up @@ -459,6 +458,7 @@ def _collate_global_ranks_in_same_process_group(self, global_rank):

# replace the local rank in the given dimension with the
# local rank of the current process iterated

process_coordinates[dim] = _local_rank
processes_in_the_same_process_group[dim].append(process_coordinates)

Expand Down
25 changes: 24 additions & 1 deletion colossalai/interface/optimizer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Union
from typing import Dict, Optional, Union

import torch
import torch.distributed as dist
import torch.nn as nn
from torch import Tensor
from torch.optim import Optimizer
Expand Down Expand Up @@ -133,3 +134,25 @@ def unwrap(self):
Unwrap the optimizer for checkpoint saving/loading.
"""
return self.optim


class DistributedOptim(Optimizer):
def setup_distributed(
self,
tp_group: Optional[dist.ProcessGroup] = None,
dp_group: Optional[dist.ProcessGroup] = None,
shard_to_working_param: Optional[Dict] = {},
padding_map: Optional[Dict] = None,
is_zero: Optional[bool] = False,
):
"""Assign process groups for TP and ZeRO 2.
Arguments:
tp_group (dist.ProcessGroup): Tensor Parallel process group
dp_group (dist.ProcessGroup): ZeRO stage 2 process group
shard_to_working_param (Dict): ZeRO stage 2 feeds the optimizer a sharded param view to match grad shape.
This maps from id(view) to model params used in forward & backward.
padding_map (Dict): Per-param padding from ZeRO stage 2
is_zero (bool): Whether to use ZeRO stage 2.
"""

raise NotImplementedError("setup_distributed for TP/DP isn't supported by this optimizer yet!")
23 changes: 23 additions & 0 deletions colossalai/lazy/pretrained.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import os
from typing import Callable, Optional, Union

Expand Down Expand Up @@ -74,6 +75,24 @@ def new_from_pretrained(
subfolder = kwargs.pop("subfolder", "")
commit_hash = kwargs.pop("_commit_hash", None)
variant = kwargs.pop("variant", None)

kwargs.pop("state_dict", None)
kwargs.pop("from_tf", False)
kwargs.pop("from_flax", False)
kwargs.pop("output_loading_info", False)
kwargs.pop("trust_remote_code", None)
kwargs.pop("low_cpu_mem_usage", None)
kwargs.pop("device_map", None)
kwargs.pop("max_memory", None)
kwargs.pop("offload_folder", None)
kwargs.pop("offload_state_dict", False)
kwargs.pop("load_in_8bit", False)
kwargs.pop("load_in_4bit", False)
kwargs.pop("quantization_config", None)
kwargs.pop("adapter_kwargs", {})
kwargs.pop("adapter_name", "default")
kwargs.pop("use_flash_attention_2", False)

use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False)

if len(kwargs) > 0:
Expand Down Expand Up @@ -108,6 +127,10 @@ def new_from_pretrained(
**kwargs,
)
else:
config = copy.deepcopy(config)
kwarg_attn_imp = kwargs.pop("attn_implementation", None)
if kwarg_attn_imp is not None and config._attn_implementation != kwarg_attn_imp:
config._attn_implementation = kwarg_attn_imp
model_kwargs = kwargs

if commit_hash is None:
Expand Down
6 changes: 3 additions & 3 deletions colossalai/legacy/inference/async_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,14 @@ def _step(self):
self.stats_tool.count_prompt_tokens(new_batch)
self.running_batch = new_batch
has_new_finished, outputs = self._prefill_batch(self.running_batch)
self._filter_runing_batch()
self._filter_running_batch()
self.has_wait_tokens = 0

else:
if self.has_wait_tokens < self.max_wait_tokens:
self.stats_tool.count_output_tokens(self.running_batch)
has_new_finished, outputs = self._decode_batch(self.running_batch)
self._filter_runing_batch()
self._filter_running_batch()
self.has_wait_tokens += 1

else:
Expand All @@ -78,7 +78,7 @@ def _step(self):
else:
self.stats_tool.count_output_tokens(self.running_batch)
has_new_finished, outputs = self._decode_batch(self.running_batch)
self._filter_runing_batch()
self._filter_running_batch()
self.has_wait_tokens += 1

if has_new_finished:
Expand Down
8 changes: 4 additions & 4 deletions colossalai/legacy/inference/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,14 +131,14 @@ def _step(self):
self.stats_tool.count_prompt_tokens(new_batch)
self.running_batch = new_batch
yield from self._prefill_batch(self.running_batch)
self._filter_runing_batch()
self._filter_running_batch()
self.has_wait_tokens = 0
return

if self.has_wait_tokens < self.max_wait_tokens:
self.stats_tool.count_output_tokens(self.running_batch)
yield from self._decode_batch(self.running_batch)
self._filter_runing_batch()
self._filter_running_batch()
self.has_wait_tokens += 1
return
else:
Expand All @@ -154,7 +154,7 @@ def _step(self):
else:
self.stats_tool.count_output_tokens(self.running_batch)
yield from self._decode_batch(self.running_batch)
self._filter_runing_batch()
self._filter_running_batch()
self.has_wait_tokens += 1

return
Expand Down Expand Up @@ -243,7 +243,7 @@ def _handle_finish_req(self, batch: Batch, has_new_finished_req):
self._filter_batch(batch)
yield from self._output_process(finished_reqs)

def _filter_runing_batch(self):
def _filter_running_batch(self):
if self.running_batch is not None and self.running_batch.is_clear():
self.running_batch = None

Expand Down
Loading
Loading