Skip to content

Commit

Permalink
Merge branch 'master' into loadams/switch-hf-home
Browse files Browse the repository at this point in the history
  • Loading branch information
loadams committed Apr 22, 2024
2 parents 630d329 + c292b03 commit 618a6a3
Showing 1 changed file with 10 additions and 24 deletions.
34 changes: 10 additions & 24 deletions deepspeed/checkpoint/ds_to_universal.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import argparse
import glob
import itertools
import multiprocessing
from concurrent.futures import ProcessPoolExecutor
import os
import re
import shutil
Expand Down Expand Up @@ -292,27 +292,18 @@ def get_matched_sub_params_pattern(name_):
return unmatched_patterns


def _get_chunks(l, n):
for i in range(0, len(l), n):
yield l[i:i + n]


def _do_parallel_work(do_work, work_chunks, num_workers):
results = []
if num_workers > 1:
pool = multiprocessing.Pool(num_workers)
results = []
for batch in tqdm.tqdm(work_chunks):
res = pool.map(do_work, batch)
results.extend(res)
pool.close()
pool.join()
with ProcessPoolExecutor(max_workers=num_workers) as executor:
future_list = [executor.submit(do_work, work) for work in work_chunks]
for f in tqdm.tqdm(future_list):
results.append(f.result())
else:
# No parallel pass for unit testing
# We can't create child processes in tests
results = []
for batch in tqdm.tqdm(work_chunks):
res = [do_work(x) for x in batch]
results.extend(res)
for work in tqdm.tqdm(work_chunks):
results.append(do_work(work))
return results


Expand All @@ -321,20 +312,15 @@ def _extract_zero_shard_files(args, ds_checkpoint, temp_dir):
itertools.product(range(ds_checkpoint.pp_degree), range(ds_checkpoint.tp_degree),
range(ds_checkpoint.dp_degree)))
#pprint(f'{_3d_range_list=}')
work_chunks = list(_get_chunks(_3d_range_list, args.num_extract_workers))
#pprint(f'{work_chunks=}')

# extract_zero_shards(temp_dir, ds_checkpoint, _3d_range_list[0])
do_work = partial(extract_zero_shards, temp_dir, ds_checkpoint)
_do_parallel_work(do_work, work_chunks, args.num_extract_workers)
_do_parallel_work(do_work, _3d_range_list, args.num_extract_workers)


def _merge_tp_slice_files(args, ds_checkpoint, slice_shapes, temp_dir):
work_chunks = list(_get_chunks(list(slice_shapes.items()), args.num_merge_workers))
#pprint(work_chunks)
zero_output_folder = os.path.join(args.output_folder, "zero")
do_work = partial(merge_tp_slices, ds_checkpoint, zero_output_folder, temp_dir, ds_checkpoint.tp_degree)
unmatched_patterns_lists = _do_parallel_work(do_work, work_chunks, args.num_merge_workers)
unmatched_patterns_lists = _do_parallel_work(do_work, list(slice_shapes.items()), args.num_merge_workers)

# verify that all patterns were used
# if a pattern was not used by any of the workers, then it was not used at all -> assert/alert
Expand Down

0 comments on commit 618a6a3

Please sign in to comment.