Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Commit

Permalink
t5 init fix (#4897)
Browse files Browse the repository at this point in the history
  • Loading branch information
klshuster committed Nov 30, 2022
1 parent 07ba788 commit 3005c9e
Showing 1 changed file with 12 additions and 3 deletions.
15 changes: 12 additions & 3 deletions parlai/agents/hugging_face/t5.py
Expand Up @@ -42,9 +42,18 @@ def build_t5(opt: Opt) -> T5ForConditionalGeneration:
if not check_hf_version(HF_VERSION):
raise RuntimeError('Must use transformers package >= 4.3 to use t5')
torch_dtype = torch.float16 if opt['fp16'] else torch.float32
return T5ForConditionalGeneration.from_pretrained(
opt['t5_model_arch'], dropout_rate=opt['t5_dropout'], torch_dtype=torch_dtype
)
try:
return T5ForConditionalGeneration.from_pretrained(
opt['t5_model_arch'],
dropout_rate=opt['t5_dropout'],
torch_dtype=torch_dtype,
)
except TypeError:
# it's not clear when HF added the `torch_dtype` option, but it is not
# available in 4.3.3, which is the earliest we support.
return T5ForConditionalGeneration.from_pretrained(
opt['t5_model_arch'], dropout_rate=opt['t5_dropout']
)


def set_device(func):
Expand Down

0 comments on commit 3005c9e

Please sign in to comment.