Skip to content

Commit

Permalink
Fix Flax params dtype (#13098)
Browse files Browse the repository at this point in the history
* fix inits

* fix embed dtype

* fix embed dtype

* add test to check default dtype

* quality

* add type conversion methods for flax models

* more robust casting

* cast sinusoidal positions

* update pegasus

* update albert

* update test

* make sure dtype is passed to every module

* style

* fix electra dense

* fix t5

* quality

* add more tests

* better name

* use the dtype for lm head computation

* fix albert

* style

* fix albert embed dtype

* more tests

* fix vision enc-dec

* cleanup

* fix embed dtype pegasus

* fix default param test

* doc

* update template

* fix final_logits_bias dtype

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* fix doc

* fix doc

* add detailed docstring for dtype parameter

* remove un-necessary import

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
  • Loading branch information
patil-suraj and patrickvonplaten authored Nov 11, 2021
1 parent 1c76a51 commit e92190c
Show file tree
Hide file tree
Showing 23 changed files with 731 additions and 262 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,13 @@ def setup(self):
self.visual_projection = nn.Dense(
self.projection_dim,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(0.02, dtype=self.dtype),
kernel_init=jax.nn.initializers.normal(0.02),
use_bias=False,
)
self.text_projection = nn.Dense(
self.projection_dim,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(0.02, dtype=self.dtype),
kernel_init=jax.nn.initializers.normal(0.02),
use_bias=False,
)
self.logit_scale = self.param("logit_scale", jax.nn.initializers.ones, [])
Expand Down
131 changes: 130 additions & 1 deletion src/transformers/modeling_flax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import os
from functools import partial
from pickle import UnpicklingError
from typing import Dict, Set, Tuple, Union
from typing import Any, Dict, Set, Tuple, Union

import flax.linen as nn
import jax
Expand Down Expand Up @@ -154,6 +154,122 @@ def params(self, params: Union[Dict, FrozenDict]):
)
self._params = params

def _cast_floating_to(self, params: Union[Dict, FrozenDict], dtype: jnp.dtype, mask: Any = None) -> Any:
"""
Helper method to cast floating-point values of given parameter ``PyTree`` to given ``dtype``.
"""

# taken from https://github.com/deepmind/jmp/blob/3a8318abc3292be38582794dbf7b094e6583b192/jmp/_src/policy.py#L27
def conditional_cast(param):
if isinstance(param, jnp.ndarray) and jnp.issubdtype(param.dtype, jnp.floating):
param = param.astype(dtype)
return param

if mask is None:
return jax.tree_map(conditional_cast, params)

flat_params = flatten_dict(params)
flat_mask, _ = jax.tree_flatten(mask)

for masked, key in zip(flat_mask, flat_params.keys()):
if masked:
param = flat_params[key]
flat_params[key] = conditional_cast(param)

return unflatten_dict(flat_params)

def to_bf16(self, params: Union[Dict, FrozenDict], mask: Any = None):
r"""
Cast the floating-point ``params`` to ``jax.numpy.bfloat16``. This returns a new ``params`` tree and does not
cast the ``params`` in place.
This method can be used on TPU to explicitly convert the model parameters to bfloat16 precision to do full
half-precision training or to save weights in bfloat16 for inference in order to save memory and improve speed.
Arguments:
params (:obj:`Union[Dict, FrozenDict]`):
A ``PyTree`` of model parameters.
mask (:obj:`Union[Dict, FrozenDict]`):
A ``PyTree`` with same structure as the ``params`` tree. The leaves should be booleans, :obj:`True` for
params you want to cast, and should be :obj:`False` for those you want to skip.
Examples::
>>> from transformers import FlaxBertModel
>>> # load model
>>> model = FlaxBertModel.from_pretrained('bert-base-cased')
>>> # By default, the model parameters will be in fp32 precision, to cast these to bfloat16 precision
>>> model.params = model.to_bf16(model.params)
>>> # If you want don't want to cast certain parameters (for example layer norm bias and scale)
>>> # then pass the mask as follows
>>> from flax import traverse_util
>>> model = FlaxBertModel.from_pretrained('bert-base-cased')
>>> flat_params = traverse_util.flatten_dict(model.params)
>>> mask = {path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale")) for path in flat_params}
>>> mask = traverse_util.unflatten_dict(mask)
>>> model.params = model.to_bf16(model.params, mask)
"""
return self._cast_floating_to(params, jnp.bfloat16, mask)

def to_fp32(self, params: Union[Dict, FrozenDict], mask: Any = None):
r"""
Cast the floating-point ``parmas`` to ``jax.numpy.float32``. This method can be used to explicitly convert the
model parameters to fp32 precision. This returns a new ``params`` tree and does not cast the ``params`` in
place.
Arguments:
params (:obj:`Union[Dict, FrozenDict]`):
A ``PyTree`` of model parameters.
mask (:obj:`Union[Dict, FrozenDict]`):
A ``PyTree`` with same structure as the ``params`` tree. The leaves should be booleans, :obj:`True` for
params you want to cast, and should be :obj:`False` for those you want to skip
Examples::
>>> from transformers import FlaxBertModel
>>> # Download model and configuration from huggingface.co
>>> model = FlaxBertModel.from_pretrained('bert-base-cased')
>>> # By default, the model params will be in fp32, to illustrate the use of this method,
>>> # we'll first cast to fp16 and back to fp32
>>> model.params = model.to_f16(model.params)
>>> # now cast back to fp32
>>> model.params = model.to_fp32(model.params)
"""
return self._cast_floating_to(params, jnp.float32, mask)

def to_fp16(self, params: Union[Dict, FrozenDict], mask: Any = None):
r"""
Cast the floating-point ``parmas`` to ``jax.numpy.float16``. This returns a new ``params`` tree and does not
cast the ``params`` in place.
This method can be used on GPU to explicitly convert the model parameters to float16 precision to do full
half-precision training or to save weights in float16 for inference in order to save memory and improve speed.
Arguments:
params (:obj:`Union[Dict, FrozenDict]`):
A ``PyTree`` of model parameters.
mask (:obj:`Union[Dict, FrozenDict]`):
A ``PyTree`` with same structure as the ``params`` tree. The leaves should be booleans, :obj:`True` for
params you want to cast, and should be :obj:`False` for those you want to skip
Examples::
>>> from transformers import FlaxBertModel
>>> # Download model and configuration from huggingface.co
>>> model = FlaxBertModel.from_pretrained('bert-base-cased')
>>> # By default, the model params will be in fp32, to cast these to float16
>>> model.params = model.to_f16(model.params)
>>> # If you want don't want to cast certain parameters (for example layer norm bias and scale)
>>> # then pass the mask as follows
>>> from flax import traverse_util
>>> model = FlaxBertModel.from_pretrained('bert-base-cased')
>>> flat_params = traverse_util.flatten_dict(model.params)
>>> mask = {path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale")) for path in flat_params}
>>> mask = traverse_util.unflatten_dict(mask)
>>> model.params = model.to_f16(model.params, mask)
"""
return self._cast_floating_to(params, jnp.float16, mask)

@classmethod
def from_pretrained(
cls,
Expand Down Expand Up @@ -184,6 +300,19 @@ def from_pretrained(
:func:`~transformers.FlaxPreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``.
- A path or url to a `pt index checkpoint file` (e.g, ``./tf_model/model.ckpt.index``). In this
case, ``from_pt`` should be set to :obj:`True`.
dtype (:obj:`jax.numpy.dtype`, `optional`, defaults to :obj:`jax.numpy.float32`):
The data type of the computation. Can be one of :obj:`jax.numpy.float32`, :obj:`jax.numpy.float16` (on
GPUs) and :obj:`jax.numpy.bfloat16` (on TPUs).
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
specified all the computation will be performed with the given ``dtype``.
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
parameters.**
If you wish to change the dtype of the model parameters, see
:meth:`~transformers.FlaxPreTrainedModel.to_fp16` and
:meth:`~transformers.FlaxPreTrainedModel.to_bf16`.
model_args (sequence of positional arguments, `optional`):
All remaining positional arguments will be passed to the underlying model's ``__init__`` method.
config (:obj:`Union[PretrainedConfig, str, os.PathLike]`, `optional`):
Expand Down
31 changes: 20 additions & 11 deletions src/transformers/models/albert/modeling_flax_albert.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,18 @@ class FlaxAlbertForPreTrainingOutput(ModelOutput):
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the
model weights.
dtype (:obj:`jax.numpy.dtype`, `optional`, defaults to :obj:`jax.numpy.float32`):
The data type of the computation. Can be one of :obj:`jax.numpy.float32`, :obj:`jax.numpy.float16` (on
GPUs) and :obj:`jax.numpy.bfloat16` (on TPUs).
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
specified all the computation will be performed with the given ``dtype``.
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
parameters.**
If you wish to change the dtype of the model parameters, see
:meth:`~transformers.FlaxPreTrainedModel.to_fp16` and :meth:`~transformers.FlaxPreTrainedModel.to_bf16`.
"""

ALBERT_INPUTS_DOCSTRING = r"""
Expand Down Expand Up @@ -152,19 +164,16 @@ def setup(self):
self.config.vocab_size,
self.config.embedding_size,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
dtype=self.dtype,
)
self.position_embeddings = nn.Embed(
self.config.max_position_embeddings,
self.config.embedding_size,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
dtype=self.dtype,
)
self.token_type_embeddings = nn.Embed(
self.config.type_vocab_size,
self.config.embedding_size,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
dtype=self.dtype,
)
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
Expand Down Expand Up @@ -199,21 +208,21 @@ def setup(self):
self.query = nn.Dense(
self.config.hidden_size,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
)
self.key = nn.Dense(
self.config.hidden_size,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
)
self.value = nn.Dense(
self.config.hidden_size,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
)
self.dense = nn.Dense(
self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype,
)
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
Expand Down Expand Up @@ -278,13 +287,13 @@ def setup(self):
self.attention = FlaxAlbertSelfAttention(self.config, dtype=self.dtype)
self.ffn = nn.Dense(
self.config.intermediate_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype,
)
self.activation = ACT2FN[self.config.hidden_act]
self.ffn_output = nn.Dense(
self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype,
)
self.full_layer_layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
Expand Down Expand Up @@ -437,7 +446,7 @@ class FlaxAlbertEncoder(nn.Module):
def setup(self):
self.embedding_hidden_mapping_in = nn.Dense(
self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype,
)
self.albert_layer_groups = FlaxAlbertLayerGroups(self.config, dtype=self.dtype)
Expand Down Expand Up @@ -596,7 +605,7 @@ def setup(self):
if self.add_pooling_layer:
self.pooler = nn.Dense(
self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype,
name="pooler",
)
Expand Down
Loading

0 comments on commit e92190c

Please sign in to comment.