Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pipeline/rank_recorder] fix bug when process data before backward | add a tool for multiple ranks debug #1681

Merged
merged 13 commits into from
Oct 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
24 changes: 13 additions & 11 deletions colossalai/pipeline/rpc/_pipeline_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,21 @@
from abc import ABC, abstractmethod
import sys
import os
import time
import inspect

import torch
from torch import nn
import torch.distributed.rpc as rpc
from torch.futures import Future
from torch._C._distributed_rpc import PyRRef

from torch import autograd
from torch import optim

from colossalai.pipeline.pipeline_process_group import ppg
from colossalai.pipeline.rpc.utils import (color_debug, tensor_shape_list, get_batch_lengths, split_batch, type_detail,
pytree_map, get_real_args_kwargs, use_color_debug)
pytree_map, pytree_filter, get_real_args_kwargs, use_color_debug)


class Phase(Enum):
Expand Down Expand Up @@ -469,6 +471,7 @@ def _consume_work_item_by_phase(self, work_item: WorkItem):

else:
consume_result = self.module_partition(*args, **kwargs)

# print(f'model{self.pp_rank + 1}(param_sum: {sum([p.sum().item() for p in self.module_partition.parameters()])}) input sum: {args[0].sum().item()} forward output sum: {consume_result.sum().item()}', )

if is_last_stage and self.criterion:
Expand All @@ -495,7 +498,6 @@ def _consume_work_item_by_phase(self, work_item: WorkItem):
stage_input_kwargs,
stage_outputs,
checkpoint=use_checkpoint)

# if not forward_only, do the backward
if not forward_only:
if is_last_stage: # if it is the last stage, trigger backward automatic
Expand All @@ -521,19 +523,19 @@ def _consume_work_item_by_phase(self, work_item: WorkItem):
if use_checkpoint:
stage_outputs = [self.module_partition(*stage_input_args, **stage_input_kwargs)]

# take tensor only (for only tensor can do backward)
stage_outputs_tensors = []
pytree_map(stage_outputs, stage_outputs_tensors.append, process_types=torch.Tensor)

# overlap recompute and future.wait
grad_tensors = get_real_args_kwargs(args)
if not is_last_stage:
grad_tensors = get_real_args_kwargs(args)
else:
grad_tensors = None

# take tensor only (for only tensor can do backward)
stage_outputs = pytree_filter(lambda x: x.requires_grad, stage_outputs, process_types=torch.Tensor)
grad_tensors = pytree_filter(lambda x: x is not None, grad_tensors, process_types=torch.Tensor)

# print('rank', self.pp_rank, tensor_shape_list(stage_outputs_tensors), tensor_shape_list(grad_tensors))
autograd.backward(stage_outputs_tensors, grad_tensors=grad_tensors)
autograd.backward(stage_outputs, grad_tensors=grad_tensors)

# collect grad of input tensor
# there is a hypothesis that node in kwargs cann't be an non-leaf node in graph
# so we don't need to save the grad of node in kwargs.
consume_result = []
if not is_first_stage:
pytree_map(stage_input_args, lambda x: consume_result.append(x.grad), process_types=torch.Tensor)
Expand Down
2 changes: 1 addition & 1 deletion colossalai/pipeline/rpc/_pipeline_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def __init__(self,
if chunk > 1:
assert num_microbatches % stage_num == 0, \
"if you use interleaving strategy, make sure 'num_microbatches' is a multiple of stage_num!"
assert num_microbatches > stage_num * chunk, "num_microbatches must be greater than stage_num * chunk"
# assert num_microbatches > stage_num * chunk, "num_microbatches must be greater than stage_num * chunk"
use_1F1B = True

super().__init__(OneFOneBWorker, partition_fn, stage_num, num_microbatches, device, use_1F1B, chunk, criterion,
Expand Down
17 changes: 16 additions & 1 deletion colossalai/pipeline/rpc/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ def pytree_map(obj: Any, fn: Callable, process_types: Union[Type, Tuple[Type]] =
Args:
obj (:class:`Any`): object to process
fn (:class:`Callable`): a function to process subobject in obj
process_types(:class: `type | tuple[type]`): types to determine the type to process
process_types (:class: `type | tuple[type]`): types to determine the type to process
map_all (:class: `bool`): if map_all is True, then any type of element will use fn

Returns:
:class:`Any`: returns have the same structure of `obj` and type in process_types after map of `fn`
Expand Down Expand Up @@ -59,6 +60,20 @@ def type_detail(obj):
return pytree_map(obj, lambda x: type(x), map_all=True)


def pytree_filter(fn, obj, process_types):
if obj is None:
return None

filters = []

def condition_append(obj):
if fn(obj):
filters.append(obj)

pytree_map(obj, fn=condition_append, process_types=process_types)
return filters


def get_real_args_kwargs(args_or_kwargs):
args_or_kwargs = pytree_map(args_or_kwargs, fn=lambda x: x.wait(), process_types=Future)
# TODO : combine producer and consumer
Expand Down
72 changes: 72 additions & 0 deletions colossalai/utils/rank_recorder/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Rank Recorder
This is a useful tool to get the records of certain functions in each rank. The records of each rank will dump into a json file after the end of multiple process program. You can parse and visualise the json file easily.

Before using the tool, you should ensure dist.is_initialized() return true before exit of program.

## Usage

Is very simple:

```python
from colossalai.utils.rank_recorder import recorder

...
...

with recorder(record_name, current_rank) as r:
"""procedure to record
"""

```

## Example
This is a demo to display kernel select in cuda and visualise the cost of several procedures in each rank.

```python
import time
import os
import logging
logging.disable(logging.INFO)

import torch
import torch.distributed as dist
import torch.multiprocessing as mp

from colossalai.utils.rank_recorder import recorder


WORLD_SIZE = 4

# config the export image here
# If you want to dive into the detail, format 'svg' is recommended
recorder.export_format = 'png'
recorder.export_name = 'kernel_select'
recorder.dpi = 500

def calc(x, y):
a = torch.randn(x, y).cuda()
b = torch.randn(x, y).cuda()
c = sum(a * b)
return c

def worker(rank):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '29020'
dist.init_process_group(backend='nccl', world_size=WORLD_SIZE, rank=rank)
print(dist.get_rank(), "enter")
time.sleep(0.1 * rank)

with recorder("calc_1(x100)", rank) as r:
calc(100, 100)

with recorder("calc_2(x400)", rank) as r:
calc(400, 400)

with recorder("calc_2(x200)", rank) as r:
calc(200, 200)

if __name__ == "__main__":
mp.spawn(worker, nprocs=WORLD_SIZE)
```

run the script directly and you will get `kernel_select.json` and `kernel_select.png` in your current folder.
3 changes: 3 additions & 0 deletions colossalai/utils/rank_recorder/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from colossalai.utils.rank_recorder.rank_recorder import recorder

__all__ = ["recorder"]
178 changes: 178 additions & 0 deletions colossalai/utils/rank_recorder/rank_recorder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
import time
from typing import List, Dict
import json
import os
import time
import shutil
import atexit

import torch
import torch.distributed as dist

import json
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors

cmap = list(mcolors.TABLEAU_COLORS.values())

LOG_FOLDER = "record.log"
MAX_WAIT_TIME = 20


class Event:

def __init__(self, start: int, end: int, name: str, rank: int) -> None:
self.start = start
self.end = end
self.name = name
self.rank = rank


class Recorder:

def __init__(self) -> None:
self.rank_to_history: Dict[int, List[Event]] = {}
self.base_time = time.time()
self.temp_event = None

self.export_format = 'png'
self.export_name = 'test'
self.dpi = 500
self.theme = 'dark_background'
self.figure_width = 30
self.figure_height = 10
self.legend_fontsize = 16
self.device_fontsize = 20
self.bar_height = 0.2

if not os.path.exists(LOG_FOLDER):
os.makedirs(LOG_FOLDER)

def start(self, name: str, rank: int):
# TODO : add lock to prevent conflict
torch.cuda.synchronize()
start_time = time.time()
self.temp_event = Event(start_time, None, name, rank)

def end(self):
assert self.temp_event is not None, "`start` before `end`"
torch.cuda.synchronize()
end_time = time.time()
self.temp_event.end = end_time
rank = self.temp_event.rank
if rank not in self.rank_to_history:
self.rank_to_history[rank] = []
self.rank_to_history[rank].append(self.temp_event)
self.temp_event = None

def get_history(self):
return self.history

def __call__(self, name: str, rank: str):
self.temp_name = name
self.temp_rank = rank
return self

def __enter__(self):
name = self.temp_name
rank = self.temp_rank
self.start(name, rank)

def __exit__(self, *args):
self.end()

def dump_record(self):
rank = dist.get_rank()
rank_to_history = self.rank_to_history
records = {'base_time': self.base_time, 'content': {}}
for record_rank in rank_to_history:
history = rank_to_history[record_rank]
recs = []
for event in history:
rec = {'start': event.start, 'end': event.end, 'name': event.name}
recs.append(rec)
records['content'][record_rank] = recs

dump_name = f'{rank}.json'
dump_path = os.path.join(LOG_FOLDER, dump_name)
with open(dump_path, 'w', encoding='utf-8') as f:
json.dump(records, f, ensure_ascii=False)

def merge_recode(self):
base_time = self.base_time
world_size = dist.get_world_size()

wait_time = 0
while True:
time.sleep(0.1)
log_num = len(os.listdir(LOG_FOLDER))
if log_num == world_size:
break

wait_time += 1
if wait_time >= MAX_WAIT_TIME:
break

# merge
logs_path = [os.path.join(LOG_FOLDER, file) for file in os.listdir(LOG_FOLDER)]
recoders = {}
for path in logs_path:
with open(path, 'r', encoding='utf-8') as f:
recs = json.load(f)
for record_rank in recs['content']:
history = recs['content'][record_rank]
recoders[record_rank] = []
for rec in history:
recoders[record_rank].append({
'start': rec['start'] - base_time,
'end': rec['end'] - base_time,
'name': rec['name']
})

shutil.rmtree(LOG_FOLDER)
with open(self.export_name + '.json', 'w', encoding='utf-8') as f:
json.dump(recoders, f, ensure_ascii=False)

def visualise_record(self):
with open(self.export_name + '.json', 'r', encoding='utf-8') as f:
records = json.load(f)
records = dict(records)
ranks = list(sorted(records.keys()))

name_list = {}
plots = {}
plt.figure(dpi=self.dpi, figsize=[self.figure_width, self.figure_height])
plt.style.use(self.theme)

for rank in ranks:
rank_records = records[rank]
for rec in rank_records:
s = rec['start']
e = rec['end']
name = rec['name']
if name not in name_list:
name_list[name] = len(name_list)
bar = plt.barh(rank, width=e - s, height=self.bar_height, left=s, color=cmap[name_list[name]])
if name not in plots:
plots[name] = bar

plt.legend(list(plots.values()), list(plots.keys()), loc="upper left", fontsize=self.legend_fontsize)
plt.yticks(ticks=ranks, labels=[f'Device:{rank}' for rank in ranks], fontsize=self.device_fontsize)
plt.grid(axis='x')
plt.savefig("{}.{}".format(self.export_name, self.export_format))

def exit_worker(self):
if len(self.rank_to_history) == 0:
return
self.dump_record()
# if this is rank 0, wait for merge
rank = dist.get_rank()

if rank == 1:
# take the base time of rank 0 as standard
self.merge_recode()
self.visualise_record()


recorder = Recorder()
atexit.register(recorder.exit_worker)