Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Wav2Vec2] Add New Wav2Vec2 Translation #14392

Merged
merged 14 commits into from
Nov 17, 2021
Original file line number Diff line number Diff line change
@@ -0,0 +1,356 @@
# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert Wav2Vec2 checkpoint."""


import argparse

import fairseq
import torch
from torch import nn

from transformers import (
MBart50Tokenizer,
MBartConfig,
MBartForCausalLM,
SpeechEncoderDecoderConfig,
SpeechEncoderDecoderModel,
Wav2Vec2Config,
Wav2Vec2FeatureExtractor,
Wav2Vec2Model,
logging,
)


logging.set_verbosity_info()
logger = logging.get_logger(__name__)

MAPPING = {
"post_extract_proj": "feature_projection.projection",
"encoder.pos_conv.0": "encoder.pos_conv_embed.conv",
"self_attn.k_proj": "encoder.layers.*.attention.k_proj",
"self_attn.v_proj": "encoder.layers.*.attention.v_proj",
"self_attn.q_proj": "encoder.layers.*.attention.q_proj",
"self_attn.out_proj": "encoder.layers.*.attention.out_proj",
"self_attn_layer_norm": "encoder.layers.*.layer_norm",
"fc1": "encoder.layers.*.feed_forward.intermediate_dense",
"fc2": "encoder.layers.*.feed_forward.output_dense",
"final_layer_norm": "encoder.layers.*.final_layer_norm",
"encoder.layer_norm": "encoder.layer_norm",
"w2v_model.layer_norm": "feature_projection.layer_norm",
"quantizer.weight_proj": "quantizer.weight_proj",
"quantizer.vars": "quantizer.codevectors",
"project_q": "project_q",
"final_proj": "project_hid",
"w2v_encoder.proj": "lm_head",
"mask_emb": "masked_spec_embed",
}
TOP_LEVEL_KEYS = [
"lm_head",
"quantizer.weight_proj",
"quantizer.codevectors",
"project_q",
"project_hid",
]


def set_recursively(hf_pointer, key, value, full_name, weight_type):
for attribute in key.split("."):
hf_pointer = getattr(hf_pointer, attribute)

if weight_type is not None:
hf_shape = getattr(hf_pointer, weight_type).shape
else:
hf_shape = hf_pointer.shape

assert (
hf_shape == value.shape
), f"Shape of hf {key + '.' + weight_type if weight_type is not None else ''} is {hf_shape}, but should be {value.shape} for {full_name}"

if weight_type == "weight":
hf_pointer.weight.data = value
elif weight_type == "weight_g":
hf_pointer.weight_g.data = value
elif weight_type == "weight_v":
hf_pointer.weight_v.data = value
elif weight_type == "bias":
hf_pointer.bias.data = value
else:
hf_pointer.data = value

logger.info(f"{key + '.' + weight_type if weight_type is not None else ''} was initialized from {full_name}.")


def recursively_load_weights_wav2vec2(fairseq_model, hf_model):
unused_weights = []
fairseq_dict = fairseq_model.state_dict()

feature_extractor = hf_model.feature_extractor
adaptor = hf_model.adaptor

for name, value in fairseq_dict.items():
is_used = False
if "conv_layers" in name:
load_conv_layer(
name,
value,
feature_extractor,
unused_weights,
hf_model.config.feat_extract_norm == "group",
)
is_used = True
elif any(x in name for x in ["adaptor", "w2v_encoder.proj.", "w2v_proj_ln."]):
load_adaptor(
name,
value,
adaptor,
unused_weights,
)
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved
is_used = True
else:
for key, mapped_key in MAPPING.items():
if key in name or key.split("w2v_model.")[-1] == name.split(".")[0]:
is_used = True
if "*" in mapped_key:
layer_index = name.split(key)[0].split(".")[-2]
mapped_key = mapped_key.replace("*", layer_index)
if "weight_g" in name:
weight_type = "weight_g"
elif "weight_v" in name:
weight_type = "weight_v"
elif "bias" in name:
weight_type = "bias"
elif "weight" in name:
weight_type = "weight"
else:
weight_type = None
set_recursively(hf_model, mapped_key, value, name, weight_type)
continue
if not is_used:
unused_weights.append(name)

logger.warning(f"Unused weights: {unused_weights}")


def load_conv_layer(full_name, value, feature_extractor, unused_weights, use_group_norm):
name = full_name.split("conv_layers.")[-1]
items = name.split(".")
layer_id = int(items[0])
type_id = int(items[1])

if type_id == 0:
if "bias" in name:
assert (
value.shape == feature_extractor.conv_layers[layer_id].conv.bias.data.shape
), f"{full_name} has size {value.shape}, but {feature_extractor.conv_layers[layer_id].conv.bias.data.shape} was found."
feature_extractor.conv_layers[layer_id].conv.bias.data = value
logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.")
elif "weight" in name:
assert (
value.shape == feature_extractor.conv_layers[layer_id].conv.weight.data.shape
), f"{full_name} has size {value.shape}, but {feature_extractor.conv_layers[layer_id].conv.weight.data.shape} was found."
feature_extractor.conv_layers[layer_id].conv.weight.data = value
logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.")
elif (type_id == 2 and not use_group_norm) or (type_id == 2 and layer_id == 0 and use_group_norm):
if "bias" in name:
assert (
value.shape == feature_extractor.conv_layers[layer_id].layer_norm.bias.data.shape
), f"{full_name} has size {value.shape}, but {feature_extractor[layer_id].layer_norm.bias.data.shape} was found."
feature_extractor.conv_layers[layer_id].layer_norm.bias.data = value
logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.")
elif "weight" in name:
assert (
value.shape == feature_extractor.conv_layers[layer_id].layer_norm.weight.data.shape
), f"{full_name} has size {value.shape}, but {feature_extractor[layer_id].layer_norm.weight.data.shape} was found."
feature_extractor.conv_layers[layer_id].layer_norm.weight.data = value
logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.")
else:
unused_weights.append(full_name)


def load_adaptor(full_name, value, adaptor, unused_weights):
name = full_name.split("adaptor.")[-1]
items = name.split(".")

if items[1].isdigit():
layer_id = int(items[1])
else:
layer_id = None

if "adaptor" not in full_name:
if "proj_ln" in full_name:
# has to be layer norm
if "bias" in name:
assert (
value.shape == adaptor.proj_layer_norm.bias.data.shape
), f"{full_name} has size {value.shape}, but {adaptor.proj_layer_norm.bias.data.shape} was found."
adaptor.proj_layer_norm.bias.data = value
logger.info(f"Adapter proj layer norm bias was initialized from {full_name}.")
if "weight" in name:
assert (
value.shape == adaptor.proj_layer_norm.weight.data.shape
), f"{full_name} has size {value.shape}, but {adaptor.proj_layer_norm.weight.data.shape} was found."
adaptor.proj_layer_norm.weight.data = value
else:
# has to be projection layer
if "bias" in name:
assert (
value.shape == adaptor.proj.bias.data.shape
), f"{full_name} has size {value.shape}, but {adaptor.proj.bias.data.shape} was found."
adaptor.proj.bias.data = value
logger.info(f"Adapter proj layer bias was initialized from {full_name}.")
if "weight" in name:
assert (
value.shape == adaptor.proj.weight.data.shape
), f"{full_name} has size {value.shape}, but {adaptor.proj.weight.data.shape} was found."
adaptor.proj.weight.data = value
logger.info(f"Adapter proj layer weight was initialized from {full_name}.")
elif isinstance(layer_id, int):
if "bias" in name:
assert (
value.shape == adaptor.layers[layer_id].conv.bias.data.shape
), f"{full_name} has size {value.shape}, but {adaptor.layers[layer_id].conv.bias.data.shape} was found."
adaptor.layers[layer_id].conv.bias.data = value
logger.info(f"Adapter layer {layer_id} bias was initialized from {full_name}.")
elif "weight" in name:
assert (
value.shape == adaptor.layers[layer_id].conv.weight.data.shape
), f"{full_name} has size {value.shape}, but {adaptor.layers[layer_id].conv.weight.data.shape} was found."
adaptor.layers[layer_id].conv.weight.data = value
logger.info(f"Adapter layer {layer_id} bias was initialized from {full_name}.")
else:
unused_weights.append(full_name)


def make_linear_from_emb(emb):
vocab_size, emb_size = emb.weight.shape
lin_layer = nn.Linear(vocab_size, emb_size, bias=False)
lin_layer.weight.data = emb.weight.data
return lin_layer


@torch.no_grad()
def convert_wav2vec2_checkpoint(
checkpoint_path,
pytorch_dump_folder_path,
dict_path,
config_yaml_path,
encoder_config_path,
decoder_config_path,
add_adaptor,
adaptor_kernel_size,
adaptor_stride,
decoder_start_token_id,
encoder_output_dim,
):
"""
Copy/paste/tweak model's weights to transformers design.
"""
# load configs
encoder_config = Wav2Vec2Config.from_pretrained(
encoder_config_path,
add_adaptor=True,
adaptor_stride=adaptor_stride,
adaptor_kernel_size=adaptor_kernel_size,
use_auth_token=True,
output_hidden_size=encoder_output_dim,
)
decoder_config = MBartConfig.from_pretrained(decoder_config_path)

# load model
model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task(
[checkpoint_path],
arg_overrides={
"config_yaml": config_yaml_path,
"data": "/".join(dict_path.split("/")[:-1]),
"w2v_path": checkpoint_path,
"load_pretrained_decoder_from": None,
},
)
model = model[0].eval()

# load feature extractor
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(encoder_config_path, use_auth_token=True)

# set weights for wav2vec2 encoder
hf_encoder = Wav2Vec2Model(encoder_config)

recursively_load_weights_wav2vec2(model.encoder, hf_encoder)

# load decoder weights
hf_decoder = MBartForCausalLM(decoder_config)
missing_keys, unexpected_keys = hf_decoder.model.decoder.load_state_dict(model.decoder.state_dict(), strict=False)
logger.warning(f"The following keys are missing when loading the decoder weights: {missing_keys}")
logger.warning(f"The following keys are unexpected when loading the decoder weights: {unexpected_keys}")

hf_wav2vec = SpeechEncoderDecoderModel(encoder=hf_encoder, decoder=hf_decoder)
hf_wav2vec.config.tie_word_embeddings = False

tokenizer = MBart50Tokenizer(dict_path)
tokenizer.save_pretrained(pytorch_dump_folder_path)

config = hf_wav2vec.config.to_dict()
config["pad_token_id"] = tokenizer.pad_token_id
config["bos_token_id"] = tokenizer.bos_token_id
config["eos_token_id"] = tokenizer.eos_token_id
config["tokenizer_class"] = "mbart50"
config["feature_extractor_type"] = "wav2vec2"

config["decoder_start_token_id"] = decoder_start_token_id

hf_wav2vec.config = SpeechEncoderDecoderConfig.from_dict(config)

hf_wav2vec.save_pretrained(pytorch_dump_folder_path)
feature_extractor.save_pretrained(pytorch_dump_folder_path)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to fairseq checkpoint")
parser.add_argument("--dict_path", default=None, type=str, help="Path to dict of fine-tuned model")
parser.add_argument("--config_yaml_path", default=None, type=str, help="Path to yaml file of fine-tuned model")
parser.add_argument(
"--encoder_config_path",
default="facebook/wav2vec2-xls-r-2b",
type=str,
help="Path to hf encoder wav2vec2 checkpoint config",
)
parser.add_argument(
"--decoder_config_path",
default="facebook/mbart-large-50-one-to-many-mmt",
type=str,
help="Path to hf decoder checkpoint config",
)
parser.add_argument("--add_adaptor", default=True, type=bool, help="whethere to add model adaptor layers")
parser.add_argument("--adaptor_stride", default=2, type=int, help="stride of adaptor layers")
parser.add_argument("--adaptor_kernel_size", default=3, type=int, help="kernel size of adaptor layers")
parser.add_argument("--encoder_output_dim", default=1024, type=int, help="encoder output dim")
parser.add_argument("--start_token_id", default=250004, type=int, help="`decoder_start_token_id` of model config")

args = parser.parse_args()
convert_wav2vec2_checkpoint(
args.checkpoint_path,
args.pytorch_dump_folder_path,
args.dict_path,
args.config_yaml_path,
encoder_config_path=args.encoder_config_path,
decoder_config_path=args.decoder_config_path,
add_adaptor=args.add_adaptor,
adaptor_kernel_size=args.adaptor_kernel_size,
adaptor_stride=args.adaptor_stride,
decoder_start_token_id=args.start_token_id,
encoder_output_dim=args.encoder_output_dim,
)
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,9 @@ def __init__(
self.encoder.config = self.config.encoder
self.decoder.config = self.config.decoder

if self.encoder.config.hidden_size != self.decoder.config.hidden_size:
# get encoder output hidden size
self.encoder_output_dim = getattr(config.encoder, "output_hidden_size", config.encoder.hidden_size)
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved
if self.encoder_output_dim != self.decoder.config.hidden_size:
# encoder outputs might need to be projected to different dimension for decoder
self.enc_to_dec_proj = nn.Linear(self.encoder.config.hidden_size, self.decoder.config.hidden_size)

Expand Down Expand Up @@ -471,7 +473,7 @@ def forward(
encoder_hidden_states = encoder_outputs[0]

# project encoder_hidden_states
if self.encoder.config.hidden_size != self.decoder.config.hidden_size:
if self.encoder_output_dim != self.decoder.config.hidden_size:
encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)

# compute correct encoder attention mask
Expand Down
Loading