Skip to content
Merged
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
0d01e79
submit changes
ShijieZZZZ Mar 16, 2023
846c017
update format
ShijieZZZZ Mar 16, 2023
d81f494
Merge remote-tracking branch 'origin/master' into dev/shared_param
ShijieZZZZ Mar 23, 2023
b3503c3
Merge branch 'master' into dev/shared_param
tjruwase Mar 24, 2023
c9d8cd7
merge master
ShijieZZZZ Mar 30, 2023
8f8e435
fix fomrat
ShijieZZZZ Mar 30, 2023
cb48aa3
Merge branch 'master' into dev/shared_param
ShijieZZZZ Mar 31, 2023
77d0bdd
Merge remote-tracking branch 'remotes/origin/master' into dev/shared_…
ShijieZZZZ Mar 31, 2023
19292ff
revert
ShijieZZZZ Apr 1, 2023
23c9eb6
test
ShijieZZZZ Apr 1, 2023
570ce8d
add top
ShijieZZZZ Apr 1, 2023
55a63a1
Merge branch 'master' into dev/shared_param
tjruwase Apr 4, 2023
ac764af
Merge branch 'master' into dev/shared_param
tjruwase Apr 5, 2023
e20c60f
Merge remote-tracking branch 'remotes/origin/master' into dev/shared_…
ShijieZZZZ Apr 5, 2023
a40530a
treat z1 as z2
ShijieZZZZ Apr 5, 2023
9229c65
Merge remote-tracking branch 'origin/master' into dev/fp32
ShijieZZZZ Apr 5, 2023
7330cb0
Merge remote-tracking branch 'remotes/origin/master' into dev/fp32
ShijieZZZZ Apr 20, 2023
e4fe467
fix shared
ShijieZZZZ Apr 20, 2023
c01ebfd
Merge remote-tracking branch 'remotes/origin/master' into dev/fp32
ShijieZZZZ Apr 20, 2023
1f02e6b
remove old changes
ShijieZZZZ Apr 20, 2023
04a2290
Merge branch 'master' into dev/fp32
ShijieZZZZ Apr 21, 2023
7608832
Merge branch 'master' into dev/fp32
tjruwase Apr 24, 2023
7538583
Merge branch 'master' into dev/fp32
tjruwase Apr 25, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions deepspeed/utils/zero_to_fp32.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def parse_model_states(files):
shared_params = []
for param in state_dict["module"]:
if param not in [*param_names, *buffer_names]:
for share_param in state_dict["module"]:
for share_param in [*param_names, *buffer_names]:
if (state_dict["module"][share_param].data_ptr() == state_dict["module"][param].data_ptr()
and share_param != param):
shared_params.append([param, share_param])
Expand Down Expand Up @@ -340,7 +340,8 @@ def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zer

# recover shared parameters
for pair in zero_model_states[0].shared_params:
state_dict[pair[0]] = state_dict[pair[1]]
if pair[1] in state_dict:
state_dict[pair[0]] = state_dict[pair[1]]

return state_dict

Expand Down Expand Up @@ -457,7 +458,8 @@ def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zer

# recover shared parameters
for pair in zero_model_states[0].shared_params:
state_dict[pair[0]] = state_dict[pair[1]]
if pair[1] in state_dict:
state_dict[pair[0]] = state_dict[pair[1]]

return state_dict

Expand Down