Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Flax example tests #14599

Merged
merged 13 commits into from
Dec 6, 2021
31 changes: 31 additions & 0 deletions examples/flax/language-modeling/run_clm_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
"""
# You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments.

import json
import logging
import math
import os
Expand Down Expand Up @@ -219,6 +220,13 @@ def write_eval_metric(summary_writer, eval_metrics, step):
summary_writer.scalar(f"eval_{metric_name}", value, step)


def save_metrics(split, output_dir, metrics):
metrics = {f"{split}_{metric_name}": value for metric_name, value in metrics.items()}
path = os.path.join(output_dir, f"{split}_results.json")
with open(path, "w") as f:
json.dump(metrics, f, indent=4, sort_keys=True)


def create_learning_rate_fn(
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
) -> Callable[[int], jnp.array]:
Expand Down Expand Up @@ -672,6 +680,29 @@ def eval_step(params, batch):
if training_args.push_to_hub:
repo.push_to_hub(commit_message=f"Saving weights and logs of step {cur_step}", blocking=False)

# Eval after training
if training_args.do_eval:
eval_metrics = []
eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size)
eval_steps = len(eval_dataset) // eval_batch_size
for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False):
# Model forward
batch = shard(next(eval_loader))
metrics = p_eval_step(state.params, batch)
eval_metrics.append(metrics)

# normalize eval metrics
eval_metrics = get_metrics(eval_metrics)
eval_metrics = jax.tree_map(lambda x: jnp.mean(x).item(), eval_metrics)

try:
eval_metrics["perplexity"] = math.exp(eval_metrics["loss"])
except OverflowError:
eval_metrics["perplexity"] = float("inf")

if jax.process_index() == 0:
save_metrics("eval", training_args.output_dir, eval_metrics)


if __name__ == "__main__":
main()
12 changes: 12 additions & 0 deletions examples/flax/summarization/run_summarization_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"""
# You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.

import json
import logging
import os
import sys
Expand Down Expand Up @@ -278,6 +279,13 @@ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
summary_writer.scalar(f"eval_{metric_name}", value, step)


def save_metrics(split, output_dir, metrics):
metrics = {f"{split}_{metric_name}": value for metric_name, value in metrics.items()}
path = os.path.join(output_dir, f"{split}_results.json")
with open(path, "w") as f:
json.dump(metrics, f, indent=4, sort_keys=True)


def create_learning_rate_fn(
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
) -> Callable[[int], jnp.array]:
Expand Down Expand Up @@ -816,6 +824,10 @@ def generate_step(params, batch):
desc = f"Predict Loss: {pred_metrics['loss']} | {rouge_desc})"
logger.info(desc)

# save final metrics in json
if jax.process_index() == 0:
save_metrics("test", training_args.output_dir, rouge_metrics)
patil-suraj marked this conversation as resolved.
Show resolved Hide resolved


if __name__ == "__main__":
main()
154 changes: 154 additions & 0 deletions examples/flax/test_examples.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
# coding=utf-8
patil-suraj marked this conversation as resolved.
Show resolved Hide resolved
# Copyright 2021 HuggingFace Inc..
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import argparse
import json
import logging
import os
import sys
from unittest.mock import patch

from transformers.testing_utils import CaptureLogger, TestCasePlus, get_gpu_count, slow


SRC_DIRS = [
os.path.join(os.path.dirname(__file__), dirname)
for dirname in [
"text-classification",
"language-modeling",
"summarization",
]
]
sys.path.extend(SRC_DIRS)


if SRC_DIRS is not None:
import run_clm_flax
import run_flax_glue
import run_summarization_flax


logging.basicConfig(level=logging.DEBUG)

logger = logging.getLogger()


def get_setup_file():
parser = argparse.ArgumentParser()
parser.add_argument("-f")
args = parser.parse_args()
return args.f


def get_results(output_dir, split="eval"):
results = {}
path = os.path.join(output_dir, f"{split}_results.json")
if os.path.exists(path):
with open(path, "r") as f:
results = json.load(f)
else:
raise ValueError(f"can't find {path}")
return results


class ExamplesTests(TestCasePlus):
def test_run_glue(self):
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)

tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
run_glue.py
--model_name_or_path distilbert-base-uncased
--output_dir {tmp_dir}
--train_file ./tests/fixtures/tests_samples/MRPC/train.csv
--validation_file ./tests/fixtures/tests_samples/MRPC/dev.csv
--per_device_train_batch_size=2
--per_device_eval_batch_size=1
--learning_rate=1e-4
--max_train_steps=10
--num_warmup_steps=2
--seed=42
--max_length=128
""".split()

with patch.object(sys, "argv", testargs):
run_flax_glue.main()
result = get_results(tmp_dir)
self.assertGreaterEqual(result["eval_accuracy"], 0.75)

def test_run_clm(self):
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)

tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
run_clm_flax.py
--model_name_or_path distilgpt2
--train_file ./tests/fixtures/sample_text.txt
--validation_file ./tests/fixtures/sample_text.txt
--do_train
--do_eval
--block_size 128
--per_device_train_batch_size 5
--per_device_eval_batch_size 5
--num_train_epochs 2
--logging_steps 2 --eval_steps 2
--output_dir {tmp_dir}
--overwrite_output_dir
""".split()

# if torch.cuda.device_count() > 1:
# Skipping because there are not enough batches to train the model + would need a drop_last to work.
# return

with patch.object(sys, "argv", testargs):
run_clm_flax.main()
result = get_results(tmp_dir)
self.assertLess(result["eval_perplexity"], 100)

# @slow
def test_run_summarization(self):
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)

tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
run_summarization.py
--model_name_or_path t5-small
--train_file tests/fixtures/tests_samples/xsum/sample.json
--validation_file tests/fixtures/tests_samples/xsum/sample.json
--test_file tests/fixtures/tests_samples/xsum/sample.json
--output_dir {tmp_dir}
--overwrite_output_dir
--max_steps=50
--warmup_steps=8
--do_train
--do_eval
--do_predict
--learning_rate=2e-4
--per_device_train_batch_size=2
--per_device_eval_batch_size=1
--predict_with_generate
""".split()

with patch.object(sys, "argv", testargs):
run_summarization_flax.main()
result = get_results(tmp_dir, split="test")
self.assertGreaterEqual(result["test_rouge1"], 10)
self.assertGreaterEqual(result["test_rouge2"], 2)
self.assertGreaterEqual(result["test_rougeL"], 7)
self.assertGreaterEqual(result["test_rougeLsum"], 7)
12 changes: 12 additions & 0 deletions examples/flax/text-classification/run_flax_glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# limitations under the License.
""" Finetuning a 🤗 Flax Transformers model for sequence classification on GLUE."""
import argparse
import json
import logging
import os
import random
Expand Down Expand Up @@ -259,6 +260,13 @@ def glue_eval_data_collator(dataset: Dataset, batch_size: int):
yield batch


def save_metrics(split, output_dir, metrics):
metrics = {f"{split}_{metric_name}": value for metric_name, value in metrics.items()}
path = os.path.join(output_dir, f"{split}_results.json")
with open(path, "w") as f:
json.dump(metrics, f, indent=4, sort_keys=True)


def main():
args = parse_args()

Expand Down Expand Up @@ -522,6 +530,10 @@ def eval_step(state, batch):
if args.push_to_hub:
repo.push_to_hub(commit_message=f"Saving weights and logs of epoch {epoch}", blocking=False)

# save the eval metrics in json
if jax.process_index() == 0:
save_metrics("eval", args.output_dir, eval_metric)


if __name__ == "__main__":
main()