Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Flax example tests #14599

Merged
merged 13 commits into from
Dec 6, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,69 @@ jobs:
- store_artifacts:
path: ~/transformers/reports

run_examples_flax:
working_directory: ~/transformers
docker:
- image: circleci/python:3.7
environment:
OMP_NUM_THREADS: 1
TRANSFORMERS_IS_CI: yes
resource_class: xlarge
parallelism: 1
steps:
- checkout
- restore_cache:
keys:
- v0.4-flax_examples-{{ checksum "setup.py" }}
- v0.4-{{ checksum "setup.py" }}
- run: pip install --upgrade pip
- run: sudo pip install .[flax,testing,sentencepiece]
- run: pip install -r examples/flax/_tests_requirements.txt
- save_cache:
key: v0.4-flax_examples-{{ checksum "setup.py" }}
paths:
- '~/.cache/pip'
- run: python utils/tests_fetcher.py --filters examples tests | tee test_preparation.txt
- store_artifacts:
path: ~/transformers/test_preparation.txt
- run: |
if [ -f test_list.txt ]; then
python -m pytest -n 8 --dist=loadfile -s --make-reports=examples_flax ./examples/flax/ | tee tests_output.txt
fi
- store_artifacts:
path: ~/transformers/flax_examples_output.txt
- store_artifacts:
path: ~/transformers/reports

run_examples_flax_all:
working_directory: ~/transformers
docker:
- image: circleci/python:3.7
environment:
OMP_NUM_THREADS: 1
TRANSFORMERS_IS_CI: yes
resource_class: xlarge
parallelism: 1
steps:
- checkout
- restore_cache:
keys:
- v0.4-flax_examples-{{ checksum "setup.py" }}
- v0.4-{{ checksum "setup.py" }}
- run: pip install --upgrade pip
- run: sudo pip install .[flax,testing,sentencepiece]
- run: pip install -r examples/flax/_tests_requirements.txt
- save_cache:
key: v0.4-flax_examples-{{ checksum "setup.py" }}
paths:
- '~/.cache/pip'
- run: |
TRANSFORMERS_IS_CI=1 python -m pytest -n 8 --dist=loadfile -s --make-reports=examples_flax ./examples/flax/ | tee examples_output.txt
- store_artifacts:
path: ~/transformers/flax_examples_output.txt
- store_artifacts:
path: ~/transformers/reports

run_tests_hub:
working_directory: ~/transformers
docker:
Expand Down
7 changes: 7 additions & 0 deletions examples/flax/_tests_requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
datasets >= 1.1.3
pytest
conllu
nltk
rouge-score
seqeval
tensorboard
27 changes: 27 additions & 0 deletions examples/flax/language-modeling/run_clm_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
"""
# You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments.

import json
import logging
import math
import os
Expand Down Expand Up @@ -672,6 +673,32 @@ def eval_step(params, batch):
if training_args.push_to_hub:
repo.push_to_hub(commit_message=f"Saving weights and logs of step {cur_step}", blocking=False)

# Eval after training
if training_args.do_eval:
eval_metrics = []
eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size)
eval_steps = len(eval_dataset) // eval_batch_size
for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False):
# Model forward
batch = shard(next(eval_loader))
metrics = p_eval_step(state.params, batch)
eval_metrics.append(metrics)

# normalize eval metrics
eval_metrics = get_metrics(eval_metrics)
eval_metrics = jax.tree_map(lambda x: jnp.mean(x).item(), eval_metrics)

try:
eval_metrics["perplexity"] = math.exp(eval_metrics["loss"])
except OverflowError:
eval_metrics["perplexity"] = float("inf")

if jax.process_index() == 0:
eval_metrics = {f"eval_{metric_name}": value for metric_name, value in eval_metrics.items()}
path = os.path.join(training_args.output_dir, "eval_results.json")
with open(path, "w") as f:
json.dump(eval_metrics, f, indent=4, sort_keys=True)


if __name__ == "__main__":
main()
42 changes: 41 additions & 1 deletion examples/flax/language-modeling/run_mlm_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
https://huggingface.co/models?filter=masked-lm
"""
import json
import logging
import math
import os
import sys
import time
Expand Down Expand Up @@ -271,7 +273,7 @@ def write_eval_metric(summary_writer, eval_metrics, step):
summary_writer.scalar(f"eval_{metric_name}", value, step)


if __name__ == "__main__":
def main():
# See all possible arguments in src/transformers/training_args.py
# or by passing the --help flag to this script.
# We now keep distinct sets of args, for a cleaner separation of concerns.
Expand Down Expand Up @@ -700,3 +702,41 @@ def eval_step(params, batch):
tokenizer.save_pretrained(training_args.output_dir)
if training_args.push_to_hub:
repo.push_to_hub(commit_message=f"Saving weights and logs of step {cur_step}", blocking=False)

# Eval after training
if training_args.do_eval:
num_eval_samples = len(tokenized_datasets["validation"])
eval_samples_idx = jnp.arange(num_eval_samples)
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)

eval_metrics = []
for _, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx]
model_inputs = data_collator(samples, pad_to_multiple_of=16)

# Model forward
model_inputs = shard(model_inputs.data)
metrics = p_eval_step(state.params, model_inputs)
eval_metrics.append(metrics)

# normalize eval metrics
eval_metrics = get_metrics(eval_metrics)
eval_metrics = jax.tree_map(lambda metric: jnp.sum(metric).item(), eval_metrics)
eval_normalizer = eval_metrics.pop("normalizer")
eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics)

try:
perplexity = math.exp(eval_metrics["loss"])
except OverflowError:
perplexity = float("inf")
eval_metrics["perplexity"] = perplexity

if jax.process_index() == 0:
eval_metrics = {f"eval_{metric_name}": value for metric_name, value in eval_metrics.items()}
path = os.path.join(training_args.output_dir, "eval_results.json")
with open(path, "w") as f:
json.dump(eval_metrics, f, indent=4, sort_keys=True)


if __name__ == "__main__":
main()
38 changes: 34 additions & 4 deletions examples/flax/language-modeling/run_t5_mlm_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
https://huggingface.co/models?filter=t5
"""
# You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments.
import json
import logging
import os
import sys
Expand Down Expand Up @@ -401,7 +402,7 @@ def write_eval_metric(summary_writer, eval_metrics, step):
summary_writer.scalar(f"eval_{metric_name}", value, step)


if __name__ == "__main__":
def main():
# See all possible arguments in src/transformers/training_args.py
# or by passing the --help flag to this script.
# We now keep distinct sets of args, for a cleaner separation of concerns.
Expand Down Expand Up @@ -522,9 +523,7 @@ def write_eval_metric(summary_writer, eval_metrics, step):
model_args.config_name, cache_dir=model_args.cache_dir, vocab_size=len(tokenizer)
)
elif model_args.model_name_or_path:
config = T5Config.from_pretrained(
model_args.model_name_or_path, cache_dir=model_args.cache_dir, vocab_size=len(tokenizer)
)
config = T5Config.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
patil-suraj marked this conversation as resolved.
Show resolved Hide resolved
else:
config = CONFIG_MAPPING[model_args.model_type]()
logger.warning("You are instantiating a new config instance from scratch.")
Expand Down Expand Up @@ -617,6 +616,7 @@ def group_texts(examples):
model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
)
else:
config.vocab_size = len(tokenizer)
model = FlaxT5ForConditionalGeneration(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype))

# Data collator
Expand Down Expand Up @@ -808,3 +808,33 @@ def eval_step(params, batch):
tokenizer.save_pretrained(training_args.output_dir)
if training_args.push_to_hub:
repo.push_to_hub(commit_message=f"Saving weights and logs of step {cur_step}", blocking=False)

# Eval after training
if training_args.do_eval:
num_eval_samples = len(tokenized_datasets["validation"])
eval_samples_idx = jnp.arange(num_eval_samples)
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)

eval_metrics = []
for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx]
model_inputs = data_collator(samples)

# Model forward
model_inputs = shard(model_inputs.data)
metrics = p_eval_step(state.params, model_inputs)
eval_metrics.append(metrics)

# get eval metrics
eval_metrics = get_metrics(eval_metrics)
eval_metrics = jax.tree_map(lambda metric: jnp.mean(metric).item(), eval_metrics)

if jax.process_index() == 0:
eval_metrics = {f"eval_{metric_name}": value for metric_name, value in eval_metrics.items()}
path = os.path.join(training_args.output_dir, "eval_results.json")
with open(path, "w") as f:
json.dump(eval_metrics, f, indent=4, sort_keys=True)


if __name__ == "__main__":
main()
53 changes: 53 additions & 0 deletions examples/flax/question-answering/run_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"""
# You can also adapt this script on your own question answering task. Pointers for this are left as comments.

import json
import logging
import os
import random
Expand Down Expand Up @@ -911,6 +912,58 @@ def eval_step(state, batch):
epochs.desc = f"Epoch ... {epoch + 1}/{num_epochs}"
# endregion

# Eval after training
if training_args.do_eval:
eval_metrics = {}
all_start_logits = []
all_end_logits = []

eva_loader = eval_data_collator(eval_dataset, eval_batch_size)
for batch in tqdm(eva_loader, total=len(eval_dataset) // eval_batch_size, desc="Evaluating ...", position=2):
_ = batch.pop("example_id")
_ = batch.pop("offset_mapping")
predictions = p_eval_step(state, batch)
start_logits = np.array([pred for pred in chain(*predictions[0])])
end_logits = np.array([pred for pred in chain(*predictions[1])])
all_start_logits.append(start_logits)
all_end_logits.append(end_logits)

# evaluate also on leftover examples (not divisible by batch_size)
num_leftover_samples = len(eval_dataset) % eval_batch_size

# make sure leftover batch is evaluated on one device
if num_leftover_samples > 0 and jax.process_index() == 0:
# take leftover samples
batch = eval_dataset[-num_leftover_samples:]
batch = {k: np.array(v) for k, v in batch.items()}
_ = batch.pop("example_id")
_ = batch.pop("offset_mapping")

predictions = eval_step(unreplicate(state), batch)
start_logits = np.array([pred for pred in predictions[0]])
end_logits = np.array([pred for pred in predictions[1]])
all_start_logits.append(start_logits)
all_end_logits.append(end_logits)

max_len = max([x.shape[1] for x in all_start_logits]) # Get the max_length of the tensor

# concatenate the numpy array
start_logits_concat = create_and_fill_np_array(all_start_logits, eval_dataset, max_len)
end_logits_concat = create_and_fill_np_array(all_end_logits, eval_dataset, max_len)

# delete the list of numpy arrays
del all_start_logits
del all_end_logits
outputs_numpy = (start_logits_concat, end_logits_concat)
prediction = post_processing_function(eval_examples, eval_dataset, outputs_numpy)
eval_metrics = compute_metrics(prediction)

if jax.process_index() == 0:
eval_metrics = {f"eval_{metric_name}": value for metric_name, value in eval_metrics.items()}
path = os.path.join(training_args.output_dir, "eval_results.json")
with open(path, "w") as f:
json.dump(eval_metrics, f, indent=4, sort_keys=True)


if __name__ == "__main__":
main()
8 changes: 8 additions & 0 deletions examples/flax/summarization/run_summarization_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"""
# You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.

import json
import logging
import os
import sys
Expand Down Expand Up @@ -816,6 +817,13 @@ def generate_step(params, batch):
desc = f"Predict Loss: {pred_metrics['loss']} | {rouge_desc})"
logger.info(desc)

# save final metrics in json
if jax.process_index() == 0:
rouge_metrics = {f"test_{metric_name}": value for metric_name, value in rouge_metrics.items()}
path = os.path.join(training_args.output_dir, "test_results.json")
with open(path, "w") as f:
json.dump(rouge_metrics, f, indent=4, sort_keys=True)


if __name__ == "__main__":
main()
Loading