Skip to content

Commit

Permalink
Fixing requirements for TF LM models and use correct model mappings (#…
Browse files Browse the repository at this point in the history
…14372)

* Fixing requirements for TF LM models and use correct model mappings

* make style
  • Loading branch information
Rocketknight1 authored Nov 11, 2021
1 parent 4c35c8d commit 7f20bf0
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 6 deletions.
2 changes: 2 additions & 0 deletions examples/tensorflow/language-modeling/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
datasets >= 1.8.0
sentencepiece != 0.1.92
6 changes: 3 additions & 3 deletions examples/tensorflow/language-modeling/run_clm.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@
from transformers import (
CONFIG_MAPPING,
CONFIG_NAME,
MODEL_FOR_CAUSAL_LM_MAPPING,
TF2_WEIGHTS_NAME,
TF_MODEL_FOR_CAUSAL_LM_MAPPING,
AutoConfig,
AutoTokenizer,
HfArgumentParser,
Expand All @@ -57,8 +57,8 @@


logger = logging.getLogger(__name__)
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
MODEL_CONFIG_CLASSES = list(MODEL_FOR_CAUSAL_LM_MAPPING.keys())
require_version("datasets>=1.8.0", "To fix: pip install -r examples/tensorflow/language-modeling/requirements.txt")
MODEL_CONFIG_CLASSES = list(TF_MODEL_FOR_CAUSAL_LM_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
# endregion

Expand Down
6 changes: 3 additions & 3 deletions examples/tensorflow/language-modeling/run_mlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@
from transformers import (
CONFIG_MAPPING,
CONFIG_NAME,
MODEL_FOR_MASKED_LM_MAPPING,
TF2_WEIGHTS_NAME,
TF_MODEL_FOR_MASKED_LM_MAPPING,
AutoConfig,
AutoTokenizer,
HfArgumentParser,
Expand All @@ -59,8 +59,8 @@


logger = logging.getLogger(__name__)
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
MODEL_CONFIG_CLASSES = list(MODEL_FOR_MASKED_LM_MAPPING.keys())
require_version("datasets>=1.8.0", "To fix: pip install -r examples/tensorflow/language-modeling/requirements.txt")
MODEL_CONFIG_CLASSES = list(TF_MODEL_FOR_MASKED_LM_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)


Expand Down

0 comments on commit 7f20bf0

Please sign in to comment.