From 6a1355d9438e4b6629c27f7ce001e5e14c93f0e6 Mon Sep 17 00:00:00 2001 From: Tuan Nguyen Date: Wed, 14 Feb 2024 23:47:26 -0500 Subject: [PATCH 1/4] Evaluation script for summarization task w/ Rouge score --- .../accelerate_default_config.4gpus.yaml | 16 + .../cnn_dailymail/rouge_accelerate.py | 278 ++++++++++++++++++ .../cnn_dailymail/rouge_accelerate.sh | 15 + 3 files changed, 309 insertions(+) create mode 100644 src/sparseml/experimental/evaluation/summarization/cnn_dailymail/accelerate_default_config.4gpus.yaml create mode 100644 src/sparseml/experimental/evaluation/summarization/cnn_dailymail/rouge_accelerate.py create mode 100755 src/sparseml/experimental/evaluation/summarization/cnn_dailymail/rouge_accelerate.sh diff --git a/src/sparseml/experimental/evaluation/summarization/cnn_dailymail/accelerate_default_config.4gpus.yaml b/src/sparseml/experimental/evaluation/summarization/cnn_dailymail/accelerate_default_config.4gpus.yaml new file mode 100644 index 00000000000..e5f0253a00f --- /dev/null +++ b/src/sparseml/experimental/evaluation/summarization/cnn_dailymail/accelerate_default_config.4gpus.yaml @@ -0,0 +1,16 @@ +compute_environment: LOCAL_MACHINE +debug: false +distributed_type: MULTI_GPU +downcast_bf16: 'no' +gpu_ids: all +machine_rank: 0 +main_training_function: main +mixed_precision: 'no' +num_machines: 1 +num_processes: 4 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/src/sparseml/experimental/evaluation/summarization/cnn_dailymail/rouge_accelerate.py b/src/sparseml/experimental/evaluation/summarization/cnn_dailymail/rouge_accelerate.py new file mode 100644 index 00000000000..002f7332523 --- /dev/null +++ b/src/sparseml/experimental/evaluation/summarization/cnn_dailymail/rouge_accelerate.py @@ -0,0 +1,278 @@ +import argparse +import json +import os + +import datasets +import torch +from torch.utils.data import DataLoader, Dataset +from tqdm import tqdm +from transformers import AutoModelForCausalLM, AutoTokenizer, default_data_collator + +import evaluate +import nltk +from accelerate import Accelerator +from lm_eval.utils import stop_sequences_criteria + + +nltk.download("punkt") + + +ARTICLE_TEMPLATE = "Article:\n{article}" + +SUMMARY_TEMPLATE = "\n\n### Summarization:\n" + + +def load_model(model_path): + return AutoModelForCausalLM.from_pretrained(model_path) + + +def load_tokenizer(model_path): + tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side="left") + tokenizer.pad_token = tokenizer.eos_token + tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True + return tokenizer + + +def postprocess_text(preds, labels, first_k_preds): + preds = [pred.strip() for pred in preds] + labels = [label.strip() for label in labels] + + # ROUGE expects a newline after each sentence + preds = ["\n".join(nltk.sent_tokenize(pred)[:first_k_preds]) for pred in preds] + labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels] + + return preds, labels + + +def main(model_path, batch, dataset_path, dataset_name): + model = load_model(model_path) + tokenizer = load_tokenizer(model_path) + + accelerator = Accelerator() if args.use_accelerate else None + + with accelerator.main_process_first(): + dataset = datasets.load_dataset("cnn_dailymail", "3.0.0", split="validation") + if args.samples > 0: + dataset = dataset.shuffle(seed=42).select(range(args.samples)) + + result_path = os.path.join(model_path, args.output_dir) + if not os.path.exists(result_path): + os.makedirs(result_path) + + if args.generation == "lm-eval-harness": + # Similar to the default decoding strategy used by + # lm-evaluation-harness + gen_kwargs = { + "do_sample": False, + "temperature": 1.0, # To disable warning + "top_p": 1.0, # To disable warning + "bos_token_id": 1, + "eos_token_id": 2, + "pad_token_id": 0, + "max_new_tokens": 512, + } + elif args.generation == "top_k": + # Similar to GPT-2 decoding strategy used for summarization + # (see their paper, section 3.6) + gen_kwargs = { + "do_sample": True, + "top_k": args.top_k, + "max_new_tokens": args.max_new_tokens, + } + else: + raise ValueError(f"Unknown decoding strategy: {args.generation}") + + def _process_sample(sample): + article = ARTICLE_TEMPLATE.format(article=sample["article"]) + tok_summary = tokenizer(SUMMARY_TEMPLATE) + + # Exclude the BOS from the tokenized summary + tok_summary = {k: tok_summary[k][1:] for k in tok_summary} + + max_tok_article = args.max_input_length - len(tok_summary["input_ids"]) + tok_article = tokenizer( + article, max_length=max_tok_article, truncation=True, padding="max_length" + ) + + model_inputs = {k: tok_article[k] + tok_summary[k] for k in tok_article} + + prompt_length = len(model_inputs["input_ids"]) + highlights = tokenizer( + sample["highlights"], + max_length=prompt_length, + truncation=True, + padding="max_length", + ) + model_inputs["tok_highlights"] = highlights["input_ids"] + + # Using "label" for sample ID since it will be recognized and reserved by + # the default data collator used below + model_inputs["label"] = hash(sample["id"]) + + return model_inputs + + tokenized_dataset = dataset.map(_process_sample, batched=False, num_proc=16) + remove_columns = dataset.column_names + tokenized_dataset = tokenized_dataset.remove_columns(remove_columns) + tokenized_dataset.set_format("torch") + + data_collator = default_data_collator + dataloader = DataLoader( + tokenized_dataset, + batch_size=batch, + shuffle=False, + num_workers=16, + pin_memory=True, + collate_fn=data_collator, + ) + if accelerator is not None: + model, dataloader = accelerator.prepare(model, dataloader) + + if accelerator.is_main_process: + saved_preds = {"ids": [], "predictions": [], "highlights": []} + rouge_score = evaluate.load("rouge") + + model.eval() + for step, batch in enumerate(tqdm(dataloader)): + labels = batch["labels"] + with torch.no_grad(): + if args.generation == "lm-eval-harness": + stop = ["\n\n", "Article:"] + initial_decoder_input_length = batch["input_ids"].shape[1] + batch_size = batch["input_ids"].shape[0] + stopping_criteria = stop_sequences_criteria( + tokenizer, stop, initial_decoder_input_length, batch_size + ) + else: + stopping_criteria = None + + prompt_length = batch["input_ids"].shape[1] + if args.use_accelerate: + generated_tokens = accelerator.unwrap_model(model).generate( + batch["input_ids"], + attention_mask=batch["attention_mask"], + stopping_criteria=stopping_criteria, + **gen_kwargs, + ) + generated_tokens = accelerator.pad_across_processes( + generated_tokens, dim=1, pad_index=tokenizer.pad_token_id + ) + highlights = batch["tok_highlights"] + + generated_tokens = accelerator.gather(generated_tokens).cpu().numpy() + highlights = accelerator.gather(highlights).cpu().numpy() + labels = accelerator.gather(labels).cpu().numpy() + else: + # Code path for debugging only with 1 GPU + batch = {k: batch[k].to(model.device) for k in batch.keys()} + generated_tokens = model.generate( + batch["input_ids"], + attention_mask=batch["attention_mask"], + stopping_criteria=stopping_criteria, + **gen_kwargs, + ) + highlights = batch["tok_highlights"] + + generated_tokens = generated_tokens.cpu().numpy() + highlights = highlights.cpu().numpy() + labels = labels.cpu().numpy() + batch = None + torch.cuda.empty_cache() + + if isinstance(generated_tokens, tuple): + generated_tokens = generated_tokens[0] + + generated_summary_tokens = generated_tokens[:, prompt_length:] + decoded_preds = tokenizer.batch_decode( + generated_summary_tokens, skip_special_tokens=True + ) + decoded_highlights = tokenizer.batch_decode( + highlights, skip_special_tokens=True + ) + decoded_preds, decoded_highlights = postprocess_text( + decoded_preds, decoded_highlights, args.first_k_preds + ) + + assert len(labels) == len(decoded_preds) == len(decoded_highlights) + + if accelerator.is_main_process: + saved_preds["ids"] += labels.tolist() + saved_preds["predictions"] += decoded_preds + saved_preds["highlights"] += decoded_highlights + + if accelerator.is_main_process: + results = rouge_score.compute( + predictions=saved_preds["predictions"], references=saved_preds["highlights"] + ) + print(f"Rouge score: {results}") + + with open(os.path.join(result_path, f"predictions.json"), "w") as f: + json.dump(saved_preds, f) + + result_file_name = ( + f"rouge_{args.samples}samples.json" + if args.samples > 0 + else f"rouge_full_validation.json" + ) + results.update( + { + "generation": args.generation, + "generation_config": gen_kwargs, + "prompt": ARTICLE_TEMPLATE + SUMMARY_TEMPLATE, + } + ) + result_file_path = os.path.join(result_path, result_file_name) + assert not os.path.exists( + result_file_path + ), f"File {result_file_path} already exists! Results will not be saved." + with open(result_file_path, "w") as f: + json.dump(results, f) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Compute ROUGE score") + parser.add_argument( + "--use-accelerate", + type=bool, + default=False, + help="Use accelerate. Default: False", + ) + parser.add_argument("--model-path", type=str, help="model path") + parser.add_argument( + "--output-dir", type=str, default="rouge", help="Output directory" + ) + parser.add_argument( + "--max-new-tokens", type=int, default=512, help="Max new tokens" + ) + parser.add_argument( + "--max-input-length", + type=int, + default=2048, + help="Max tokenized input length to model", + ) + parser.add_argument( + "--first-k-preds", type=int, default=-1, help="Use first K predictions" + ) + parser.add_argument("--batch", type=int, default=8, help="Batch size") + parser.add_argument( + "--samples", type=int, default=-1, help="Numer of samples. Default to all." + ) + parser.add_argument( + "--generation", + type=str, + default="lm-eval-harness", + help="Generation strategies: lm-eval-harness, top_k", + ) + parser.add_argument( + "--top-k", type=int, default=10, help="top_k in the top_k stategy" + ) + parser.add_argument( + "--dataset-path", type=str, default="cnn_dailymail", help="dataset path" + ) + parser.add_argument( + "--dataset-name", type=str, default="3.0.0", help="dataset name" + ) + + args = parser.parse_args() + + main(args.model_path, args.batch, args.dataset_path, args.dataset_name) diff --git a/src/sparseml/experimental/evaluation/summarization/cnn_dailymail/rouge_accelerate.sh b/src/sparseml/experimental/evaluation/summarization/cnn_dailymail/rouge_accelerate.sh new file mode 100755 index 00000000000..a927ccd2b0c --- /dev/null +++ b/src/sparseml/experimental/evaluation/summarization/cnn_dailymail/rouge_accelerate.sh @@ -0,0 +1,15 @@ +#!/bin/bash + +export CUDA_VISIBLE_DEVICES=0,1,2,3 +NPROC=$(($(echo $CUDA_VISIBLE_DEVICES | grep -o "," | wc -l)+1)) + +SRC_ROOT=$HOME/work/llama2.cnn_dailymail.eval/src/my_scripts + +source $SRC_ROOT/start_here.sh + +for MODEL_NAME in sparse_ft@SRCcerebras50@lr1e-4@WD0.0@B8@GrAcc8@W0.1@ep2@GPUs7@ID15577 +do + M=$HOME/models/llama2/cnn_dailymail/llama-recipes/sparse_finetuned/$MODEL_NAME + accelerate launch --config_file $SRC_ROOT/accelerate_default_config.${NPROC}gpus.yaml $SRC_ROOT/rouge_accelerate.py --model-path $M --batch 2 --samples 16 --generation top_k --top-k 2 --max-new-tokens 100 --first-k-preds 3 --use-accelerate 1 --output-dir rouge +done + From 491428f00421c2074636657ffe68547a7648fdba Mon Sep 17 00:00:00 2001 From: Tuan Nguyen Date: Thu, 15 Feb 2024 10:40:17 -0500 Subject: [PATCH 2/4] Copyright --- .../cnn_dailymail/rouge_accelerate.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/sparseml/experimental/evaluation/summarization/cnn_dailymail/rouge_accelerate.py b/src/sparseml/experimental/evaluation/summarization/cnn_dailymail/rouge_accelerate.py index 002f7332523..6e3905a66e3 100644 --- a/src/sparseml/experimental/evaluation/summarization/cnn_dailymail/rouge_accelerate.py +++ b/src/sparseml/experimental/evaluation/summarization/cnn_dailymail/rouge_accelerate.py @@ -1,3 +1,18 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + import argparse import json import os From 4b54a30e3117532c4db6cb85d113fc90cb1298cf Mon Sep 17 00:00:00 2001 From: Tuan Nguyen Date: Thu, 15 Feb 2024 11:00:35 -0500 Subject: [PATCH 3/4] Formatting --- .../summarization/cnn_dailymail/rouge_accelerate.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/sparseml/experimental/evaluation/summarization/cnn_dailymail/rouge_accelerate.py b/src/sparseml/experimental/evaluation/summarization/cnn_dailymail/rouge_accelerate.py index 6e3905a66e3..d0b5e0d2510 100644 --- a/src/sparseml/experimental/evaluation/summarization/cnn_dailymail/rouge_accelerate.py +++ b/src/sparseml/experimental/evaluation/summarization/cnn_dailymail/rouge_accelerate.py @@ -19,7 +19,7 @@ import datasets import torch -from torch.utils.data import DataLoader, Dataset +from torch.utils.data import DataLoader from tqdm import tqdm from transformers import AutoModelForCausalLM, AutoTokenizer, default_data_collator @@ -221,13 +221,13 @@ def _process_sample(sample): ) print(f"Rouge score: {results}") - with open(os.path.join(result_path, f"predictions.json"), "w") as f: + with open(os.path.join(result_path, "predictions.json"), "w") as f: json.dump(saved_preds, f) result_file_name = ( f"rouge_{args.samples}samples.json" if args.samples > 0 - else f"rouge_full_validation.json" + else "rouge_full_validation.json" ) results.update( { From fec5fd9deb6a38b2c6711cd6fd935edd9a08a2ec Mon Sep 17 00:00:00 2001 From: Tuan Nguyen Date: Mon, 25 Mar 2024 16:56:03 -0400 Subject: [PATCH 4/4] Apply recipe for quantization models --- .../cnn_dailymail/rouge_accelerate.py | 42 +++++++------------ 1 file changed, 16 insertions(+), 26 deletions(-) diff --git a/src/sparseml/experimental/evaluation/summarization/cnn_dailymail/rouge_accelerate.py b/src/sparseml/experimental/evaluation/summarization/cnn_dailymail/rouge_accelerate.py index d0b5e0d2510..7e86b2269e8 100644 --- a/src/sparseml/experimental/evaluation/summarization/cnn_dailymail/rouge_accelerate.py +++ b/src/sparseml/experimental/evaluation/summarization/cnn_dailymail/rouge_accelerate.py @@ -1,33 +1,19 @@ -# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - import argparse import json import os import datasets -import torch -from torch.utils.data import DataLoader -from tqdm import tqdm -from transformers import AutoModelForCausalLM, AutoTokenizer, default_data_collator - import evaluate import nltk +import torch from accelerate import Accelerator from lm_eval.utils import stop_sequences_criteria - +from sparseml.pytorch.model_load.helpers import ( + RECIPE_FILE_NAME, apply_recipe_structure_to_model) +from torch.utils.data import DataLoader, Dataset +from tqdm import tqdm +from transformers import (AutoModelForCausalLM, AutoTokenizer, + default_data_collator) nltk.download("punkt") @@ -38,7 +24,13 @@ def load_model(model_path): - return AutoModelForCausalLM.from_pretrained(model_path) + model = AutoModelForCausalLM.from_pretrained(model_path) + input_recipe_path = os.path.join(model_path, RECIPE_FILE_NAME) + if os.path.exists(input_recipe_path): + apply_recipe_structure_to_model( + model=model, recipe_path=input_recipe_path, model_path=model_path + ) + return model def load_tokenizer(model_path): @@ -75,8 +67,6 @@ def main(model_path, batch, dataset_path, dataset_name): os.makedirs(result_path) if args.generation == "lm-eval-harness": - # Similar to the default decoding strategy used by - # lm-evaluation-harness gen_kwargs = { "do_sample": False, "temperature": 1.0, # To disable warning @@ -221,13 +211,13 @@ def _process_sample(sample): ) print(f"Rouge score: {results}") - with open(os.path.join(result_path, "predictions.json"), "w") as f: + with open(os.path.join(result_path, f"predictions.json"), "w") as f: json.dump(saved_preds, f) result_file_name = ( f"rouge_{args.samples}samples.json" if args.samples > 0 - else "rouge_full_validation.json" + else f"rouge_full_validation.json" ) results.update( {