Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Flax params dtype #13098

Merged
merged 36 commits into from
Nov 11, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
0a4e66d
fix inits
patil-suraj Aug 12, 2021
e568b02
fix embed dtype
patil-suraj Aug 12, 2021
f3fc704
fix embed dtype
patil-suraj Aug 12, 2021
142a665
add test to check default dtype
patil-suraj Aug 12, 2021
d59bab1
quality
patil-suraj Aug 12, 2021
1e5feea
add type conversion methods for flax models
patil-suraj Aug 12, 2021
f6f92b0
more robust casting
patil-suraj Sep 14, 2021
a7e1ddc
cast sinusoidal positions
patil-suraj Sep 15, 2021
d3209c4
update pegasus
patil-suraj Sep 15, 2021
3a92bc8
update albert
patil-suraj Sep 15, 2021
622f215
update test
patil-suraj Sep 15, 2021
2f0370d
make sure dtype is passed to every module
patil-suraj Sep 15, 2021
9f62f57
style
patil-suraj Sep 16, 2021
cc6e655
fix electra dense
patil-suraj Sep 16, 2021
a804880
fix t5
patil-suraj Sep 16, 2021
ae2b6d1
quality
patil-suraj Sep 16, 2021
633c0a8
add more tests
patil-suraj Sep 16, 2021
65c2455
better name
patil-suraj Sep 16, 2021
9dfb565
use the dtype for lm head computation
patil-suraj Sep 16, 2021
623b47f
Merge branch 'master' into fix-flax-dtype
patil-suraj Nov 10, 2021
fa39008
fix albert
patil-suraj Nov 10, 2021
6c861aa
style
patil-suraj Nov 10, 2021
f594571
fix albert embed dtype
patil-suraj Nov 10, 2021
084b7cb
more tests
patil-suraj Nov 10, 2021
9384382
fix vision enc-dec
patil-suraj Nov 10, 2021
4b66ae4
cleanup
patil-suraj Nov 10, 2021
52716d6
fix embed dtype pegasus
patil-suraj Nov 10, 2021
8e0572b
fix default param test
patil-suraj Nov 10, 2021
347f287
doc
patil-suraj Nov 10, 2021
d26afc0
update template
patil-suraj Nov 10, 2021
2b4c001
fix final_logits_bias dtype
patil-suraj Nov 10, 2021
785a3e5
Apply suggestions from code review
patil-suraj Nov 11, 2021
a17463b
fix doc
patil-suraj Nov 11, 2021
2a993df
fix doc
patil-suraj Nov 11, 2021
6dcffb8
add detailed docstring for dtype parameter
patil-suraj Nov 11, 2021
fe8614c
remove un-necessary import
patil-suraj Nov 11, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,13 @@ def setup(self):
self.visual_projection = nn.Dense(
self.projection_dim,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(0.02, dtype=self.dtype),
kernel_init=jax.nn.initializers.normal(0.02),
use_bias=False,
)
self.text_projection = nn.Dense(
self.projection_dim,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(0.02, dtype=self.dtype),
kernel_init=jax.nn.initializers.normal(0.02),
use_bias=False,
)
self.logit_scale = self.param("logit_scale", jax.nn.initializers.ones, [])
Expand Down
131 changes: 130 additions & 1 deletion src/transformers/modeling_flax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import os
from functools import partial
from pickle import UnpicklingError
from typing import Dict, Set, Tuple, Union
from typing import Any, Dict, Set, Tuple, Union

import flax.linen as nn
import jax
Expand Down Expand Up @@ -154,6 +154,122 @@ def params(self, params: Union[Dict, FrozenDict]):
)
self._params = params

def _cast_floating_to(self, params: Union[Dict, FrozenDict], dtype: jnp.dtype, mask: Any = None) -> Any:
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved
"""
Helper method to cast floating-point values of given parameter ``PyTree`` to given ``dtype``.
"""

# taken from https://github.com/deepmind/jmp/blob/3a8318abc3292be38582794dbf7b094e6583b192/jmp/_src/policy.py#L27
def conditional_cast(param):
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(param, jnp.ndarray) and jnp.issubdtype(param.dtype, jnp.floating):
param = param.astype(dtype)
return param

if mask is None:
return jax.tree_map(conditional_cast, params)

flat_params = flatten_dict(params)
flat_mask, _ = jax.tree_flatten(mask)

for masked, key in zip(flat_mask, flat_params.keys()):
if masked:
param = flat_params[key]
flat_params[key] = conditional_cast(param)

return unflatten_dict(flat_params)

def to_bf16(self, params: Union[Dict, FrozenDict], mask: Any = None):
r"""
Cast the floating-point ``params`` to ``jax.numpy.bfloat16``. This returns a new ``params`` tree and does not
cast the ``params`` in place.

This method can be used on TPU to explicitly convert the model parameters to bfloat16 precision to do full
half-precision training or to save weights in bfloat16 for inference in order to save memory and improve speed.

Arguments:
params (:obj:`Union[Dict, FrozenDict]`):
A ``PyTree`` of model parameters.
mask (:obj:`Union[Dict, FrozenDict]`):
A ``PyTree`` with same structure as the ``params`` tree. The leaves should be booleans, :obj:`True` for
params you want to cast, and should be :obj:`False` for those you want to skip.

Examples::

>>> from transformers import FlaxBertModel
>>> # load model
>>> model = FlaxBertModel.from_pretrained('bert-base-cased')
>>> # By default, the model parameters will be in fp32 precision, to cast these to bfloat16 precision
>>> model.params = model.to_bf16(model.params)
>>> # If you want don't want to cast certain parameters (for example layer norm bias and scale)
patil-suraj marked this conversation as resolved.
Show resolved Hide resolved
>>> # then pass the mask as follows
>>> from flax import traverse_util
>>> model = FlaxBertModel.from_pretrained('bert-base-cased')
patil-suraj marked this conversation as resolved.
Show resolved Hide resolved
>>> flat_params = traverse_util.flatten_dict(model.params)
>>> mask = {path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale")) for path in flat_params}
>>> mask = traverse_util.unflatten_dict(mask)
>>> model.params = model.to_bf16(model.params, mask)
"""
return self._cast_floating_to(params, jnp.bfloat16, mask)

def to_fp32(self, params: Union[Dict, FrozenDict], mask: Any = None):
r"""
Cast the floating-point ``parmas`` to ``jax.numpy.float32``. This method can be used to explicitly convert the
model parameters to fp32 precision. This returns a new ``params`` tree and does not cast the ``params`` in
place.

Arguments:
params (:obj:`Union[Dict, FrozenDict]`):
A ``PyTree`` of model parameters.
mask (:obj:`Union[Dict, FrozenDict]`):
A ``PyTree`` with same structure as the ``params`` tree. The leaves should be booleans, :obj:`True` for
params you want to cast, and should be :obj:`False` for those you want to skip

Examples::

>>> from transformers import FlaxBertModel
>>> # Download model and configuration from huggingface.co
>>> model = FlaxBertModel.from_pretrained('bert-base-cased')
>>> # By default, the model params will be in fp32, to illustrate the use of this method,
>>> # we'll first cast to fp16 and back to fp32
>>> model.params = model.to_f16(model.params)
>>> # now cast back to fp32
>>> model.params = model.to_fp32(model.params)
"""
return self._cast_floating_to(params, jnp.float32, mask)

def to_fp16(self, params: Union[Dict, FrozenDict], mask: Any = None):
r"""
Cast the floating-point ``parmas`` to ``jax.numpy.float16``. This returns a new ``params`` tree and does not
cast the ``params`` in place.

This method can be used on GPU to explicitly convert the model parameters to float16 precision to do full
half-precision training or to save weights in float16 for inference in order to save memory and improve speed.

Arguments:
params (:obj:`Union[Dict, FrozenDict]`):
A ``PyTree`` of model parameters.
mask (:obj:`Union[Dict, FrozenDict]`):
A ``PyTree`` with same structure as the ``params`` tree. The leaves should be booleans, :obj:`True` for
params you want to cast, and should be :obj:`False` for those you want to skip

Examples::

>>> from transformers import FlaxBertModel
>>> # Download model and configuration from huggingface.co
>>> model = FlaxBertModel.from_pretrained('bert-base-cased')
>>> # By default, the model params will be in fp32, to cast these to float16
>>> model.params = model.to_f16(model.params)
>>> # If you want don't want to cast certain parameters (for example layer norm bias and scale)
>>> # then pass the mask as follows
>>> from flax import traverse_util
>>> model = FlaxBertModel.from_pretrained('bert-base-cased')
>>> flat_params = traverse_util.flatten_dict(model.params)
>>> mask = {path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale")) for path in flat_params}
>>> mask = traverse_util.unflatten_dict(mask)
>>> model.params = model.to_f16(model.params, mask)
"""
return self._cast_floating_to(params, jnp.float16, mask)

@classmethod
def from_pretrained(
cls,
Expand Down Expand Up @@ -184,6 +300,19 @@ def from_pretrained(
:func:`~transformers.FlaxPreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``.
- A path or url to a `pt index checkpoint file` (e.g, ``./tf_model/model.ckpt.index``). In this
case, ``from_pt`` should be set to :obj:`True`.
dtype (:obj:`jax.numpy.dtype`, `optional`, defaults to :obj:`jax.numpy.float32`):
patil-suraj marked this conversation as resolved.
Show resolved Hide resolved
The data type of the computation. Can be one of :obj:`jax.numpy.float32`, :obj:`jax.numpy.float16` (on
GPUs) and :obj:`jax.numpy.bfloat16` (on TPUs).

This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
specified all the computation will be performed with the given ``dtype``.

**Note that this only specifies the dtype of the computation and does not influence the dtype of model
patil-suraj marked this conversation as resolved.
Show resolved Hide resolved
parameters.**

If you wish to change the dtype of the model parameters, see
:meth:`~transformers.FlaxPreTrainedModel.to_fp16` and
:meth:`~transformers.FlaxPreTrainedModel.to_bf16`.
model_args (sequence of positional arguments, `optional`):
All remaining positional arguments will be passed to the underlying model's ``__init__`` method.
config (:obj:`Union[PretrainedConfig, str, os.PathLike]`, `optional`):
Expand Down
31 changes: 20 additions & 11 deletions src/transformers/models/albert/modeling_flax_albert.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,18 @@ class FlaxAlbertForPreTrainingOutput(ModelOutput):
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the
model weights.
dtype (:obj:`jax.numpy.dtype`, `optional`, defaults to :obj:`jax.numpy.float32`):
The data type of the computation. Can be one of :obj:`jax.numpy.float32`, :obj:`jax.numpy.float16` (on
GPUs) and :obj:`jax.numpy.bfloat16` (on TPUs).

This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
specified all the computation will be performed with the given ``dtype``.

**Note that this only specifies the dtype of the computation and does not influence the dtype of model
parameters.**

If you wish to change the dtype of the model parameters, see
:meth:`~transformers.FlaxPreTrainedModel.to_fp16` and :meth:`~transformers.FlaxPreTrainedModel.to_bf16`.
"""

ALBERT_INPUTS_DOCSTRING = r"""
Expand Down Expand Up @@ -152,19 +164,16 @@ def setup(self):
self.config.vocab_size,
self.config.embedding_size,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
dtype=self.dtype,
)
self.position_embeddings = nn.Embed(
self.config.max_position_embeddings,
self.config.embedding_size,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
dtype=self.dtype,
)
self.token_type_embeddings = nn.Embed(
self.config.type_vocab_size,
self.config.embedding_size,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
dtype=self.dtype,
)
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
Expand Down Expand Up @@ -199,21 +208,21 @@ def setup(self):
self.query = nn.Dense(
self.config.hidden_size,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
)
self.key = nn.Dense(
self.config.hidden_size,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
)
self.value = nn.Dense(
self.config.hidden_size,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
)
self.dense = nn.Dense(
self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype,
)
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
Expand Down Expand Up @@ -278,13 +287,13 @@ def setup(self):
self.attention = FlaxAlbertSelfAttention(self.config, dtype=self.dtype)
self.ffn = nn.Dense(
self.config.intermediate_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype,
)
self.activation = ACT2FN[self.config.hidden_act]
self.ffn_output = nn.Dense(
self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype,
)
self.full_layer_layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
Expand Down Expand Up @@ -437,7 +446,7 @@ class FlaxAlbertEncoder(nn.Module):
def setup(self):
self.embedding_hidden_mapping_in = nn.Dense(
self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype,
)
self.albert_layer_groups = FlaxAlbertLayerGroups(self.config, dtype=self.dtype)
Expand Down Expand Up @@ -596,7 +605,7 @@ def setup(self):
if self.add_pooling_layer:
self.pooler = nn.Dense(
self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype,
name="pooler",
)
Expand Down
Loading