Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
f98ef14
BIG Reorganize examples (#4213)
julien-c May 7, 2020
3fdaddc
upstream
sshleifer May 7, 2020
4ed9067
docs
sshleifer May 7, 2020
84b5e23
Merge remote-tracking branch 'upstream/master'
sshleifer May 8, 2020
4d2e6ee
Merge remote-tracking branch 'upstream/master'
sshleifer May 8, 2020
94de64b
boom boom
sshleifer May 8, 2020
5775d7b
Merge remote-tracking branch 'upstream/master'
sshleifer May 10, 2020
784333e
Merge branch 'master' of github.com:sshleifer/transformers_fork
sshleifer May 10, 2020
71a6716
Merge remote-tracking branch 'upstream/master'
sshleifer May 14, 2020
3283f08
Merge remote-tracking branch 'upstream/master'
sshleifer Jul 2, 2020
219b91e
Merge remote-tracking branch 'upstream/master'
sshleifer Jul 12, 2020
97db343
Merge remote-tracking branch 'upstream/master'
sshleifer Jul 15, 2020
1d47552
Merge remote-tracking branch 'upstream/master'
sshleifer Jul 19, 2020
3c9e56e
Merge remote-tracking branch 'upstream/master'
sshleifer Jul 30, 2020
6e8244d
Merge remote-tracking branch 'upstream/master'
sshleifer Aug 20, 2020
7d26cbc
Merge remote-tracking branch 'upstream/master'
sshleifer Aug 21, 2020
efc3817
Merge remote-tracking branch 'upstream/master'
sshleifer Sep 7, 2020
bbfc09b
boom boom
sshleifer Sep 11, 2020
dbbbe38
boom boom
sshleifer Sep 11, 2020
3250876
boom boom
sshleifer Sep 11, 2020
85ddccb
boom boom
sshleifer Sep 11, 2020
b003026
boom boom
sshleifer Sep 11, 2020
c31f917
boom boom
sshleifer Sep 11, 2020
cf6ceed
boom boom
sshleifer Sep 11, 2020
d59f3cb
boom boom
sshleifer Sep 11, 2020
9590a01
boom boom
sshleifer Sep 11, 2020
186021c
boom boom
sshleifer Sep 11, 2020
854f4db
boom boom
sshleifer Sep 11, 2020
6f534e4
boom boom
sshleifer Sep 11, 2020
9189d68
boom boom
sshleifer Sep 11, 2020
4962cc6
boom boom
sshleifer Sep 11, 2020
5147b33
boom boom
sshleifer Sep 11, 2020
281b171
Merge branch 'master' into distro-cut
sshleifer Sep 13, 2020
1d2e040
Merge remote-tracking branch 'upstream/master'
sshleifer Sep 13, 2020
3736445
Merge branch 'master' into distro-cut
sshleifer Sep 13, 2020
e762119
boom boom
sshleifer Sep 13, 2020
8706259
boom boom
sshleifer Sep 13, 2020
9cb1174
boom boom
sshleifer Sep 13, 2020
ee1610f
boom boom
sshleifer Sep 13, 2020
fe2634e
argparse
sshleifer Sep 13, 2020
78d342d
boom boom
sshleifer Sep 13, 2020
2a9af47
boom boom
sshleifer Sep 13, 2020
ece95c0
boom boom
sshleifer Sep 13, 2020
69c828c
boom boom
sshleifer Sep 13, 2020
7dd89b4
boom boom
sshleifer Sep 13, 2020
212d356
boom boom
sshleifer Sep 13, 2020
ef228aa
no ddp
sshleifer Sep 13, 2020
b5c9f0f
remove superfluous
sshleifer Sep 13, 2020
514f523
boom boom
sshleifer Sep 13, 2020
86d12f5
boom boom
sshleifer Sep 13, 2020
559cf29
boom boom
sshleifer Sep 13, 2020
9873a46
boom boom
sshleifer Sep 13, 2020
b18e94a
boom boom
sshleifer Sep 13, 2020
4e94436
boom boom
sshleifer Sep 13, 2020
1561e49
boom boom
sshleifer Sep 13, 2020
28ec293
boom boom
sshleifer Sep 13, 2020
75c9b04
boom boom
sshleifer Sep 13, 2020
236381f
Rename helper script
sshleifer Sep 13, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions examples/seq2seq/aggregate_distributed_results.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from pathlib import Path

import fire


try:
from .utils import calculate_bleu, calculate_rouge, load_json, save_json, write_txt_file
except ImportError:
from utils import calculate_bleu, calculate_rouge, load_json, save_json, write_txt_file


def combine_partial_results(
result_dir: str, save_dir: str = None, save_prefix=None, calc_bleu=False, just_metrics=False
):
"""Write first n lines of each file f in src_dir to dest_dir/f """
src_dir = Path(result_dir)
save_dir = Path(save_dir)
save_dir.mkdir(exist_ok=True)
paths_to_combine = list(src_dir.glob("rank*.json"))
records = []
for partial_result in paths_to_combine:
records.extend(load_json(partial_result))
preds = [x["pred"] for x in records]
labels = [x["label"] for x in records]
score_fn = calculate_bleu if calc_bleu else calculate_rouge
metrics = score_fn(preds, labels)
save_json(metrics, save_dir.joinpath("metrics.json")) # better would be be {prefix}_{rouge|bleu}.json
print(metrics)
if just_metrics:
return

if save_prefix is None:
save_prefix = "generated"
print("using generated as prefix")

tgt_path = save_dir.joinpath(f"{save_prefix}.target")
write_txt_file(labels, tgt_path)
pred_path = save_dir.joinpath(f"{save_prefix}.pred_target")
write_txt_file(preds, pred_path)
if "source" in records[0]:
src_path = save_dir.joinpath(f"{save_prefix}.source")
write_txt_file([x["source"] for x in records], src_path)


if __name__ == "__main__":
fire.Fire(combine_partial_results)
139 changes: 139 additions & 0 deletions examples/seq2seq/run_distributed_eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
import argparse
import warnings
from logging import getLogger
from pathlib import Path
from typing import Dict

import torch
from torch.utils.data import DataLoader
from tqdm import tqdm

from transformers import AutoModelForSeq2SeqLM, AutoTokenizer


logger = getLogger(__name__)

try:
from .utils import Seq2SeqDataset, parse_numeric_cl_kwargs, save_json, use_task_specific_params
except ImportError:
from utils import Seq2SeqDataset, parse_numeric_cl_kwargs, save_json, use_task_specific_params

DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"


def eval_data_dir(
data_dir,
save_dir: str,
model_name: str,
bs: int = 8,
max_source_length: int = 1024,
type_path="val",
n_obs=None,
fp16=False,
save_source=False,
num_beams: int = 4,
task="summarization",
local_rank=None,
**generate_kwargs,
) -> Dict:
"""Run evaluation on part of the data for one gpu and save to {save_dir}/rank_{rank}_output.json"""
model_name = str(model_name)
assert local_rank is not None
torch.distributed.init_process_group(backend="nccl", rank=local_rank)

save_dir = Path(save_dir)
save_path = save_dir.joinpath(f"rank_{local_rank}_output.json")
torch.cuda.set_device(local_rank)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name).cuda()
if fp16:
model = model.half()

tokenizer = AutoTokenizer.from_pretrained(model_name)
logger.info(f"Inferred tokenizer type: {tokenizer.__class__}") # if this is wrong, check config.model_type.
use_task_specific_params(model, task) # update config with task specific params
ds = Seq2SeqDataset(
tokenizer,
data_dir,
max_source_length,
max_target_length=1024,
type_path=type_path,
n_obs=n_obs,
prefix=model.config.prefix,
)
sampler = ds.make_sortish_sampler(bs, distributed=True)
data_loader = DataLoader(ds, sampler=sampler, batch_size=bs, collate_fn=ds.collate_fn)
dec_kwargs = dict(skip_special_tokens=True, clean_up_tokenization_spaces=False) # tokenizer.decode
results = []
for batch in tqdm(data_loader):
summaries = model.generate(
input_ids=batch["input_ids"].to(model.device),
attention_mask=batch["attention_mask"].to(model.device),
num_beams=num_beams,
**generate_kwargs,
)
preds = tokenizer.batch_decode(summaries, **dec_kwargs)
labels = tokenizer.batch_decode(batch["labels"], **dec_kwargs)
if save_source:
docs = tokenizer.batch_decode(batch["input_ids"], **dec_kwargs)
for i in range(len(labels)):
label, pred = labels[i], preds[i]
if save_source:
results.append(dict(pred=pred, label=label, source=docs[i]))
else:
results.append(dict(pred=pred, label=label))
save_json(results, save_path)
return results


def run_generate():
parser = argparse.ArgumentParser(
epilog="Unspecified args like --num_beams=2 --decoder_start_token_id=4 are passed to model.generate"
)
parser.add_argument("--input_path", type=str, help="like cnn_dm/test.source")
parser.add_argument(
"--model_name",
type=str,
help="like facebook/bart-large-cnn,t5-base, etc.",
default="sshleifer/distilbart-xsum-12-3",
)
parser.add_argument("--save_dir", type=str, help="where to save", default="tmp_gen")
parser.add_argument("--prefix", type=str, default="test", help="which subset to evaluate typically train/val/test")
parser.add_argument("--reference_path", type=str, required=False, help="like cnn_dm/test.target")
parser.add_argument("--score_path", type=str, required=False, default="metrics.json", help="where to save metrics")
parser.add_argument("--task", type=str, default="summarization", help="used for task_specific_params + metrics")
parser.add_argument("--bs", type=int, default=8, required=False, help="batch size")
parser.add_argument(
"--local_rank", type=int, default=-1, required=False, help="should be passed by distributed.launch"
)

parser.add_argument(
"--n_obs", type=int, default=None, required=False, help="How many observations. Defaults to all."
)
parser.add_argument("--fp16", action="store_true")
parser.add_argument("--save_source", action="store_true")

args, rest = parser.parse_known_args()
parsed = parse_numeric_cl_kwargs(rest)
if parsed:
print(f"parsed the following generate kwargs: {parsed}")
Path(args.save_dir).mkdir(exist_ok=True)
if args.reference_path is None and Path(args.score_path).exists():
warnings.warn(f"score_path {args.score_path} will be overwritten unless you type ctrl-c.")
eval_data_dir(
args.input_path,
args.save_dir,
args.model_name,
prefix=args.prefix,
batch_size=args.bs,
fp16=args.fp16,
task=args.task,
local_rank=args.local_rank,
n_obs=args.n_obs,
save_source=args.save_source,
**parsed,
)


if __name__ == "__main__":
# Usage for MT:
run_generate()
7 changes: 7 additions & 0 deletions examples/seq2seq/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,3 +387,10 @@ def parse_numeric_cl_kwargs(unparsed_args: List[str]) -> Dict[str, Union[int, fl

result[unparsed_args[i][2:]] = value
return result


def write_txt_file(ordered_tgt, path):
f = Path(path).open("w")
for ln in ordered_tgt:
f.write(ln + "\n")
f.flush()