forked from huggingface/transformers
-
Notifications
You must be signed in to change notification settings - Fork 0
/
pt_to_tf.py
243 lines (210 loc) 路 11.3 KB
/
pt_to_tf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from argparse import ArgumentParser, Namespace
from importlib import import_module
import numpy as np
from datasets import load_dataset
from huggingface_hub import Repository, upload_file
from .. import AutoConfig, AutoFeatureExtractor, AutoTokenizer, is_tf_available, is_torch_available
from ..utils import logging
from . import BaseTransformersCLICommand
if is_tf_available():
import tensorflow as tf
tf.config.experimental.enable_tensor_float_32_execution(False)
if is_torch_available():
import torch
MAX_ERROR = 5e-5 # larger error tolerance than in our internal tests, to avoid flaky user-facing errors
TF_WEIGHTS_NAME = "tf_model.h5"
def convert_command_factory(args: Namespace):
"""
Factory function used to convert a model PyTorch checkpoint in a TensorFlow 2 checkpoint.
Returns: ServeCommand
"""
return PTtoTFCommand(args.model_name, args.local_dir, args.no_pr)
class PTtoTFCommand(BaseTransformersCLICommand):
@staticmethod
def register_subcommand(parser: ArgumentParser):
"""
Register this command to argparse so it's available for the transformer-cli
Args:
parser: Root parser to register command-specific arguments
"""
train_parser = parser.add_parser(
"pt-to-tf",
help=(
"CLI tool to run convert a transformers model from a PyTorch checkpoint to a TensorFlow checkpoint."
" Can also be used to validate existing weights without opening PRs, with --no-pr."
),
)
train_parser.add_argument(
"--model-name",
type=str,
required=True,
help="The model name, including owner/organization, as seen on the hub.",
)
train_parser.add_argument(
"--local-dir",
type=str,
default="",
help="Optional local directory of the model repository. Defaults to /tmp/{model_name}",
)
train_parser.add_argument(
"--no-pr", action="store_true", help="Optional flag to NOT open a PR with converted weights."
)
train_parser.set_defaults(func=convert_command_factory)
@staticmethod
def compare_pt_tf_models(pt_model, pt_input, tf_model, tf_input):
"""
Compares the tf and the pt models, given their inputs, returning a tuple with the maximum observed difference
and its source.
"""
pt_outputs = pt_model(**pt_input, output_hidden_states=True)
tf_outputs = tf_model(**tf_input, output_hidden_states=True)
# 1. All keys must be the same
if set(pt_outputs.keys()) != set(tf_outputs.keys()):
raise ValueError("The model outputs have different attributes, aborting.")
# 2. For each key, ALL values must be the same
def compate_pt_tf_values(pt_out, tf_out, attr_name=""):
max_difference = 0
max_difference_source = ""
# If the current attribute is a tensor, it is a leaf and we make the comparison. Otherwise, we will dig in
# recursivelly, keeping the name of the attribute.
if isinstance(pt_out, (torch.Tensor)):
difference = np.max(np.abs(pt_out.detach().numpy() - tf_out.numpy()))
if difference > max_difference:
max_difference = difference
max_difference_source = attr_name
else:
root_name = attr_name
for i, pt_item in enumerate(pt_out):
# If it is a named attribute, we keep the name. Otherwise, just its index.
if isinstance(pt_item, str):
branch_name = root_name + pt_item
tf_item = tf_out[pt_item]
pt_item = pt_out[pt_item]
else:
branch_name = root_name + f"[{i}]"
tf_item = tf_out[i]
difference, difference_source = compate_pt_tf_values(pt_item, tf_item, branch_name)
if difference > max_difference:
max_difference = difference
max_difference_source = difference_source
return max_difference, max_difference_source
return compate_pt_tf_values(pt_outputs, tf_outputs)
def __init__(self, model_name: str, local_dir: str, no_pr: bool, *args):
self._logger = logging.get_logger("transformers-cli/pt_to_tf")
self._model_name = model_name
self._local_dir = local_dir if local_dir else os.path.join("/tmp", model_name)
self._no_pr = no_pr
def get_text_inputs(self):
tokenizer = AutoTokenizer.from_pretrained(self._local_dir)
sample_text = ["Hi there!", "I am a batch with more than one row and different input lengths."]
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
pt_input = tokenizer(sample_text, return_tensors="pt", padding=True, truncation=True)
tf_input = tokenizer(sample_text, return_tensors="tf", padding=True, truncation=True)
return pt_input, tf_input
def get_audio_inputs(self):
processor = AutoFeatureExtractor.from_pretrained(self._local_dir)
num_samples = 2
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
speech_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"]
raw_samples = [x["array"] for x in speech_samples]
pt_input = processor(raw_samples, return_tensors="pt", padding=True)
tf_input = processor(raw_samples, return_tensors="tf", padding=True)
return pt_input, tf_input
def get_image_inputs(self):
feature_extractor = AutoFeatureExtractor.from_pretrained(self._local_dir)
num_samples = 2
ds = load_dataset("cifar10", "plain_text", split="test")[:num_samples]["img"]
pt_input = feature_extractor(images=ds, return_tensors="pt")
tf_input = feature_extractor(images=ds, return_tensors="tf")
return pt_input, tf_input
def run(self):
# Fetch remote data
# TODO: implement a solution to pull a specific PR/commit, so we can use this CLI to validate pushes.
repo = Repository(local_dir=self._local_dir, clone_from=self._model_name)
repo.git_pull() # in case the repo already exists locally, but with an older commit
# Load config and get the appropriate architecture -- the latter is needed to convert the head's weights
config = AutoConfig.from_pretrained(self._local_dir)
architectures = config.architectures
if architectures is None: # No architecture defined -- use auto classes
pt_class = getattr(import_module("transformers"), "AutoModel")
tf_class = getattr(import_module("transformers"), "TFAutoModel")
self._logger.warn("No detected architecture, using auto classes")
else: # Architecture defined -- use it
if len(architectures) > 1:
raise ValueError(f"More than one architecture was found, aborting. (architectures = {architectures})")
pt_class = getattr(import_module("transformers"), architectures[0])
tf_class = getattr(import_module("transformers"), "TF" + architectures[0])
self._logger.warn(f"Detected architecture: {architectures[0]}")
# Load models and acquire a basic input for its modality.
pt_model = pt_class.from_pretrained(self._local_dir)
main_input_name = pt_model.main_input_name
if main_input_name == "input_ids":
pt_input, tf_input = self.get_text_inputs()
elif main_input_name == "pixel_values":
pt_input, tf_input = self.get_image_inputs()
elif main_input_name == "input_features":
pt_input, tf_input = self.get_audio_inputs()
else:
raise ValueError(f"Can't detect the model modality (`main_input_name` = {main_input_name})")
tf_from_pt_model = tf_class.from_pretrained(self._local_dir, from_pt=True)
# Extra input requirements, in addition to the input modality
if hasattr(pt_model, "encoder") and hasattr(pt_model, "decoder"):
decoder_input_ids = np.asarray([[1], [1]], dtype=int) * pt_model.config.decoder_start_token_id
pt_input.update({"decoder_input_ids": torch.tensor(decoder_input_ids)})
tf_input.update({"decoder_input_ids": tf.convert_to_tensor(decoder_input_ids)})
# Confirms that cross loading PT weights into TF worked.
crossload_diff, diff_source = self.compare_pt_tf_models(pt_model, pt_input, tf_from_pt_model, tf_input)
if crossload_diff >= MAX_ERROR:
raise ValueError(
"The cross-loaded TF model has different outputs, something went wrong! (max difference ="
f" {crossload_diff:.3e}, observed in {diff_source})"
)
# Save the weights in a TF format (if they don't exist) and confirms that the results are still good
tf_weights_path = os.path.join(self._local_dir, TF_WEIGHTS_NAME)
if not os.path.exists(tf_weights_path):
tf_from_pt_model.save_weights(tf_weights_path)
del tf_from_pt_model # will no longer be used, and may have a large memory footprint
tf_model = tf_class.from_pretrained(self._local_dir)
converted_diff, diff_source = self.compare_pt_tf_models(pt_model, pt_input, tf_model, tf_input)
if converted_diff >= MAX_ERROR:
raise ValueError(
"The converted TF model has different outputs, something went wrong! (max difference ="
f" {converted_diff:.3e}, observed in {diff_source})"
)
if not self._no_pr:
# TODO: remove try/except when the upload to PR feature is released
# (https://github.com/huggingface/huggingface_hub/pull/884)
try:
self._logger.warn("Uploading the weights into a new PR...")
hub_pr_url = upload_file(
path_or_fileobj=tf_weights_path,
path_in_repo=TF_WEIGHTS_NAME,
repo_id=self._model_name,
create_pr=True,
pr_commit_summary="Add TF weights",
pr_commit_description=(
f"Validated by the `pt_to_tf` CLI. Max crossload output difference={crossload_diff:.3e};"
f" Max converted output difference={converted_diff:.3e}."
),
)
self._logger.warn(f"PR open in {hub_pr_url}")
except TypeError:
self._logger.warn(
f"You can now open a PR in https://huggingface.co/{self._model_name}/discussions, manually"
f" uploading the file in {tf_weights_path}"
)