Skip to content

Commit

Permalink
CLI: Print all different tensors on exception (#17612)
Browse files Browse the repository at this point in the history
  • Loading branch information
gante committed Jun 8, 2022
1 parent e9d5138 commit 66e8656
Showing 1 changed file with 31 additions and 28 deletions.
59 changes: 31 additions & 28 deletions src/transformers/commands/pt_to_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,10 @@ def register_subcommand(parser: ArgumentParser):
train_parser.set_defaults(func=convert_command_factory)

@staticmethod
def compare_pt_tf_models(pt_model, pt_input, tf_model, tf_input):
def find_pt_tf_differences(pt_model, pt_input, tf_model, tf_input):
"""
Compares the TensorFlow and PyTorch models, given their inputs, returning a tuple with the maximum observed
difference and its source.
Compares the TensorFlow and PyTorch models, given their inputs, returning a dictionary with all tensor
differences.
"""
pt_outputs = pt_model(**pt_input, output_hidden_states=True)
tf_outputs = tf_model(**tf_input, output_hidden_states=True)
Expand All @@ -104,18 +104,14 @@ def compare_pt_tf_models(pt_model, pt_input, tf_model, tf_input):
f" {tf_out_attrs})"
)

# 2. For each output attribute, ALL values must be the same
def _compate_pt_tf_models(pt_out, tf_out, attr_name=""):
max_difference = 0
max_difference_source = ""
# 2. For each output attribute, computes the difference
def _find_pt_tf_differences(pt_out, tf_out, differences, attr_name=""):

# If the current attribute is a tensor, it is a leaf and we make the comparison. Otherwise, we will dig in
# recursivelly, keeping the name of the attribute.
if isinstance(pt_out, (torch.Tensor)):
difference = np.max(np.abs(pt_out.detach().numpy() - tf_out.numpy()))
if difference > max_difference:
max_difference = difference
max_difference_source = attr_name
if isinstance(pt_out, torch.Tensor):
tensor_difference = np.max(np.abs(pt_out.detach().numpy() - tf_out.numpy()))
differences[attr_name] = tensor_difference
else:
root_name = attr_name
for i, pt_item in enumerate(pt_out):
Expand All @@ -127,14 +123,11 @@ def _compate_pt_tf_models(pt_out, tf_out, attr_name=""):
else:
branch_name = root_name + f"[{i}]"
tf_item = tf_out[i]
difference, difference_source = _compate_pt_tf_models(pt_item, tf_item, branch_name)
if difference > max_difference:
max_difference = difference
max_difference_source = difference_source
differences = _find_pt_tf_differences(pt_item, tf_item, differences, branch_name)

return max_difference, max_difference_source
return differences

return _compate_pt_tf_models(pt_outputs, tf_outputs)
return _find_pt_tf_differences(pt_outputs, tf_outputs, {})

def __init__(self, model_name: str, local_dir: str, no_pr: bool, new_weights: bool, *args):
self._logger = logging.get_logger("transformers-cli/pt_to_tf")
Expand Down Expand Up @@ -213,11 +206,15 @@ def run(self):
tf_input.update({"decoder_input_ids": tf.convert_to_tensor(decoder_input_ids)})

# Confirms that cross loading PT weights into TF worked.
crossload_diff, diff_source = self.compare_pt_tf_models(pt_model, pt_input, tf_from_pt_model, tf_input)
if crossload_diff >= MAX_ERROR:
crossload_differences = self.find_pt_tf_differences(pt_model, pt_input, tf_from_pt_model, tf_input)
max_crossload_diff = max(crossload_differences.values())
if max_crossload_diff > MAX_ERROR:
raise ValueError(
"The cross-loaded TF model has different outputs, something went wrong! (max difference ="
f" {crossload_diff:.3e}, observed in {diff_source})"
"The cross-loaded TensorFlow model has different outputs, something went wrong! Exaustive list of"
f" maximum tensor differences above the error threshold ({MAX_ERROR}):\n"
+ "\n".join(
[f"{key}: {value:.3e}" for key, value in crossload_differences.items() if value > MAX_ERROR]
)
)

# Save the weights in a TF format (if needed) and confirms that the results are still good
Expand All @@ -226,11 +223,15 @@ def run(self):
tf_from_pt_model.save_weights(tf_weights_path)
del tf_from_pt_model # will no longer be used, and may have a large memory footprint
tf_model = tf_class.from_pretrained(self._local_dir)
converted_diff, diff_source = self.compare_pt_tf_models(pt_model, pt_input, tf_model, tf_input)
if converted_diff >= MAX_ERROR:
conversion_differences = self.find_pt_tf_differences(pt_model, pt_input, tf_model, tf_input)
max_conversion_diff = max(conversion_differences.values())
if max_conversion_diff > MAX_ERROR:
raise ValueError(
"The converted TF model has different outputs, something went wrong! (max difference ="
f" {converted_diff:.3e}, observed in {diff_source})"
"The converted TensorFlow model has different outputs, something went wrong! Exaustive list of maximum"
f" tensor differences above the error threshold ({MAX_ERROR}):\n"
+ "\n".join(
[f"{key}: {value:.3e}" for key, value in conversion_differences.items() if value > MAX_ERROR]
)
)

if not self._no_pr:
Expand All @@ -245,8 +246,10 @@ def run(self):
create_pr=True,
pr_commit_summary="Add TF weights",
pr_commit_description=(
f"Validated by the `pt_to_tf` CLI. Max crossload output difference={crossload_diff:.3e};"
f" Max converted output difference={converted_diff:.3e}."
"Model converted by the `transformers`' `pt_to_tf` CLI -- all converted model outputs and"
" hidden layers were validated against its Pytorch counterpart. Maximum crossload output"
f" difference={max_crossload_diff:.3e}; Maximum converted output"
f" difference={max_conversion_diff:.3e}."
),
)
self._logger.warn(f"PR open in {hub_pr_url}")
Expand Down

0 comments on commit 66e8656

Please sign in to comment.