Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
eeb2dfe
Initial commit
stancld Feb 9, 2022
daec2ee
Add TFGPTJModel
stancld Feb 9, 2022
148326d
[WIP] Fix some basic issues
stancld Feb 10, 2022
1e0cf12
Fix a forward pass
stancld Feb 11, 2022
b8bee61
Merge branch 'master' into tf_gpt-j
stancld Feb 11, 2022
ff06081
Add TFGPTJCausalLM
stancld Feb 11, 2022
d9ea032
Add TFGPTJForSequenceClassification
stancld Feb 11, 2022
342a9dd
Add TFGPTJForQuestionAnswering
stancld Feb 11, 2022
f824b25
Fix docs
stancld Feb 11, 2022
a5b0fd4
make fix-copies
stancld Feb 11, 2022
089f4ad
Merge branch 'master' into tf_gpt-j
stancld Feb 11, 2022
d839d10
Add models into the auto factory
stancld Feb 11, 2022
05556c1
Fix shape_list import in a test file
stancld Feb 11, 2022
2866c49
make style
stancld Feb 11, 2022
0eff264
Merge branch 'master' into tf_gpt-j
stancld Feb 12, 2022
25ffba9
Merge branch 'master' into tf_gpt-j
stancld Feb 19, 2022
91e409f
Fix - Unable to create link (name already exists)
stancld Feb 19, 2022
37ea8d8
Fix model compilation
stancld Feb 19, 2022
8302737
Deal with TF dynamic shapes
stancld Feb 19, 2022
ec87443
Add Loss parents to models
stancld Feb 19, 2022
aec1a32
Fix imports
stancld Feb 19, 2022
66c7dc8
Update keys to ignore + fix scale_attn assignment
stancld Feb 19, 2022
76b6174
Fix PT-TF equivalence
stancld Feb 19, 2022
00b4854
Define product of list items due to python<=3.7
stancld Feb 19, 2022
81c1c73
Merge branch 'master' into tf_gpt-j
stancld Mar 11, 2022
1327179
Apply some suggestions from code review
stancld Mar 11, 2022
34e4e44
Remove one import added by unintentionally IDE
stancld Mar 11, 2022
7aec58d
Apply some suggestions from code review
stancld Mar 11, 2022
5332030
[WIP] Add some new slow/tooslow tests
stancld Mar 11, 2022
ad18a5c
.
stancld Mar 11, 2022
878db21
Merge branch 'master' into tf_gpt-j
stancld Mar 11, 2022
730dba4
Add token_type_ids to prepare_inputs_for_generation
stancld Mar 11, 2022
7442fab
Add a type hint and a test
stancld Mar 11, 2022
be9ca28
Move test_batch_generation among TFGPTJModelLanguageGenerationTest tests
stancld Mar 11, 2022
2a12082
Merge remote-tracking branch 'upstream/master' into tf_gpt-j
stancld Mar 11, 2022
947031e
Merge branch 'main' into tf_gpt-j
stancld Mar 25, 2022
25e67de
Resolve some remaining issues
stancld Mar 25, 2022
b101109
Merge remote-tracking branch 'upstream/main' into tf_gpt-j
stancld Mar 25, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/index.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ Flax), PyTorch, and/or TensorFlow.
| Funnel Transformer | ✅ | ✅ | ✅ | ✅ | ❌ |
| GLPN | ❌ | ❌ | ✅ | ❌ | ❌ |
| GPT Neo | ❌ | ❌ | ✅ | ❌ | ✅ |
| GPT-J | ❌ | ❌ | ✅ | | ✅ |
| GPT-J | ❌ | ❌ | ✅ | | ✅ |
| Hubert | ❌ | ❌ | ✅ | ✅ | ❌ |
| I-BERT | ❌ | ❌ | ✅ | ❌ | ❌ |
| ImageGPT | ❌ | ❌ | ✅ | ❌ | ❌ |
Expand Down
20 changes: 20 additions & 0 deletions docs/source/model_doc/gptj.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,26 @@ model.
[[autodoc]] GPTJForQuestionAnswering
- forward

## TFGPTJModel

[[autodoc]] TFGPTJModel
- call

## TFGPTJForCausalLM

[[autodoc]] TFGPTJForCausalLM
- call

## TFGPTJForSequenceClassification

[[autodoc]] TFGPTJForSequenceClassification
- call

## TFGPTJForQuestionAnswering

[[autodoc]] TFGPTJForQuestionAnswering
- call

## FlaxGPTJModel

[[autodoc]] FlaxGPTJModel
Expand Down
16 changes: 16 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1929,6 +1929,15 @@
"TFGPT2PreTrainedModel",
]
)
_import_structure["models.gptj"].extend(
[
"TFGPTJForCausalLM",
"TFGPTJForQuestionAnswering",
"TFGPTJForSequenceClassification",
"TFGPTJModel",
"TFGPTJPreTrainedModel",
]
)
_import_structure["models.hubert"].extend(
[
"TF_HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
Expand Down Expand Up @@ -4003,6 +4012,13 @@
TFGPT2Model,
TFGPT2PreTrainedModel,
)
from .models.gptj import (
TFGPTJForCausalLM,
TFGPTJForQuestionAnswering,
TFGPTJForSequenceClassification,
TFGPTJModel,
TFGPTJPreTrainedModel,
)
from .models.hubert import (
TF_HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFHubertForCTC,
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/models/auto/modeling_tf_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
("bert", "TFBertModel"),
("openai-gpt", "TFOpenAIGPTModel"),
("gpt2", "TFGPT2Model"),
("gptj", "TFGPTJModel"),
("mobilebert", "TFMobileBertModel"),
("transfo-xl", "TFTransfoXLModel"),
("xlnet", "TFXLNetModel"),
Expand Down Expand Up @@ -123,6 +124,7 @@
("bert", "TFBertForMaskedLM"),
("openai-gpt", "TFOpenAIGPTLMHeadModel"),
("gpt2", "TFGPT2LMHeadModel"),
("gptj", "TFGPTJForCausalLM"),
("mobilebert", "TFMobileBertForMaskedLM"),
("transfo-xl", "TFTransfoXLLMHeadModel"),
("xlnet", "TFXLNetLMHeadModel"),
Expand All @@ -146,6 +148,7 @@
("bert", "TFBertLMHeadModel"),
("openai-gpt", "TFOpenAIGPTLMHeadModel"),
("gpt2", "TFGPT2LMHeadModel"),
("gptj", "TFGPTJForCausalLM"),
("transfo-xl", "TFTransfoXLLMHeadModel"),
("xlnet", "TFXLNetLMHeadModel"),
("xlm", "TFXLMWithLMHeadModel"),
Expand Down Expand Up @@ -239,6 +242,7 @@
("tapas", "TFTapasForSequenceClassification"),
("funnel", "TFFunnelForSequenceClassification"),
("gpt2", "TFGPT2ForSequenceClassification"),
("gptj", "TFGPTJForSequenceClassification"),
("mpnet", "TFMPNetForSequenceClassification"),
("openai-gpt", "TFOpenAIGPTForSequenceClassification"),
("transfo-xl", "TFTransfoXLForSequenceClassification"),
Expand Down Expand Up @@ -267,6 +271,7 @@
("xlm", "TFXLMForQuestionAnsweringSimple"),
("electra", "TFElectraForQuestionAnswering"),
("funnel", "TFFunnelForQuestionAnswering"),
("gptj", "TFGPTJForQuestionAnswering"),
("mpnet", "TFMPNetForQuestionAnswering"),
]
)
Expand Down
20 changes: 19 additions & 1 deletion src/transformers/models/gptj/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# limitations under the License.
from typing import TYPE_CHECKING

from ...utils import _LazyModule, is_flax_available, is_torch_available
from ...utils import _LazyModule, is_flax_available, is_tf_available, is_torch_available


_import_structure = {
Expand All @@ -34,6 +34,15 @@
"GPTJPreTrainedModel",
]

if is_tf_available():
_import_structure["modeling_tf_gptj"] = [
"TFGPTJForCausalLM",
"TFGPTJForQuestionAnswering",
"TFGPTJForSequenceClassification",
"TFGPTJModel",
"TFGPTJPreTrainedModel",
]

if is_flax_available():
_import_structure["modeling_flax_gptj"] = [
"FlaxGPTJForCausalLM",
Expand All @@ -55,6 +64,15 @@
GPTJPreTrainedModel,
)

if is_tf_available():
from .modeling_tf_gptj import (
TFGPTJForCausalLM,
TFGPTJForQuestionAnswering,
TFGPTJForSequenceClassification,
TFGPTJModel,
TFGPTJPreTrainedModel,
)

if is_flax_available():
from .modeling_flax_gptj import FlaxGPTJForCausalLM, FlaxGPTJModel, FlaxGPTJPreTrainedModel

Expand Down
Loading