Skip to content

Commit

Permalink
Add MultiBERTs conversion script (#13077)
Browse files Browse the repository at this point in the history
* Init multibert checkpoint conversion script

* Rename conversion script

* Fix MultiBerts Conversion Script

* Apply suggestions from code review

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>
  • Loading branch information
3 people committed Sep 30, 2021
1 parent e1d1c7c commit 9a9805f
Showing 1 changed file with 128 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
This script can be used to convert a head-less TF 2.x MultiBERTs model to PyTorch, as published on the official GitHub:
https://github.com/tensorflow/models/tree/master/official/nlp/bert
"""

import argparse
import os

import tensorflow as tf
import torch

from transformers import BertConfig, BertForPreTraining
from transformers.utils import logging


logging.set_verbosity_info()
logger = logging.get_logger(__name__)


def convert_multibert_checkpoint_to_pytorch(tf_checkpoint_path, config_path, save_path):
tf_path = os.path.abspath(tf_checkpoint_path)
logger.info(f"Converting TensorFlow checkpoint from {tf_path}")

# Load weights from TF model
init_vars = tf.train.list_variables(tf_path)
names = []
arrays = []
config = BertConfig.from_pretrained(config_path)
model = BertForPreTraining(config)

layer_nums = []
for full_name, shape in init_vars:
array = tf.train.load_variable(tf_path, full_name)
names.append(full_name)
split_names = full_name.split("/")
for name in split_names:
if name.startswith("layer_"):
layer_nums.append(int(name.split("_")[-1]))

arrays.append(array)
logger.info(f"Read a total of {len(arrays):,} layers")

name_to_array = dict(zip(names, arrays))

# Check that number of layers match
assert config.num_hidden_layers == len(list(set(layer_nums)))

state_dict = model.state_dict()

# Need to do this explicitly as it is a buffer
position_ids = state_dict["bert.embeddings.position_ids"]
new_state_dict = {"bert.embeddings.position_ids": position_ids}

# Encoder Layers
for weight_name in names:
pt_weight_name = weight_name.replace("kernel", "weight").replace("gamma", "weight").replace("beta", "bias")
name_split = pt_weight_name.split("/")
for name_idx, name in enumerate(name_split):
if name.startswith("layer_"):
name_split[name_idx] = name.replace("_", ".")

if name_split[-1].endswith("embeddings"):
name_split.append("weight")

if name_split[0] == "cls":
if name_split[-1] == "output_bias":
name_split[-1] = "bias"
if name_split[-1] == "output_weights":
name_split[-1] = "weight"

if name_split[-1] == "weight" and name_split[-2] == "dense":
name_to_array[weight_name] = name_to_array[weight_name].T

pt_weight_name = ".".join(name_split)

new_state_dict[pt_weight_name] = torch.from_numpy(name_to_array[weight_name])

new_state_dict["cls.predictions.decoder.weight"] = new_state_dict["bert.embeddings.word_embeddings.weight"].clone()
new_state_dict["cls.predictions.decoder.bias"] = new_state_dict["cls.predictions.bias"].clone().T
# Load State Dict
model.load_state_dict(new_state_dict)

# Save PreTrained
logger.info(f"Saving pretrained model to {save_path}")
model.save_pretrained(save_path)

return model


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--tf_checkpoint_path",
type=str,
default="./seed_0/bert.ckpt",
required=False,
help="Path to the TensorFlow 2.x checkpoint path.",
)
parser.add_argument(
"--bert_config_file",
type=str,
default="./bert_config.json",
required=False,
help="The config json file corresponding to the BERT model. This specifies the model architecture.",
)
parser.add_argument(
"--save_path",
type=str,
required=True,
help="Path to the output PyTorch model (must include filename).",
)
args = parser.parse_args()

convert_multibert_checkpoint_to_pytorch(args.tf_checkpoint_path, args.bert_config_file, args.save_path)

0 comments on commit 9a9805f

Please sign in to comment.