Skip to content

Commit

Permalink
fix loading flax bf16 weights in pt (#14369)
Browse files Browse the repository at this point in the history
* fix loading flax bf16 weights in pt

* fix clip test

* fix t5 test

* add logging statement

* Update src/transformers/modeling_flax_pytorch_utils.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* switch back to native any

* fix check for bf16 weights

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
  • Loading branch information
patil-suraj and patrickvonplaten authored Nov 11, 2021
1 parent 7f20bf0 commit 3d607df
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 0 deletions.
14 changes: 14 additions & 0 deletions src/transformers/modeling_flax_pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import numpy as np

import jax
import jax.numpy as jnp
import transformers
from flax.serialization import from_bytes
Expand Down Expand Up @@ -189,6 +190,19 @@ def load_flax_weights_in_pytorch_model(pt_model, flax_state):
)
raise

# check if we have bf16 weights
is_type_bf16 = flatten_dict(jax.tree_map(lambda x: x.dtype == jnp.bfloat16, flax_state)).values()
if any(is_type_bf16):
# convert all weights to fp32 if the are bf16 since torch.from_numpy can-not handle bf16
# and bf16 is not fully supported in PT yet.
logger.warning(
"Found ``bfloat16`` weights in Flax model. Casting all ``bfloat16`` weights to ``float32`` "
"before loading those in PyTorch model."
)
flax_state = jax.tree_map(
lambda params: params.astype(np.float32) if params.dtype == jnp.bfloat16 else params, flax_state
)

flax_state_dict = flatten_dict(flax_state)
pt_model_dict = pt_model.state_dict()

Expand Down
10 changes: 10 additions & 0 deletions tests/test_modeling_flax_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,11 @@ def test_save_load_from_base_pt(self):
def test_save_load_to_base_pt(self):
pass

# FlaxCLIPVisionModel does not have any base model
@is_pt_flax_cross_test
def test_save_load_bf16_to_base_pt(self):
pass

@slow
def test_model_from_pretrained(self):
for model_class_name in self.all_model_classes:
Expand Down Expand Up @@ -332,6 +337,11 @@ def test_save_load_from_base_pt(self):
def test_save_load_to_base_pt(self):
pass

# FlaxCLIPVisionModel does not have any base model
@is_pt_flax_cross_test
def test_save_load_bf16_to_base_pt(self):
pass

@slow
def test_model_from_pretrained(self):
for model_class_name in self.all_model_classes:
Expand Down
29 changes: 29 additions & 0 deletions tests/test_modeling_flax_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,35 @@ def test_save_load_to_base_pt(self):
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")

@is_pt_flax_cross_test
def test_save_load_bf16_to_base_pt(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
base_class = FLAX_MODEL_MAPPING[config.__class__]

for model_class in self.all_model_classes:
if model_class == base_class:
continue

model = model_class(config)
model.params = model.to_bf16(model.params)
base_params_from_head = flatten_dict(unfreeze(model.params[model.base_model_prefix]))

# convert Flax model to PyTorch model
pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning
pt_model = pt_model_class(config).eval()
pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params)

# check that all base model weights are loaded correctly
with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname)
base_model = base_class.from_pretrained(tmpdirname, from_pt=True)

base_params = flatten_dict(unfreeze(base_model.params))

for key in base_params_from_head.keys():
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")

def test_jit_compilation(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

Expand Down
30 changes: 30 additions & 0 deletions tests/test_modeling_flax_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,36 @@ def test_save_load_to_base_pt(self):
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")

# overwrite since special base model prefix is used
@is_pt_flax_cross_test
def test_save_load_bf16_to_base_pt(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
base_class = FLAX_MODEL_MAPPING[config.__class__]

for model_class in self.all_model_classes:
if model_class == base_class:
continue

model = model_class(config)
model.params = model.to_bf16(model.params)
base_params_from_head = flatten_dict(unfreeze(model.params))

# convert Flax model to PyTorch model
pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning
pt_model = pt_model_class(config).eval()
pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params)

# check that all base model weights are loaded correctly
with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname)
base_model = base_class.from_pretrained(tmpdirname, from_pt=True)

base_params = flatten_dict(unfreeze(base_model.params))

for key in base_params_from_head.keys():
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")


@require_sentencepiece
@require_tokenizers
Expand Down

0 comments on commit 3d607df

Please sign in to comment.