-
Notifications
You must be signed in to change notification settings - Fork 25.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add MultiBERTs conversion script (#13077)
* Init multibert checkpoint conversion script * Rename conversion script * Fix MultiBerts Conversion Script * Apply suggestions from code review Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>
- Loading branch information
1 parent
e1d1c7c
commit 9a9805f
Showing
1 changed file
with
128 additions
and
0 deletions.
There are no files selected for viewing
128 changes: 128 additions & 0 deletions
128
src/transformers/models/bert/convert_multiberts_checkpoint_to_pytorch.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
# Copyright 2021 The HuggingFace Team. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
""" | ||
This script can be used to convert a head-less TF 2.x MultiBERTs model to PyTorch, as published on the official GitHub: | ||
https://github.com/tensorflow/models/tree/master/official/nlp/bert | ||
""" | ||
|
||
import argparse | ||
import os | ||
|
||
import tensorflow as tf | ||
import torch | ||
|
||
from transformers import BertConfig, BertForPreTraining | ||
from transformers.utils import logging | ||
|
||
|
||
logging.set_verbosity_info() | ||
logger = logging.get_logger(__name__) | ||
|
||
|
||
def convert_multibert_checkpoint_to_pytorch(tf_checkpoint_path, config_path, save_path): | ||
tf_path = os.path.abspath(tf_checkpoint_path) | ||
logger.info(f"Converting TensorFlow checkpoint from {tf_path}") | ||
|
||
# Load weights from TF model | ||
init_vars = tf.train.list_variables(tf_path) | ||
names = [] | ||
arrays = [] | ||
config = BertConfig.from_pretrained(config_path) | ||
model = BertForPreTraining(config) | ||
|
||
layer_nums = [] | ||
for full_name, shape in init_vars: | ||
array = tf.train.load_variable(tf_path, full_name) | ||
names.append(full_name) | ||
split_names = full_name.split("/") | ||
for name in split_names: | ||
if name.startswith("layer_"): | ||
layer_nums.append(int(name.split("_")[-1])) | ||
|
||
arrays.append(array) | ||
logger.info(f"Read a total of {len(arrays):,} layers") | ||
|
||
name_to_array = dict(zip(names, arrays)) | ||
|
||
# Check that number of layers match | ||
assert config.num_hidden_layers == len(list(set(layer_nums))) | ||
|
||
state_dict = model.state_dict() | ||
|
||
# Need to do this explicitly as it is a buffer | ||
position_ids = state_dict["bert.embeddings.position_ids"] | ||
new_state_dict = {"bert.embeddings.position_ids": position_ids} | ||
|
||
# Encoder Layers | ||
for weight_name in names: | ||
pt_weight_name = weight_name.replace("kernel", "weight").replace("gamma", "weight").replace("beta", "bias") | ||
name_split = pt_weight_name.split("/") | ||
for name_idx, name in enumerate(name_split): | ||
if name.startswith("layer_"): | ||
name_split[name_idx] = name.replace("_", ".") | ||
|
||
if name_split[-1].endswith("embeddings"): | ||
name_split.append("weight") | ||
|
||
if name_split[0] == "cls": | ||
if name_split[-1] == "output_bias": | ||
name_split[-1] = "bias" | ||
if name_split[-1] == "output_weights": | ||
name_split[-1] = "weight" | ||
|
||
if name_split[-1] == "weight" and name_split[-2] == "dense": | ||
name_to_array[weight_name] = name_to_array[weight_name].T | ||
|
||
pt_weight_name = ".".join(name_split) | ||
|
||
new_state_dict[pt_weight_name] = torch.from_numpy(name_to_array[weight_name]) | ||
|
||
new_state_dict["cls.predictions.decoder.weight"] = new_state_dict["bert.embeddings.word_embeddings.weight"].clone() | ||
new_state_dict["cls.predictions.decoder.bias"] = new_state_dict["cls.predictions.bias"].clone().T | ||
# Load State Dict | ||
model.load_state_dict(new_state_dict) | ||
|
||
# Save PreTrained | ||
logger.info(f"Saving pretrained model to {save_path}") | ||
model.save_pretrained(save_path) | ||
|
||
return model | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"--tf_checkpoint_path", | ||
type=str, | ||
default="./seed_0/bert.ckpt", | ||
required=False, | ||
help="Path to the TensorFlow 2.x checkpoint path.", | ||
) | ||
parser.add_argument( | ||
"--bert_config_file", | ||
type=str, | ||
default="./bert_config.json", | ||
required=False, | ||
help="The config json file corresponding to the BERT model. This specifies the model architecture.", | ||
) | ||
parser.add_argument( | ||
"--save_path", | ||
type=str, | ||
required=True, | ||
help="Path to the output PyTorch model (must include filename).", | ||
) | ||
args = parser.parse_args() | ||
|
||
convert_multibert_checkpoint_to_pytorch(args.tf_checkpoint_path, args.bert_config_file, args.save_path) |