Skip to content

Commit

Permalink
[shardformer] update bert finetune example with HybridParallelPlugin (#…
Browse files Browse the repository at this point in the history
…4584)

* [shardformer] fix opt test hanging

* fix

* test

* test

* test

* fix test

* fix test

* remove print

* add fix

* [shardformer] add bert finetune example

* [shardformer] add bert finetune example

* [shardformer] add bert finetune example

* [shardformer] add bert finetune example

* [shardformer] add bert finetune example

* [shardformer] add bert finetune example

* [shardformer] fix epoch change

* [shardformer] broadcast add pp group

* [shardformer] fix opt test hanging

* fix

* test

* test

* [shardformer] zero1+pp and the corresponding tests (#4517)

* pause

* finish pp+zero1

* Update test_shard_vit.py

* [shardformer/fix overlap bug] fix overlap bug, add overlap as an option in shardco… (#4516)

* fix overlap bug and support bert, add overlap as an option in shardconfig

* support overlap for chatglm and bloom

* [shardformer] fix emerged bugs after updating transformers (#4526)

* test

* fix test

* fix test

* remove print

* add fix

* [shardformer] add bert finetune example

* [shardformer] add bert finetune example

* [shardformer] Add overlap support for gpt2 (#4535)

* add overlap support for gpt2

* remove unused code

* remove unused code

* [shardformer] support pp+tp+zero1 tests (#4531)

* [shardformer] fix opt test hanging

* fix

* test

* test

* test

* fix test

* fix test

* remove print

* add fix

* [shardformer] pp+tp+zero1

[shardformer] pp+tp+zero1

[shardformer] pp+tp+zero1

[shardformer] pp+tp+zero1

[shardformer] pp+tp+zero1

[shardformer] pp+tp+zero1

* [shardformer] pp+tp+zero1

* [shardformer] pp+tp+zero1

* [shardformer] pp+tp+zero1

* [shardformer] pp+tp+zero1

* [shardformer] fix submodule replacement bug when enabling pp (#4544)

* [shardformer] support sharded optimizer checkpointIO of HybridParallelPlugin (#4540)

* implement sharded optimizer saving

* add more param info

* finish implementation of sharded optimizer saving

* fix bugs in optimizer sharded saving

* add pp+zero test

* param group loading

* greedy loading of optimizer

* fix bug when loading

* implement optimizer sharded saving

* add optimizer test & arrange checkpointIO utils

* fix gemini sharding state_dict

* add verbose option

* add loading of master params

* fix typehint

* fix master/working mapping in fp16 amp

* [shardformer] add bert finetune example

* [shardformer] add bert finetune example

* [shardformer] add bert finetune example

* [shardformer] add bert finetune example

* [shardformer] fix epoch change

* [shardformer] broadcast add pp group

* rebase feature/shardformer

* update pipeline

* [shardformer] fix

* [shardformer] fix

* [shardformer] bert finetune fix

* [shardformer] add all_reduce operation to loss

add all_reduce operation to loss

* [shardformer] make compatible with pytree.

make compatible with pytree.

* [shardformer] disable tp

disable tp

* [shardformer] add 3d plugin to ci test

* [shardformer] update num_microbatches to None

* [shardformer] update microbatchsize

* [shardformer] update assert

* update scheduler

* update scheduler

---------

Co-authored-by: Jianghai <72591262+CjhHa1@users.noreply.github.com>
Co-authored-by: Bin Jia <45593998+FoolPlayer@users.noreply.github.com>
Co-authored-by: Baizhou Zhang <eddiezhang@pku.edu.cn>
  • Loading branch information
4 people committed Sep 4, 2023
1 parent 24c0768 commit 0a94fcd
Show file tree
Hide file tree
Showing 4 changed files with 134 additions and 36 deletions.
2 changes: 1 addition & 1 deletion colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ def __init__(self,
self.schedule = None
assert zero_stage in (0, 1, 2)
if self.pp_size > 1:
assert num_microbatches is not None, 'num_microbatches must be specified when using pipeline parallelism'
assert num_microbatches is not None or microbatch_size is not None, 'num_microbatches or microbatch_size must be specified when using pipeline parallelism'
assert self.zero_stage <= 1, 'zero stage must be 0 or 1 when using pipeline parallelism'
self.stage_manager = PipelineStageManager(self.pg_mesh, PP_AXIS)
self.schedule = OneForwardOneBackwardSchedule(self.stage_manager,
Expand Down
3 changes: 2 additions & 1 deletion colossalai/pipeline/schedule/one_f_one_b.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def __init__(self,
self.batch: Optional[Any] = None
self.batch_size: Optional[int] = None
self.microbatch_offset: Optional[int] = None
self._use_microbatch_size = num_microbatches is None

def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None:
"""Load a batch from data iterator.
Expand All @@ -60,7 +61,7 @@ def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None)
self.batch = batch
self.batch_size = get_batch_size(batch)
self.microbatch_offset = 0
if self.num_microbatches is not None:
if not self._use_microbatch_size:
assert self.batch_size % self.num_microbatches == 0, \
"Batch size should divided by the number of microbatches"
self.microbatch_size = self.batch_size // self.num_microbatches
Expand Down
163 changes: 130 additions & 33 deletions examples/language/bert/finetune.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import argparse
from typing import List, Union
from contextlib import nullcontext
from typing import Callable, List, Union

import evaluate
import torch
import torch.distributed as dist
import torch.nn as nn
from data import GLUEDataBuilder
from torch.optim import Optimizer
from torch.optim import Adam, Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import (
Expand All @@ -18,8 +20,9 @@

import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.cluster import DistCoordinator
from colossalai.lazy import LazyInitContext
from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device

Expand All @@ -32,38 +35,93 @@
WEIGHT_DECAY = 0.01
WARMUP_FRACTION = 0.1

output_transform_fn = lambda x: x
criterion = lambda x: x.loss


def move_to_cuda(batch):
return {k: v.cuda() for k, v in batch.items()}


@torch.no_grad()
def evaluate_model(model: nn.Module, test_dataloader: Union[DataLoader, List[DataLoader]], num_labels: int, task_name: str,
eval_splits: List[str], coordinator: DistCoordinator):
def evaluate_model(
model: nn.Module,
optimizer,
criterion,
test_dataloader: Union[DataLoader, List[DataLoader]],
num_labels: int,
task_name: str,
eval_splits: List[str],
booster: Booster,
coordinator: DistCoordinator,
):
metric = evaluate.load("glue", task_name, process_id=coordinator.rank, num_process=coordinator.world_size)
model.eval()

def evaluate_subset(dataloader: DataLoader):
accum_loss = torch.zeros(1, device=get_current_device())
for batch in dataloader:
batch = move_to_cuda(batch)
outputs = model(**batch)
val_loss, logits = outputs[:2]
accum_loss.add_(val_loss)

if num_labels > 1:
preds = torch.argmax(logits, axis=1)
elif num_labels == 1:
preds = logits.squeeze()

labels = batch["labels"]

metric.add_batch(predictions=preds, references=labels)
batch_size = batch["input_ids"].shape[0]
if hasattr(booster.plugin, "stage_manager") and booster.plugin.stage_manager is not None:
pg_mesh = booster.plugin.pg_mesh
pp_group = booster.plugin.pp_group
current_pp_group_ranks = pg_mesh.get_ranks_in_group(pp_group)
current_rank = dist.get_rank()
#TODO pass dataloader to execute_pipeline directly
batch = iter([batch])
outputs = booster.execute_pipeline(batch,
model,
criterion,
optimizer,
return_loss=True,
return_outputs=True)

if booster.plugin.stage_manager.is_last_stage():
val_loss = outputs["loss"]

logits = outputs["outputs"]["logits"]

accum_loss.add_(val_loss)

if num_labels > 1:
preds = torch.argmax(logits, axis=1)
elif num_labels == 1:
preds = logits.squeeze()

dist.broadcast(preds, src=current_rank, group=pp_group)
dist.broadcast(val_loss, src=current_rank, group=pp_group)

metric.add_batch(predictions=preds, references=labels)
elif current_rank in current_pp_group_ranks:
val_loss = torch.empty((1,), device=get_current_device())
preds = torch.empty((batch_size,), dtype=torch.int64, device=get_current_device())

dist.broadcast(preds, src=current_pp_group_ranks[-1], group=pp_group)
dist.broadcast(val_loss, src=current_pp_group_ranks[-1], group=pp_group)

accum_loss.add_(val_loss)
metric.add_batch(predictions=preds, references=labels)

else:
batch = move_to_cuda(batch)
outputs = model(**batch)
val_loss, logits = outputs[:2]
accum_loss.add_(val_loss)

if num_labels > 1:
preds = torch.argmax(logits, axis=1)
elif num_labels == 1:
preds = logits.squeeze()

metric.add_batch(predictions=preds, references=labels)

results = metric.compute()
dist.all_reduce(accum_loss.div_(len(dataloader)))
if coordinator.is_master():
if coordinator.is_master() and results is not None:
results['loss'] = accum_loss.item() / coordinator.world_size

return results

if isinstance(test_dataloader, DataLoader):
Expand All @@ -77,25 +135,43 @@ def evaluate_subset(dataloader: DataLoader):
return final_results


def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, lr_scheduler, train_dataloader: DataLoader,
booster: Booster, coordinator: DistCoordinator):
def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, _criterion: Callable, lr_scheduler: LRScheduler,
train_dataloader: DataLoader, booster: Booster, coordinator: DistCoordinator):

model.train()
with tqdm(train_dataloader, desc=f'Epoch [{epoch + 1}/{NUM_EPOCHS}]', disable=not coordinator.is_master()) as pbar:
is_pp_last_stage = hasattr(
booster.plugin,
"stage_manager") and booster.plugin.stage_manager is not None and booster.plugin.stage_manager.is_last_stage()
with tqdm(train_dataloader,
desc=f'Epoch [{epoch + 1}/{NUM_EPOCHS}]',
disable=not (coordinator.is_master() or is_pp_last_stage)) as pbar:
for batch in pbar:
# Forward pass
batch = move_to_cuda(batch)
outputs = model(**batch)
loss = outputs[0]
if hasattr(booster.plugin, "stage_manager") and booster.plugin.stage_manager is not None:
#TODO pass train_dataloader to execute_pipeline directly
batch = iter([batch])
outputs = booster.execute_pipeline(batch,
model,
_criterion,
optimizer,
return_loss=True,
return_outputs=True)
# Backward and optimize
if booster.plugin.stage_manager.is_last_stage():
loss = outputs['loss']
pbar.set_postfix({'loss': loss.item()})
else:
outputs = model(**batch)
loss = _criterion(outputs, None)
# Backward
booster.backward(loss, optimizer)
pbar.set_postfix({'loss': loss.item()})

# Backward and optimize
booster.backward(loss, optimizer)
optimizer.step()
optimizer.zero_grad()
lr_scheduler.step()

# Print log info
pbar.set_postfix({'loss': loss.item()})


def main():
# ==============================
Expand All @@ -107,7 +183,7 @@ def main():
'--plugin',
type=str,
default='torch_ddp',
choices=['torch_ddp', 'torch_ddp_fp16', 'gemini', 'low_level_zero'],
choices=['torch_ddp', 'torch_ddp_fp16', 'gemini', 'low_level_zero', 'hybrid_parallel'],
help="plugin to use")
parser.add_argument(
"--model_type",
Expand All @@ -116,6 +192,7 @@ def main():
help="bert or albert",
)
parser.add_argument('--target_f1', type=float, default=None, help="target f1 score. Raise exception if not reached")
parser.add_argument('--use_lazy_init', type=bool, default=False, help="for initiating lazy init context")
args = parser.parse_args()

if args.model_type == 'bert':
Expand Down Expand Up @@ -145,6 +222,17 @@ def main():
plugin = GeminiPlugin(placement_policy='cuda', strict_ddp_mode=True, initial_scale=2**5)
elif args.plugin == 'low_level_zero':
plugin = LowLevelZeroPlugin(initial_scale=2**5)
elif args.plugin == 'hybrid_parallel':

# modify the param accordingly for finetuning test cases
plugin = HybridParallelPlugin(tp_size=1,
pp_size=2,
num_microbatches=None,
microbatch_size=1,
enable_all_optimization=True,
zero_stage=1,
precision='fp16',
initial_scale=1)

booster = Booster(plugin=plugin, **booster_kwargs)

Expand All @@ -165,8 +253,9 @@ def main():
# bert pretrained model

cfg = AutoConfig.from_pretrained(model_name, num_labels=data_builder.num_labels)

if model_name == "bert-base-uncased":
model = BertForSequenceClassification.from_pretrained(model_name, config=cfg)
model = BertForSequenceClassification.from_pretrained(model_name, config=cfg).cuda()
elif model_name == "albert-xxlarge-v2":
model = AlbertForSequenceClassification.from_pretrained(model_name, config=cfg)
else:
Expand Down Expand Up @@ -196,19 +285,27 @@ def main():
num_training_steps=total_steps,
)

def _criterion(outputs, inputs):
outputs = output_transform_fn(outputs)
loss = criterion(outputs)
return loss

# ==============================
# Boost with ColossalAI
# ==============================
model, optimizer, _, _, lr_scheduler = booster.boost(model, optimizer, lr_scheduler=lr_scheduler)
model, optimizer, _criterion, _, lr_scheduler = booster.boost(model,
optimizer,
criterion=_criterion,
lr_scheduler=lr_scheduler)

# ==============================
# Train model
# ==============================
for epoch in range(NUM_EPOCHS):
train_epoch(epoch, model, optimizer, lr_scheduler, train_dataloader, booster, coordinator)
train_epoch(epoch, model, optimizer, _criterion, lr_scheduler, train_dataloader, booster, coordinator)

results = evaluate_model(model, test_dataloader, data_builder.num_labels, args.task, data_builder.eval_splits,
coordinator)
results = evaluate_model(model, optimizer, _criterion, test_dataloader, data_builder.num_labels, args.task,
data_builder.eval_splits, booster, coordinator)

if coordinator.is_master():
print(results)
Expand Down
2 changes: 1 addition & 1 deletion examples/language/bert/test_ci.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@ set -xe

pip install -r requirements.txt

for plugin in "torch_ddp" "torch_ddp_fp16" "gemini" "low_level_zero"; do
for plugin in "torch_ddp" "torch_ddp_fp16" "gemini" "low_level_zero" "hybrid_parallel"; do
torchrun --standalone --nproc_per_node 4 finetune.py --target_f1 0.86 --plugin $plugin --model_type "bert"
done

0 comments on commit 0a94fcd

Please sign in to comment.