Skip to content

Commit

Permalink
Add Flax example tests (#14599)
Browse files Browse the repository at this point in the history
* add test for glue

* add tests for clm

* fix clm test

* add summrization tests

* more tests

* fix few tests

* add test for t5 mlm

* fix t5 mlm test

* fix tests for multi device

* cleanup

* ci job

* fix metric file name

* make t5 more robust
  • Loading branch information
patil-suraj authored Dec 6, 2021
1 parent 803a8cd commit c5bd732
Show file tree
Hide file tree
Showing 11 changed files with 553 additions and 6 deletions.
63 changes: 63 additions & 0 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,69 @@ jobs:
- store_artifacts:
path: ~/transformers/reports

run_examples_flax:
working_directory: ~/transformers
docker:
- image: circleci/python:3.7
environment:
OMP_NUM_THREADS: 1
TRANSFORMERS_IS_CI: yes
resource_class: xlarge
parallelism: 1
steps:
- checkout
- restore_cache:
keys:
- v0.4-flax_examples-{{ checksum "setup.py" }}
- v0.4-{{ checksum "setup.py" }}
- run: pip install --upgrade pip
- run: sudo pip install .[flax,testing,sentencepiece]
- run: pip install -r examples/flax/_tests_requirements.txt
- save_cache:
key: v0.4-flax_examples-{{ checksum "setup.py" }}
paths:
- '~/.cache/pip'
- run: python utils/tests_fetcher.py --filters examples tests | tee test_preparation.txt
- store_artifacts:
path: ~/transformers/test_preparation.txt
- run: |
if [ -f test_list.txt ]; then
python -m pytest -n 8 --dist=loadfile -s --make-reports=examples_flax ./examples/flax/ | tee tests_output.txt
fi
- store_artifacts:
path: ~/transformers/flax_examples_output.txt
- store_artifacts:
path: ~/transformers/reports

run_examples_flax_all:
working_directory: ~/transformers
docker:
- image: circleci/python:3.7
environment:
OMP_NUM_THREADS: 1
TRANSFORMERS_IS_CI: yes
resource_class: xlarge
parallelism: 1
steps:
- checkout
- restore_cache:
keys:
- v0.4-flax_examples-{{ checksum "setup.py" }}
- v0.4-{{ checksum "setup.py" }}
- run: pip install --upgrade pip
- run: sudo pip install .[flax,testing,sentencepiece]
- run: pip install -r examples/flax/_tests_requirements.txt
- save_cache:
key: v0.4-flax_examples-{{ checksum "setup.py" }}
paths:
- '~/.cache/pip'
- run: |
TRANSFORMERS_IS_CI=1 python -m pytest -n 8 --dist=loadfile -s --make-reports=examples_flax ./examples/flax/ | tee examples_output.txt
- store_artifacts:
path: ~/transformers/flax_examples_output.txt
- store_artifacts:
path: ~/transformers/reports

run_tests_hub:
working_directory: ~/transformers
docker:
Expand Down
7 changes: 7 additions & 0 deletions examples/flax/_tests_requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
datasets >= 1.1.3
pytest
conllu
nltk
rouge-score
seqeval
tensorboard
27 changes: 27 additions & 0 deletions examples/flax/language-modeling/run_clm_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
"""
# You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments.

import json
import logging
import math
import os
Expand Down Expand Up @@ -672,6 +673,32 @@ def eval_step(params, batch):
if training_args.push_to_hub:
repo.push_to_hub(commit_message=f"Saving weights and logs of step {cur_step}", blocking=False)

# Eval after training
if training_args.do_eval:
eval_metrics = []
eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size)
eval_steps = len(eval_dataset) // eval_batch_size
for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False):
# Model forward
batch = shard(next(eval_loader))
metrics = p_eval_step(state.params, batch)
eval_metrics.append(metrics)

# normalize eval metrics
eval_metrics = get_metrics(eval_metrics)
eval_metrics = jax.tree_map(lambda x: jnp.mean(x).item(), eval_metrics)

try:
eval_metrics["perplexity"] = math.exp(eval_metrics["loss"])
except OverflowError:
eval_metrics["perplexity"] = float("inf")

if jax.process_index() == 0:
eval_metrics = {f"eval_{metric_name}": value for metric_name, value in eval_metrics.items()}
path = os.path.join(training_args.output_dir, "eval_results.json")
with open(path, "w") as f:
json.dump(eval_metrics, f, indent=4, sort_keys=True)


if __name__ == "__main__":
main()
42 changes: 41 additions & 1 deletion examples/flax/language-modeling/run_mlm_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
https://huggingface.co/models?filter=masked-lm
"""
import json
import logging
import math
import os
import sys
import time
Expand Down Expand Up @@ -271,7 +273,7 @@ def write_eval_metric(summary_writer, eval_metrics, step):
summary_writer.scalar(f"eval_{metric_name}", value, step)


if __name__ == "__main__":
def main():
# See all possible arguments in src/transformers/training_args.py
# or by passing the --help flag to this script.
# We now keep distinct sets of args, for a cleaner separation of concerns.
Expand Down Expand Up @@ -700,3 +702,41 @@ def eval_step(params, batch):
tokenizer.save_pretrained(training_args.output_dir)
if training_args.push_to_hub:
repo.push_to_hub(commit_message=f"Saving weights and logs of step {cur_step}", blocking=False)

# Eval after training
if training_args.do_eval:
num_eval_samples = len(tokenized_datasets["validation"])
eval_samples_idx = jnp.arange(num_eval_samples)
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)

eval_metrics = []
for _, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx]
model_inputs = data_collator(samples, pad_to_multiple_of=16)

# Model forward
model_inputs = shard(model_inputs.data)
metrics = p_eval_step(state.params, model_inputs)
eval_metrics.append(metrics)

# normalize eval metrics
eval_metrics = get_metrics(eval_metrics)
eval_metrics = jax.tree_map(lambda metric: jnp.sum(metric).item(), eval_metrics)
eval_normalizer = eval_metrics.pop("normalizer")
eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics)

try:
perplexity = math.exp(eval_metrics["loss"])
except OverflowError:
perplexity = float("inf")
eval_metrics["perplexity"] = perplexity

if jax.process_index() == 0:
eval_metrics = {f"eval_{metric_name}": value for metric_name, value in eval_metrics.items()}
path = os.path.join(training_args.output_dir, "eval_results.json")
with open(path, "w") as f:
json.dump(eval_metrics, f, indent=4, sort_keys=True)


if __name__ == "__main__":
main()
38 changes: 34 additions & 4 deletions examples/flax/language-modeling/run_t5_mlm_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
https://huggingface.co/models?filter=t5
"""
# You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments.
import json
import logging
import os
import sys
Expand Down Expand Up @@ -401,7 +402,7 @@ def write_eval_metric(summary_writer, eval_metrics, step):
summary_writer.scalar(f"eval_{metric_name}", value, step)


if __name__ == "__main__":
def main():
# See all possible arguments in src/transformers/training_args.py
# or by passing the --help flag to this script.
# We now keep distinct sets of args, for a cleaner separation of concerns.
Expand Down Expand Up @@ -522,9 +523,7 @@ def write_eval_metric(summary_writer, eval_metrics, step):
model_args.config_name, cache_dir=model_args.cache_dir, vocab_size=len(tokenizer)
)
elif model_args.model_name_or_path:
config = T5Config.from_pretrained(
model_args.model_name_or_path, cache_dir=model_args.cache_dir, vocab_size=len(tokenizer)
)
config = T5Config.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
else:
config = CONFIG_MAPPING[model_args.model_type]()
logger.warning("You are instantiating a new config instance from scratch.")
Expand Down Expand Up @@ -617,6 +616,7 @@ def group_texts(examples):
model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
)
else:
config.vocab_size = len(tokenizer)
model = FlaxT5ForConditionalGeneration(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype))

# Data collator
Expand Down Expand Up @@ -808,3 +808,33 @@ def eval_step(params, batch):
tokenizer.save_pretrained(training_args.output_dir)
if training_args.push_to_hub:
repo.push_to_hub(commit_message=f"Saving weights and logs of step {cur_step}", blocking=False)

# Eval after training
if training_args.do_eval:
num_eval_samples = len(tokenized_datasets["validation"])
eval_samples_idx = jnp.arange(num_eval_samples)
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)

eval_metrics = []
for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx]
model_inputs = data_collator(samples)

# Model forward
model_inputs = shard(model_inputs.data)
metrics = p_eval_step(state.params, model_inputs)
eval_metrics.append(metrics)

# get eval metrics
eval_metrics = get_metrics(eval_metrics)
eval_metrics = jax.tree_map(lambda metric: jnp.mean(metric).item(), eval_metrics)

if jax.process_index() == 0:
eval_metrics = {f"eval_{metric_name}": value for metric_name, value in eval_metrics.items()}
path = os.path.join(training_args.output_dir, "eval_results.json")
with open(path, "w") as f:
json.dump(eval_metrics, f, indent=4, sort_keys=True)


if __name__ == "__main__":
main()
53 changes: 53 additions & 0 deletions examples/flax/question-answering/run_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"""
# You can also adapt this script on your own question answering task. Pointers for this are left as comments.

import json
import logging
import os
import random
Expand Down Expand Up @@ -911,6 +912,58 @@ def eval_step(state, batch):
epochs.desc = f"Epoch ... {epoch + 1}/{num_epochs}"
# endregion

# Eval after training
if training_args.do_eval:
eval_metrics = {}
all_start_logits = []
all_end_logits = []

eva_loader = eval_data_collator(eval_dataset, eval_batch_size)
for batch in tqdm(eva_loader, total=len(eval_dataset) // eval_batch_size, desc="Evaluating ...", position=2):
_ = batch.pop("example_id")
_ = batch.pop("offset_mapping")
predictions = p_eval_step(state, batch)
start_logits = np.array([pred for pred in chain(*predictions[0])])
end_logits = np.array([pred for pred in chain(*predictions[1])])
all_start_logits.append(start_logits)
all_end_logits.append(end_logits)

# evaluate also on leftover examples (not divisible by batch_size)
num_leftover_samples = len(eval_dataset) % eval_batch_size

# make sure leftover batch is evaluated on one device
if num_leftover_samples > 0 and jax.process_index() == 0:
# take leftover samples
batch = eval_dataset[-num_leftover_samples:]
batch = {k: np.array(v) for k, v in batch.items()}
_ = batch.pop("example_id")
_ = batch.pop("offset_mapping")

predictions = eval_step(unreplicate(state), batch)
start_logits = np.array([pred for pred in predictions[0]])
end_logits = np.array([pred for pred in predictions[1]])
all_start_logits.append(start_logits)
all_end_logits.append(end_logits)

max_len = max([x.shape[1] for x in all_start_logits]) # Get the max_length of the tensor

# concatenate the numpy array
start_logits_concat = create_and_fill_np_array(all_start_logits, eval_dataset, max_len)
end_logits_concat = create_and_fill_np_array(all_end_logits, eval_dataset, max_len)

# delete the list of numpy arrays
del all_start_logits
del all_end_logits
outputs_numpy = (start_logits_concat, end_logits_concat)
prediction = post_processing_function(eval_examples, eval_dataset, outputs_numpy)
eval_metrics = compute_metrics(prediction)

if jax.process_index() == 0:
eval_metrics = {f"eval_{metric_name}": value for metric_name, value in eval_metrics.items()}
path = os.path.join(training_args.output_dir, "eval_results.json")
with open(path, "w") as f:
json.dump(eval_metrics, f, indent=4, sort_keys=True)


if __name__ == "__main__":
main()
8 changes: 8 additions & 0 deletions examples/flax/summarization/run_summarization_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"""
# You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.

import json
import logging
import os
import sys
Expand Down Expand Up @@ -816,6 +817,13 @@ def generate_step(params, batch):
desc = f"Predict Loss: {pred_metrics['loss']} | {rouge_desc})"
logger.info(desc)

# save final metrics in json
if jax.process_index() == 0:
rouge_metrics = {f"test_{metric_name}": value for metric_name, value in rouge_metrics.items()}
path = os.path.join(training_args.output_dir, "test_results.json")
with open(path, "w") as f:
json.dump(rouge_metrics, f, indent=4, sort_keys=True)


if __name__ == "__main__":
main()
Loading

0 comments on commit c5bd732

Please sign in to comment.