Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

CLI: add stricter automatic checks to pt-to-tf #17588

Merged
merged 9 commits into from Jun 8, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 2 additions & 1 deletion docker/transformers-all-latest-gpu/Dockerfile
Expand Up @@ -4,7 +4,8 @@ LABEL maintainer="Hugging Face"
ARG DEBIAN_FRONTEND=noninteractive

RUN apt update
RUN apt install -y git libsndfile1-dev tesseract-ocr espeak-ng python3 python3-pip ffmpeg
RUN apt install -y git libsndfile1-dev tesseract-ocr espeak-ng python3 python3-pip ffmpeg git-lfs
RUN git lfs install
RUN python3 -m pip install --no-cache-dir --upgrade pip

ARG REF=main
Expand Down
89 changes: 73 additions & 16 deletions src/transformers/commands/pt_to_tf.py
Expand Up @@ -14,13 +14,14 @@

import os
from argparse import ArgumentParser, Namespace
from importlib import import_module

import numpy as np
from datasets import load_dataset

from huggingface_hub import Repository, upload_file

from .. import AutoFeatureExtractor, AutoModel, AutoTokenizer, TFAutoModel, is_tf_available, is_torch_available
from .. import AutoConfig, AutoFeatureExtractor, AutoTokenizer, is_tf_available, is_torch_available
from ..utils import logging
from . import BaseTransformersCLICommand

Expand Down Expand Up @@ -80,6 +81,51 @@ def register_subcommand(parser: ArgumentParser):
)
train_parser.set_defaults(func=convert_command_factory)

@staticmethod
def compare_pt_tf_models(pt_model, pt_input, tf_model, tf_input):
"""
Compares the tf and the pt models, given their inputs, returning a tuple with the maximum observed difference
gante marked this conversation as resolved.
Show resolved Hide resolved
and its source.
"""
pt_outputs = pt_model(**pt_input, output_hidden_states=True)
tf_outputs = tf_model(**tf_input, output_hidden_states=True)

# 1. All keys must be the same
if set(pt_outputs.keys()) != set(tf_outputs.keys()):
raise ValueError("The model outputs have different attributes, aborting.")

# 2. For each key, ALL values must be the same
def compate_pt_tf_values(pt_out, tf_out, attr_name=""):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

compare_pt_tf_values ..?

max_difference = 0
max_difference_source = ""

# If the current attribute is a tensor, it is a leaf and we make the comparison. Otherwise, we will dig in
# recursivelly, keeping the name of the attribute.
if isinstance(pt_out, (torch.Tensor)):
difference = np.max(np.abs(pt_out.detach().numpy() - tf_out.numpy()))
if difference > max_difference:
max_difference = difference
max_difference_source = attr_name
else:
root_name = attr_name
for i, pt_item in enumerate(pt_out):
# If it is a named attribute, we keep the name. Otherwise, just its index.
if isinstance(pt_item, str):
branch_name = root_name + pt_item
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel that we will need to have something like f"{root_name}.{pt_item}", i.e. to include some kind of separator, so the result names will be more readable.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no need, the names are not nested (at the moment). As it is structured, it will print the variable as we would write on a python terminal to get it, so we can copy-paste it for further inspection -- e.g. past_key_values[0][2]

tf_item = tf_out[pt_item]
pt_item = pt_out[pt_item]
else:
branch_name = root_name + f"[{i}]"
tf_item = tf_out[i]
difference, difference_source = compate_pt_tf_values(pt_item, tf_item, branch_name)
Copy link
Collaborator

@ydshieh ydshieh Jun 7, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

compare_pt_tf_values ..?

if difference > max_difference:
max_difference = difference
max_difference_source = difference_source

return max_difference, max_difference_source

return compate_pt_tf_values(pt_outputs, tf_outputs)
Copy link
Collaborator

@ydshieh ydshieh Jun 7, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

compare_pt_tf_values 馃槃

Copy link
Member Author

@gante gante Jun 7, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will rename to _compare_pt_tf_models (to avoid a name clash, as Matt mentioned)


def __init__(self, model_name: str, local_dir: str, no_pr: bool, *args):
self._logger = logging.get_logger("transformers-cli/pt_to_tf")
self._model_name = model_name
Expand Down Expand Up @@ -119,8 +165,22 @@ def run(self):
repo = Repository(local_dir=self._local_dir, clone_from=self._model_name)
repo.git_pull() # in case the repo already exists locally, but with an older commit

# Load config and get the appropriate architecture -- the latter is needed to convert the head's weights
config = AutoConfig.from_pretrained(self._local_dir)
architectures = config.architectures
if architectures is None: # No architecture defined -- use auto classes
pt_class = getattr(import_module("transformers"), "AutoModel")
tf_class = getattr(import_module("transformers"), "TFAutoModel")
self._logger.warn("No detected architecture, using auto classes")
gante marked this conversation as resolved.
Show resolved Hide resolved
else: # Architecture defined -- use it
if len(architectures) > 1:
raise ValueError(f"More than one architecture was found, aborting. (architectures = {architectures})")
pt_class = getattr(import_module("transformers"), architectures[0])
tf_class = getattr(import_module("transformers"), "TF" + architectures[0])
gante marked this conversation as resolved.
Show resolved Hide resolved
self._logger.warn(f"Detected architecture: {architectures[0]}")

# Load models and acquire a basic input for its modality.
pt_model = AutoModel.from_pretrained(self._local_dir)
pt_model = pt_class.from_pretrained(self._local_dir)
main_input_name = pt_model.main_input_name
if main_input_name == "input_ids":
pt_input, tf_input = self.get_text_inputs()
Expand All @@ -130,7 +190,7 @@ def run(self):
pt_input, tf_input = self.get_audio_inputs()
else:
raise ValueError(f"Can't detect the model modality (`main_input_name` = {main_input_name})")
tf_from_pt_model = TFAutoModel.from_pretrained(self._local_dir, from_pt=True)
tf_from_pt_model = tf_class.from_pretrained(self._local_dir, from_pt=True)

# Extra input requirements, in addition to the input modality
if hasattr(pt_model, "encoder") and hasattr(pt_model, "decoder"):
Expand All @@ -139,27 +199,24 @@ def run(self):
tf_input.update({"decoder_input_ids": tf.convert_to_tensor(decoder_input_ids)})

# Confirms that cross loading PT weights into TF worked.
pt_last_hidden_state = pt_model(**pt_input).last_hidden_state.detach().numpy()
tf_from_pt_last_hidden_state = tf_from_pt_model(**tf_input).last_hidden_state.numpy()
crossload_diff = np.max(np.abs(pt_last_hidden_state - tf_from_pt_last_hidden_state))
crossload_diff, diff_source = self.compare_pt_tf_models(pt_model, pt_input, tf_from_pt_model, tf_input)
if crossload_diff >= MAX_ERROR:
raise ValueError(
"The cross-loaded TF model has different last hidden states, something went wrong! (max difference ="
f" {crossload_diff})"
"The cross-loaded TF model has different outputs, something went wrong! (max difference ="
f" {crossload_diff:.3e}, observed in {diff_source})"
)

# Save the weights in a TF format (if they don't exist) and confirms that the results are still good
tf_weights_path = os.path.join(self._local_dir, TF_WEIGHTS_NAME)
if not os.path.exists(tf_weights_path):
tf_from_pt_model.save_weights(tf_weights_path)
del tf_from_pt_model, pt_model # will no longer be used, and may have a large memory footprint
tf_model = TFAutoModel.from_pretrained(self._local_dir)
tf_last_hidden_state = tf_model(**tf_input).last_hidden_state.numpy()
converted_diff = np.max(np.abs(pt_last_hidden_state - tf_last_hidden_state))
del tf_from_pt_model # will no longer be used, and may have a large memory footprint
tf_model = tf_class.from_pretrained(self._local_dir)
converted_diff, diff_source = self.compare_pt_tf_models(pt_model, pt_input, tf_model, tf_input)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function is called as compare_pt_tf_models here, but as @ydshieh mentioned it's defined as compate_pt_tf_models, so this bit will probably crash.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was my bad, my comment should be compate_pt_tf_values --> compare_pt_tf_values.

Nothing wrong about compare_pt_tf_models.

if converted_diff >= MAX_ERROR:
raise ValueError(
"The converted TF model has different last hidden states, something went wrong! (max difference ="
f" {converted_diff})"
"The converted TF model has different outputs, something went wrong! (max difference ="
f" {converted_diff:.3e}, observed in {diff_source})"
)

if not self._no_pr:
Expand All @@ -174,8 +231,8 @@ def run(self):
create_pr=True,
pr_commit_summary="Add TF weights",
pr_commit_description=(
f"Validated by the `pt_to_tf` CLI. Max crossload hidden state difference={crossload_diff:.3e};"
f" Max converted hidden state difference={converted_diff:.3e}."
f"Validated by the `pt_to_tf` CLI. Max crossload output difference={crossload_diff:.3e};"
f" Max converted output difference={converted_diff:.3e}."
),
)
self._logger.warn(f"PR open in {hub_pr_url}")
Expand Down