Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrong weight initialization for TF t5 model #13327

Closed
2 of 3 tasks
danshirron opened this issue Aug 30, 2021 · 3 comments
Closed
2 of 3 tasks

Wrong weight initialization for TF t5 model #13327

danshirron opened this issue Aug 30, 2021 · 3 comments

Comments

@danshirron
Copy link
Contributor

Environment info

  - `transformers` version: 4.9.2
  - Platform: Linux-4.15.0-142-generic-x86_64-with-glibc2.29
  - Python version: 3.8.10
  - PyTorch version (GPU?): not installed (NA)
  - Tensorflow version (GPU?): 2.5.0 (True)
  - Flax version (CPU?/GPU?/TPU?): 0.3.4 (gpu)
  - Jax version: 0.2.18
  - JaxLib version: 0.1.69
  - Using GPU in script?: Yes
  - Using distributed or parallel set-up in script?: yes

Who can help

@patil-suraj
@patrickvonplaten

Information

Model I am using: Pre-training T5-base

The problem arises when using:

  • the official example scripts: (give details below)
  • my own modified scripts: (give details below): Added to run_mlm.py the t5 data collator and keras adafactor optimizer

The tasks I am working on is:

  • my own task or dataset: Pre-training T5 base with oscar dataset (as in FLAX example)

Expected behavior

Before updating init weights to normal distribution (as in transformers/src/transformers/models/t5/modeling_flax_t5.py) loss stuck at 4.5 (unlike FLAX behaviour). after update of init weights i get same behaviour as in FLAX and reach <2 loss.

Example:
In flax code: class: FlaxT5DenseReluDense: lines 95:,96
wi_init_std = self.config.initializer_factor * (self.config.d_model ** -0.5)
wo_init_std = self.config.initializer_factor * (self.config.d_ff ** -0.5)

In TF code, the default initializer is used. My suggested fix:
wi_initializer = tf.keras.initializers.RandomNormal(mean = 0, stddev = config.initializer_factor * (config.d_model ** -0.5))
wo_initializer = tf.keras.initializers.RandomNormal(mean = 0, stddev = config.initializer_factor * (config.d_ff ** -0.5))
self.wi = tf.keras.layers.Dense(config.d_ff, use_bias=False, name="wi",kernel_initializer=wi_initializer)
self.wo = tf.keras.layers.Dense(config.d_model, use_bias=False, name="wo",kernel_initializer=wo_initializer)

This is relevant for all weights and embeddings initialization.

@patrickvonplaten
Copy link
Contributor

I agree! Would you like to open a PR to fix it? :-)

@danshirron
Copy link
Contributor Author

Will try to do it on coming days

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot closed this as completed Oct 7, 2021
dshirron pushed a commit to dshirron/transformers that referenced this issue Nov 2, 2021
Rocketknight1 pushed a commit that referenced this issue Nov 3, 2021
)

* Fix of issue #13327: Wrong weight initialization for TF t5 model

* run black formatter

* fix typo

* remove my name tag from comments

Co-authored-by: Shirron <dan.shirron@intel.com>
LysandreJik pushed a commit that referenced this issue Nov 16, 2021
)

* Fix of issue #13327: Wrong weight initialization for TF t5 model

* run black formatter

* fix typo

* remove my name tag from comments

Co-authored-by: Shirron <dan.shirron@intel.com>
Albertobegue pushed a commit to Albertobegue/transformers that referenced this issue Jan 27, 2022
… model (huggingface#14241)

* Fix of issue huggingface#13327: Wrong weight initialization for TF t5 model

* run black formatter

* fix typo

* remove my name tag from comments

Co-authored-by: Shirron <dan.shirron@intel.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants