Skip to content

Commit

Permalink
deepspeed save model temp fix (#374)
Browse files Browse the repository at this point in the history
* fix deepspeed model saving

* fix deepspeed zero stage-3 model save

fixes #369

Co-Authored-By: Kovvuri Satyanarayana Reddy <54667784+KOVVURISATYANARAYANAREDDY@users.noreply.github.com>

Co-authored-by: Kovvuri Satyanarayana Reddy <54667784+KOVVURISATYANARAYANAREDDY@users.noreply.github.com>
  • Loading branch information
pacman100 and KOVVURISATYANARAYANAREDDY committed May 19, 2022
1 parent d33dc39 commit 6163e20
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -886,9 +886,10 @@ def get_state_dict(self, model):
model = self.unwrap_model(model)
state_dict = model.state_dict()

for k in state_dict:
if state_dict[k].dtype == torch.float16:
state_dict[k] = state_dict[k].float()
if state_dict is not None:
for k in state_dict:
if state_dict[k].dtype == torch.float16:
state_dict[k] = state_dict[k].float()

return state_dict

Expand Down

0 comments on commit 6163e20

Please sign in to comment.