Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Flax params dtype #13098

Merged
merged 36 commits into from
Nov 11, 2021
Merged
Show file tree
Hide file tree
Changes from 31 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
0a4e66d
fix inits
patil-suraj Aug 12, 2021
e568b02
fix embed dtype
patil-suraj Aug 12, 2021
f3fc704
fix embed dtype
patil-suraj Aug 12, 2021
142a665
add test to check default dtype
patil-suraj Aug 12, 2021
d59bab1
quality
patil-suraj Aug 12, 2021
1e5feea
add type conversion methods for flax models
patil-suraj Aug 12, 2021
f6f92b0
more robust casting
patil-suraj Sep 14, 2021
a7e1ddc
cast sinusoidal positions
patil-suraj Sep 15, 2021
d3209c4
update pegasus
patil-suraj Sep 15, 2021
3a92bc8
update albert
patil-suraj Sep 15, 2021
622f215
update test
patil-suraj Sep 15, 2021
2f0370d
make sure dtype is passed to every module
patil-suraj Sep 15, 2021
9f62f57
style
patil-suraj Sep 16, 2021
cc6e655
fix electra dense
patil-suraj Sep 16, 2021
a804880
fix t5
patil-suraj Sep 16, 2021
ae2b6d1
quality
patil-suraj Sep 16, 2021
633c0a8
add more tests
patil-suraj Sep 16, 2021
65c2455
better name
patil-suraj Sep 16, 2021
9dfb565
use the dtype for lm head computation
patil-suraj Sep 16, 2021
623b47f
Merge branch 'master' into fix-flax-dtype
patil-suraj Nov 10, 2021
fa39008
fix albert
patil-suraj Nov 10, 2021
6c861aa
style
patil-suraj Nov 10, 2021
f594571
fix albert embed dtype
patil-suraj Nov 10, 2021
084b7cb
more tests
patil-suraj Nov 10, 2021
9384382
fix vision enc-dec
patil-suraj Nov 10, 2021
4b66ae4
cleanup
patil-suraj Nov 10, 2021
52716d6
fix embed dtype pegasus
patil-suraj Nov 10, 2021
8e0572b
fix default param test
patil-suraj Nov 10, 2021
347f287
doc
patil-suraj Nov 10, 2021
d26afc0
update template
patil-suraj Nov 10, 2021
2b4c001
fix final_logits_bias dtype
patil-suraj Nov 10, 2021
785a3e5
Apply suggestions from code review
patil-suraj Nov 11, 2021
a17463b
fix doc
patil-suraj Nov 11, 2021
2a993df
fix doc
patil-suraj Nov 11, 2021
6dcffb8
add detailed docstring for dtype parameter
patil-suraj Nov 11, 2021
fe8614c
remove un-necessary import
patil-suraj Nov 11, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,13 @@ def setup(self):
self.visual_projection = nn.Dense(
self.projection_dim,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(0.02, dtype=self.dtype),
kernel_init=jax.nn.initializers.normal(0.02),
use_bias=False,
)
self.text_projection = nn.Dense(
self.projection_dim,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(0.02, dtype=self.dtype),
kernel_init=jax.nn.initializers.normal(0.02),
use_bias=False,
)
self.logit_scale = self.param("logit_scale", jax.nn.initializers.ones, [])
Expand Down
111 changes: 110 additions & 1 deletion src/transformers/modeling_flax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import os
from functools import partial
from pickle import UnpicklingError
from typing import Dict, Set, Tuple, Union
from typing import Any, Dict, Set, Tuple, Union

import flax.linen as nn
import jax
Expand Down Expand Up @@ -154,6 +154,115 @@ def params(self, params: Union[Dict, FrozenDict]):
)
self._params = params

def _cast_floating_to(self, params: Union[Dict, FrozenDict], dtype: jnp.dtype, mask: Any = None) -> Any:
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved
"""
Helper method to cast the floating-point values from given ``params`` tree to the given ```dtype```.
patil-suraj marked this conversation as resolved.
Show resolved Hide resolved
"""

# taken from https://github.com/deepmind/jmp/blob/3a8318abc3292be38582794dbf7b094e6583b192/jmp/_src/policy.py#L27
def conditional_cast(param):
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(param, jnp.ndarray) and jnp.issubdtype(param.dtype, jnp.floating):
param = param.astype(dtype)
return param

if mask is None:
return jax.tree_map(conditional_cast, params)

flat_params = flatten_dict(params)
flat_mask, _ = jax.tree_flatten(mask)

for masked, key in zip(flat_mask, flat_params.keys()):
if masked:
param = flat_params[key]
flat_params[key] = conditional_cast(param)

return unflatten_dict(flat_params)

def to_bf16(self, params: Union[Dict, FrozenDict], mask: Any = None):
r"""
Cast the floating-point ``parmas`` to ``jax.numpy.bfloat16``. This method can be used to explicitly convert the
patil-suraj marked this conversation as resolved.
Show resolved Hide resolved
model paramters to ``bfloat16``.
patil-suraj marked this conversation as resolved.
Show resolved Hide resolved

Arguments:
params (:obj:`Union[Dict, FrozenDict]`):
A ``PyTree`` of model parameters.
mask (:obj:`Union[Dict, FrozenDict]`):
A ``PyTree` with same structure as the ``params` tree. The leaves should be booleans, :obj:`True` for
params you want to cast, and should be :obj:`False` for those you want to skip

Examples::

>>> from transformers import FlaxBertModel
>>> # Download model and configuration from huggingface.co
patil-suraj marked this conversation as resolved.
Show resolved Hide resolved
>>> model = FlaxBertModel.from_pretrained('bert-base-cased')
>>> # By default, the model params will be in fp32, to cast these to bfloat16
patil-suraj marked this conversation as resolved.
Show resolved Hide resolved
>>> params = model.to_bf16(model.params)
patil-suraj marked this conversation as resolved.
Show resolved Hide resolved
>>> # If you want don't want to cast certain parameters (for example layer norm bias and scale)
patil-suraj marked this conversation as resolved.
Show resolved Hide resolved
>>> # then pass the mask as follows
>>> from flax import traverse_util
>>> model = FlaxBertModel.from_pretrained('bert-base-cased')
patil-suraj marked this conversation as resolved.
Show resolved Hide resolved
>>> flat_params = traverse_util.flatten_dict(model.params)
>>> mask = {path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale")) for path in flat_params}
>>> mask = traverse_util.unflatten_dict(mask)
>>> params = model.to_bf16(model.params, mask)
patil-suraj marked this conversation as resolved.
Show resolved Hide resolved
"""
return self._cast_floating_to(params, jnp.bfloat16, mask)

def to_fp32(self, params: Union[Dict, FrozenDict], mask: Any = None):
r"""
Cast the floating-point ``parmas`` to ``jax.numpy.float32``. This method can be used to explicitly convert the
model paramters to ``float32``.
patil-suraj marked this conversation as resolved.
Show resolved Hide resolved

Arguments:
params (:obj:`Union[Dict, FrozenDict]`):
A ``PyTree`` of model parameters.
mask (:obj:`Union[Dict, FrozenDict]`):
A ``PyTree` with same structure as the ``params` tree. The leaves should be booleans, :obj:`True` for
params you want to cast, and should be :obj:`False` for those you want to skip

Examples::

>>> from transformers import FlaxBertModel
>>> # Download model and configuration from huggingface.co
>>> model = FlaxBertModel.from_pretrained('bert-base-cased')
>>> # By default, the model params will be in fp32, to illustrate the use of this method,
>>> # we'll first cast to fp16 and back to fp32
>>> params = model.to_f16(model.params)
>>> # now cast back to fp32
>>> params = model.to_fp32(params)
patil-suraj marked this conversation as resolved.
Show resolved Hide resolved
"""
return self._cast_floating_to(params, jnp.float32, mask)

def to_fp16(self, params: Union[Dict, FrozenDict], mask: Any = None):
r"""
Cast the floating-point ``parmas`` to ``jax.numpy.float16``. This method can be used to explicitly convert the
model paramters to ``float16``.
patil-suraj marked this conversation as resolved.
Show resolved Hide resolved

Arguments:
params (:obj:`Union[Dict, FrozenDict]`):
A ``PyTree`` of model parameters.
mask (:obj:`Union[Dict, FrozenDict]`):
A ``PyTree` with same structure as the ``params` tree. The leaves should be booleans, :obj:`True` for
params you want to cast, and should be :obj:`False` for those you want to skip

Examples::

>>> from transformers import FlaxBertModel
>>> # Download model and configuration from huggingface.co
>>> model = FlaxBertModel.from_pretrained('bert-base-cased')
>>> # By default, the model params will be in fp32, to cast these to float16
>>> params = model.to_f16(model.params)
>>> # If you want don't want to cast certain parameters (for example layer norm bias and scale)
>>> # then pass the mask as follows
>>> from flax import traverse_util
>>> model = FlaxBertModel.from_pretrained('bert-base-cased')
>>> flat_params = traverse_util.flatten_dict(model.params)
>>> mask = {path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale")) for path in flat_params}
>>> mask = traverse_util.unflatten_dict(mask)
>>> params = model.to_f16(model.params, mask)
"""
return self._cast_floating_to(params, jnp.float16, mask)

@classmethod
def from_pretrained(
cls,
Expand Down
19 changes: 8 additions & 11 deletions src/transformers/models/albert/modeling_flax_albert.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,19 +152,16 @@ def setup(self):
self.config.vocab_size,
self.config.embedding_size,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
dtype=self.dtype,
)
self.position_embeddings = nn.Embed(
self.config.max_position_embeddings,
self.config.embedding_size,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
dtype=self.dtype,
)
self.token_type_embeddings = nn.Embed(
self.config.type_vocab_size,
self.config.embedding_size,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
dtype=self.dtype,
)
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
Expand Down Expand Up @@ -199,21 +196,21 @@ def setup(self):
self.query = nn.Dense(
self.config.hidden_size,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
)
self.key = nn.Dense(
self.config.hidden_size,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
)
self.value = nn.Dense(
self.config.hidden_size,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
)
self.dense = nn.Dense(
self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype,
)
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
Expand Down Expand Up @@ -278,13 +275,13 @@ def setup(self):
self.attention = FlaxAlbertSelfAttention(self.config, dtype=self.dtype)
self.ffn = nn.Dense(
self.config.intermediate_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype,
)
self.activation = ACT2FN[self.config.hidden_act]
self.ffn_output = nn.Dense(
self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype,
)
self.full_layer_layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
Expand Down Expand Up @@ -437,7 +434,7 @@ class FlaxAlbertEncoder(nn.Module):
def setup(self):
self.embedding_hidden_mapping_in = nn.Dense(
self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype,
)
self.albert_layer_groups = FlaxAlbertLayerGroups(self.config, dtype=self.dtype)
Expand Down Expand Up @@ -596,7 +593,7 @@ def setup(self):
if self.add_pooling_layer:
self.pooler = nn.Dense(
self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype,
name="pooler",
)
Expand Down
40 changes: 19 additions & 21 deletions src/transformers/models/bart/modeling_flax_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def setup(self) -> None:
self.embed_dim,
use_bias=self.bias,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.init_std),
)

self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense()
Expand Down Expand Up @@ -404,6 +404,7 @@ def setup(self) -> None:
embed_dim=self.embed_dim,
num_heads=self.config.encoder_attention_heads,
dropout=self.config.attention_dropout,
dtype=self.dtype,
)
self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype)
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
Expand All @@ -412,10 +413,10 @@ def setup(self) -> None:
self.fc1 = nn.Dense(
self.config.encoder_ffn_dim,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.init_std),
)
self.fc2 = nn.Dense(
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype)
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
)
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype)

Expand Down Expand Up @@ -514,6 +515,7 @@ def setup(self) -> None:
num_heads=self.config.decoder_attention_heads,
dropout=self.config.attention_dropout,
causal=True,
dtype=self.dtype,
)
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
self.activation_fn = ACT2FN[self.config.activation_function]
Expand All @@ -525,15 +527,16 @@ def setup(self) -> None:
embed_dim=self.embed_dim,
num_heads=self.config.decoder_attention_heads,
dropout=self.config.attention_dropout,
dtype=self.dtype,
)
self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype)
self.fc1 = nn.Dense(
self.config.encoder_ffn_dim,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.init_std),
)
self.fc2 = nn.Dense(
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype)
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
)
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype)

Expand Down Expand Up @@ -668,13 +671,13 @@ class FlaxBartClassificationHead(nn.Module):

def setup(self):
self.dense = nn.Dense(
self.inner_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype)
self.inner_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
)
self.dropout = nn.Dropout(rate=self.pooler_dropout)
self.out_proj = nn.Dense(
self.num_classes,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.init_std),
)

def __call__(self, hidden_states: jnp.ndarray, deterministic: bool):
Expand Down Expand Up @@ -703,8 +706,7 @@ def setup(self):
self.embed_tokens = nn.Embed(
self.config.vocab_size,
embed_dim,
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
dtype=self.dtype,
embedding_init=jax.nn.initializers.normal(self.config.init_std),
)

# Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
Expand All @@ -713,8 +715,7 @@ def setup(self):
self.embed_positions = nn.Embed(
self.config.max_position_embeddings + self.offset,
embed_dim,
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
dtype=self.dtype,
embedding_init=jax.nn.initializers.normal(self.config.init_std),
)
self.layers = FlaxBartEncoderLayerCollection(self.config, self.dtype)
self.layernorm_embedding = nn.LayerNorm(dtype=self.dtype)
Expand Down Expand Up @@ -776,8 +777,7 @@ def setup(self):
self.embed_tokens = nn.Embed(
self.config.vocab_size,
embed_dim,
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
dtype=self.dtype,
embedding_init=jax.nn.initializers.normal(self.config.init_std),
)

# Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
Expand All @@ -786,8 +786,7 @@ def setup(self):
self.embed_positions = nn.Embed(
self.config.max_position_embeddings + self.offset,
embed_dim,
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
dtype=self.dtype,
embedding_init=jax.nn.initializers.normal(self.config.init_std),
)

self.layers = FlaxBartDecoderLayerCollection(self.config, self.dtype)
Expand Down Expand Up @@ -850,8 +849,7 @@ def setup(self):
self.shared = nn.Embed(
self.config.vocab_size,
self.config.d_model,
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
dtype=self.dtype,
embedding_init=jax.nn.initializers.normal(self.config.init_std),
)

self.encoder = FlaxBartEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared)
Expand Down Expand Up @@ -1256,7 +1254,7 @@ def setup(self):
self.model.shared.num_embeddings,
use_bias=False,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.init_std),
)
self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, self.model.shared.num_embeddings))

Expand Down Expand Up @@ -1300,7 +1298,7 @@ def __call__(
else:
lm_logits = self.lm_head(hidden_states)

lm_logits += self.final_logits_bias
lm_logits += self.final_logits_bias.astype(self.dtype)

if not return_dict:
output = (lm_logits,) + outputs[1:]
Expand Down Expand Up @@ -1416,7 +1414,7 @@ def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_
else:
lm_logits = module.lm_head(hidden_states)

lm_logits += module.final_logits_bias
lm_logits += module.final_logits_bias.astype(self.dtype)
return lm_logits, outputs

outputs = self.module.apply(
Expand Down Expand Up @@ -1647,7 +1645,7 @@ class FlaxBartForQuestionAnsweringModule(nn.Module):
def setup(self):
self.model = FlaxBartModule(config=self.config, dtype=self.dtype)
self.qa_outputs = nn.Dense(
self.num_labels, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype)
self.num_labels, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
)

def _get_encoder_module(self):
Expand Down
Loading