Skip to content

Commit

Permalink
Tapas tf (#13393)
Browse files Browse the repository at this point in the history
* TF Tapas first commit

* updated docs

* updated logger message

* updated pytorch weight conversion
script to support scalar array

* added use_cache to tapas model config to
work properly with tf input_processing

* 1. rm embeddings_sum
2. added # Copied
3. + TFTapasMLMHead
4. and lot other small fixes

* updated docs

* + test for tapas

* updated testing_utils to check
is_tensorflow_probability_available

* converted model logits post processing using
numpy to work with both PT and TF models

* + TFAutoModelForTableQuestionAnswering

* added TF support

* added test for
TFAutoModelForTableQuestionAnswering

* added test for
TFAutoModelForTableQuestionAnswering pipeline

* updated auto model docs

* fixed typo in import

* added tensorflow_probability to run tests

* updated MLM head

* updated tapas.rst with TF  model docs

* fixed optimizer import in docs

* updated convert to np
data from pt model is not
`transformers.tokenization_utils_base.BatchEncoding`
after pipeline upgrade

* updated pipeline:
1. with torch.no_gard removed, pipeline forward handles
2. token_type_ids converted to numpy

* updated docs.

* removed `use_cache` from config

* removed floats_tensor

* updated code comment

* updated Copyright Year and
logits_aggregation Optional

* updated docs and comments

* updated docstring

* fixed model weight loading

* make fixup

* fix indentation

* added tf slow pipeline test

* pip upgrade

* upgrade python to 3.7

* removed from_pt from tests

* revert commit f18cfa9
  • Loading branch information
kamalkraj committed Nov 30, 2021
1 parent 6fc38ad commit c468a87
Show file tree
Hide file tree
Showing 19 changed files with 4,327 additions and 78 deletions.
8 changes: 7 additions & 1 deletion .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ jobs:
run_tests_torch_and_tf:
working_directory: ~/transformers
docker:
- image: circleci/python:3.6
- image: circleci/python:3.7
environment:
OMP_NUM_THREADS: 1
RUN_PT_TF_CROSS_TESTS: yes
Expand All @@ -82,6 +82,7 @@ jobs:
- run: pip install --upgrade pip
- run: pip install .[sklearn,tf-cpu,torch,testing,sentencepiece,torch-speech,vision]
- run: pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.10.0+cpu.html
- run: pip install tensorflow_probability
- save_cache:
key: v0.4-{{ checksum "setup.py" }}
paths:
Expand Down Expand Up @@ -118,6 +119,7 @@ jobs:
- run: pip install --upgrade pip
- run: pip install .[sklearn,tf-cpu,torch,testing,sentencepiece,torch-speech,vision]
- run: pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.10.0+cpu.html
- run: pip install tensorflow_probability
- save_cache:
key: v0.4-{{ checksum "setup.py" }}
paths:
Expand Down Expand Up @@ -278,6 +280,7 @@ jobs:
- v0.4-{{ checksum "setup.py" }}
- run: pip install --upgrade pip
- run: pip install .[sklearn,tf-cpu,testing,sentencepiece,tf-speech,vision]
- run: pip install tensorflow_probability
- save_cache:
key: v0.4-tf-{{ checksum "setup.py" }}
paths:
Expand Down Expand Up @@ -311,6 +314,7 @@ jobs:
- v0.4-{{ checksum "setup.py" }}
- run: pip install --upgrade pip
- run: pip install .[sklearn,tf-cpu,testing,sentencepiece,tf-speech,vision]
- run: pip install tensorflow_probability
- save_cache:
key: v0.4-tf-{{ checksum "setup.py" }}
paths:
Expand Down Expand Up @@ -468,6 +472,7 @@ jobs:
- v0.4-{{ checksum "setup.py" }}
- run: pip install --upgrade pip
- run: pip install .[sklearn,tf-cpu,testing,sentencepiece]
- run: pip install tensorflow_probability
- save_cache:
key: v0.4-tf-{{ checksum "setup.py" }}
paths:
Expand Down Expand Up @@ -502,6 +507,7 @@ jobs:
- v0.4-{{ checksum "setup.py" }}
- run: pip install --upgrade pip
- run: pip install .[sklearn,tf-cpu,testing,sentencepiece]
- run: pip install tensorflow_probability
- save_cache:
key: v0.4-tf-{{ checksum "setup.py" }}
paths:
Expand Down
2 changes: 1 addition & 1 deletion docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,7 @@ Flax), PyTorch, and/or TensorFlow.
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| T5 ||||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| TAPAS |||| ||
| TAPAS |||| ||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| Transformer-XL ||||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
Expand Down
7 changes: 7 additions & 0 deletions docs/source/model_doc/auto.rst
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,13 @@ TFAutoModelForMultipleChoice
:members:


TFAutoModelForTableQuestionAnswering
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.TFAutoModelForTableQuestionAnswering
:members:


TFAutoModelForTokenClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
231 changes: 224 additions & 7 deletions docs/source/model_doc/tapas.rst

Large diffs are not rendered by default.

22 changes: 22 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1446,6 +1446,7 @@
"TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING",
"TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING",
"TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING",
"TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING",
"TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
"TF_MODEL_MAPPING",
"TF_MODEL_WITH_LM_HEAD_MAPPING",
Expand All @@ -1458,6 +1459,7 @@
"TFAutoModelForQuestionAnswering",
"TFAutoModelForSeq2SeqLM",
"TFAutoModelForSequenceClassification",
"TFAutoModelForTableQuestionAnswering",
"TFAutoModelForTokenClassification",
"TFAutoModelWithLMHead",
]
Expand Down Expand Up @@ -1767,6 +1769,16 @@
"TFT5PreTrainedModel",
]
)
_import_structure["models.tapas"].extend(
[
"TF_TAPAS_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFTapasForMaskedLM",
"TFTapasForQuestionAnswering",
"TFTapasForSequenceClassification",
"TFTapasModel",
"TFTapasPreTrainedModel",
]
)
_import_structure["models.transfo_xl"].extend(
[
"TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST",
Expand Down Expand Up @@ -3225,6 +3237,7 @@
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
TF_MODEL_MAPPING,
TF_MODEL_WITH_LM_HEAD_MAPPING,
Expand All @@ -3237,6 +3250,7 @@
TFAutoModelForQuestionAnswering,
TFAutoModelForSeq2SeqLM,
TFAutoModelForSequenceClassification,
TFAutoModelForTableQuestionAnswering,
TFAutoModelForTokenClassification,
TFAutoModelWithLMHead,
)
Expand Down Expand Up @@ -3483,6 +3497,14 @@
TFT5Model,
TFT5PreTrainedModel,
)
from .models.tapas import (
TF_TAPAS_PRETRAINED_MODEL_ARCHIVE_LIST,
TFTapasForMaskedLM,
TFTapasForQuestionAnswering,
TFTapasForSequenceClassification,
TFTapasModel,
TFTapasPreTrainedModel,
)
from .models.transfo_xl import (
TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST,
TFAdaptiveEmbedding,
Expand Down
19 changes: 19 additions & 0 deletions src/transformers/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,14 @@
_soundfile_available = False


_tensorflow_probability_available = importlib.util.find_spec("tensorflow_probability") is not None
try:
_tensorflow_probability_version = importlib_metadata.version("tensorflow_probability")
logger.debug(f"Successfully imported tensorflow-probability version {_tensorflow_probability_version}")
except importlib_metadata.PackageNotFoundError:
_tensorflow_probability_available = False


_timm_available = importlib.util.find_spec("timm") is not None
try:
_timm_version = importlib_metadata.version("timm")
Expand Down Expand Up @@ -444,6 +452,10 @@ def is_pytorch_quantization_available():
return _pytorch_quantization_available


def is_tensorflow_probability_available():
return _tensorflow_probability_available


def is_pandas_available():
return importlib.util.find_spec("pandas") is not None

Expand Down Expand Up @@ -629,6 +641,12 @@ def wrapper(*args, **kwargs):
`pip install pytorch-quantization --extra-index-url https://pypi.ngc.nvidia.com`
"""

# docstyle-ignore
TENSORFLOW_PROBABILITY_IMPORT_ERROR = """
{0} requires the tensorflow_probability library but it was not found in your environment. You can install it with pip as
explained here: https://github.com/tensorflow/probability.
"""


# docstyle-ignore
PANDAS_IMPORT_ERROR = """
Expand Down Expand Up @@ -684,6 +702,7 @@ def wrapper(*args, **kwargs):
("sentencepiece", (is_sentencepiece_available, SENTENCEPIECE_IMPORT_ERROR)),
("sklearn", (is_sklearn_available, SKLEARN_IMPORT_ERROR)),
("speech", (is_speech_available, SPEECH_IMPORT_ERROR)),
("tensorflow_probability", (is_tensorflow_probability_available, TENSORFLOW_PROBABILITY_IMPORT_ERROR)),
("tf", (is_tf_available, TENSORFLOW_IMPORT_ERROR)),
("timm", (is_timm_available, TIMM_IMPORT_ERROR)),
("tokenizers", (is_tokenizers_available, TOKENIZERS_IMPORT_ERROR)),
Expand Down
4 changes: 3 additions & 1 deletion src/transformers/modeling_tf_pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,9 @@ def load_tf2_weights_in_pytorch_model(pt_model, tf_weights, allow_missing_keys=F
raise e

# logger.warning(f"Initialize PyTorch weight {pt_weight_name}")

# Make sure we have a proper numpy array
if numpy.isscalar(array):
array = numpy.array(array)
new_pt_params_dict[pt_weight_name] = torch.from_numpy(array)
loaded_pt_weights_data_ptr[pt_weight.data_ptr()] = torch.from_numpy(array)
all_tf_weights.discard(pt_weight_name)
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/models/auto/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
"TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING",
"TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING",
"TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING",
"TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING",
"TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
"TF_MODEL_MAPPING",
"TF_MODEL_WITH_LM_HEAD_MAPPING",
Expand All @@ -95,6 +96,7 @@
"TFAutoModelForQuestionAnswering",
"TFAutoModelForSeq2SeqLM",
"TFAutoModelForSequenceClassification",
"TFAutoModelForTableQuestionAnswering",
"TFAutoModelForTokenClassification",
"TFAutoModelWithLMHead",
]
Expand Down Expand Up @@ -189,6 +191,7 @@
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
TF_MODEL_MAPPING,
TF_MODEL_WITH_LM_HEAD_MAPPING,
Expand All @@ -201,6 +204,7 @@
TFAutoModelForQuestionAnswering,
TFAutoModelForSeq2SeqLM,
TFAutoModelForSequenceClassification,
TFAutoModelForTableQuestionAnswering,
TFAutoModelForTokenClassification,
TFAutoModelWithLMHead,
)
Expand Down
27 changes: 27 additions & 0 deletions src/transformers/models/auto/modeling_tf_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
("funnel", ("TFFunnelModel", "TFFunnelBaseModel")),
("dpr", "TFDPRQuestionEncoder"),
("mpnet", "TFMPNetModel"),
("tapas", "TFTapasModel"),
("mbart", "TFMBartModel"),
("marian", "TFMarianModel"),
("pegasus", "TFPegasusModel"),
Expand Down Expand Up @@ -92,6 +93,7 @@
("xlm", "TFXLMWithLMHeadModel"),
("ctrl", "TFCTRLLMHeadModel"),
("electra", "TFElectraForPreTraining"),
("tapas", "TFTapasForMaskedLM"),
("funnel", "TFFunnelForPreTraining"),
("mpnet", "TFMPNetForMaskedLM"),
]
Expand Down Expand Up @@ -124,6 +126,7 @@
("xlm", "TFXLMWithLMHeadModel"),
("ctrl", "TFCTRLLMHeadModel"),
("electra", "TFElectraForMaskedLM"),
("tapas", "TFTapasForMaskedLM"),
("funnel", "TFFunnelForMaskedLM"),
("mpnet", "TFMPNetForMaskedLM"),
]
Expand Down Expand Up @@ -172,6 +175,7 @@
("flaubert", "TFFlaubertWithLMHeadModel"),
("xlm", "TFXLMWithLMHeadModel"),
("electra", "TFElectraForMaskedLM"),
("tapas", "TFTapasForMaskedLM"),
("funnel", "TFFunnelForMaskedLM"),
("mpnet", "TFMPNetForMaskedLM"),
]
Expand Down Expand Up @@ -215,6 +219,7 @@
("flaubert", "TFFlaubertForSequenceClassification"),
("xlm", "TFXLMForSequenceClassification"),
("electra", "TFElectraForSequenceClassification"),
("tapas", "TFTapasForSequenceClassification"),
("funnel", "TFFunnelForSequenceClassification"),
("gpt2", "TFGPT2ForSequenceClassification"),
("mpnet", "TFMPNetForSequenceClassification"),
Expand Down Expand Up @@ -249,6 +254,14 @@
]
)

TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
[
# Model for Table Question Answering mapping
("tapas", "TFTapasForQuestionAnswering"),
]
)


TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
# Model for Token Classification mapping
Expand Down Expand Up @@ -323,6 +336,9 @@
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
)
TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES
)
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES
)
Expand Down Expand Up @@ -402,6 +418,17 @@ class TFAutoModelForQuestionAnswering(_BaseAutoModelClass):
TFAutoModelForQuestionAnswering = auto_class_update(TFAutoModelForQuestionAnswering, head_doc="question answering")


class TFAutoModelForTableQuestionAnswering(_BaseAutoModelClass):
_model_mapping = TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING


TFAutoModelForTableQuestionAnswering = auto_class_update(
TFAutoModelForTableQuestionAnswering,
head_doc="table question answering",
checkpoint_for_example="google/tapas-base-finetuned-wtq",
)


class TFAutoModelForTokenClassification(_BaseAutoModelClass):
_model_mapping = TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING

Expand Down
22 changes: 21 additions & 1 deletion src/transformers/models/tapas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from typing import TYPE_CHECKING

from ...file_utils import _LazyModule, is_torch_available
from ...file_utils import _LazyModule, is_tf_available, is_torch_available


_import_structure = {
Expand All @@ -36,6 +36,15 @@
"TapasPreTrainedModel",
"load_tf_weights_in_tapas",
]
if is_tf_available():
_import_structure["modeling_tf_tapas"] = [
"TF_TAPAS_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFTapasForMaskedLM",
"TFTapasForQuestionAnswering",
"TFTapasForSequenceClassification",
"TFTapasModel",
"TFTapasPreTrainedModel",
]


if TYPE_CHECKING:
Expand All @@ -53,6 +62,17 @@
load_tf_weights_in_tapas,
)

if is_tf_available():
from .modeling_tf_tapas import (
TF_TAPAS_PRETRAINED_MODEL_ARCHIVE_LIST,
TFTapasForMaskedLM,
TFTapasForQuestionAnswering,
TFTapasForSequenceClassification,
TFTapasModel,
TFTapasPreTrainedModel,
)


else:
import sys

Expand Down
Loading

0 comments on commit c468a87

Please sign in to comment.