Skip to content

Commit

Permalink
Merge pull request #5668 from siddhu001/Fix_transformers_load
Browse files Browse the repository at this point in the history
Fix loading pre-trained model from transformers
  • Loading branch information
mergify[bot] committed Feb 21, 2024
2 parents a50d6a0 + 8ca64cb commit 6834444
Showing 1 changed file with 14 additions and 2 deletions.
16 changes: 14 additions & 2 deletions espnet2/tasks/abs_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -2089,8 +2089,20 @@ def build_model_from_file(
}
model.load_state_dict(state_dict, strict=not use_lora)
else:
raise
if any(["postdecoder" in k for k in state_dict.keys()]):
model.load_state_dict(
state_dict,
strict=False,
)
else:
raise
else:
raise
if any(["postdecoder" in k for k in state_dict.keys()]):
model.load_state_dict(
state_dict,
strict=False,
)
else:
raise

return model, args

0 comments on commit 6834444

Please sign in to comment.