Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 24 additions & 23 deletions integrations/transformers/run_distill_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
# limitations under the License.

"""
Example script for integrating spaseml with the transformers library to perform model distillation.
This script is addopted from hugging face's implementation for Question Answering on the SQUAD Dataset.
Example script for integrating spaseml with the transformers library to perform model distillation.
This script is addopted from hugging face's implementation for Question Answering on the SQUAD Dataset.
Hugging Face's original implementation is regularly updated and can be found at https://github.com/huggingface/transformers/blob/master/examples/question-answering/run_qa.py
This script will:
- Load transformer based models
Expand Down Expand Up @@ -54,12 +54,12 @@
[--onnx_export_path] \
[--layers_to_keep] \

Train, prune, and evaluate a transformer base question answering model on squad.
Train, prune, and evaluate a transformer base question answering model on squad.
-h, --help show this help message and exit
--teacher_model_name_or_path The name or path of model which will be used for distilation.
Note, this model needs to be trained for QA task already.
--student_model_name_or_path The name or path of the model wich will be trained using distilation.
--temperature Hyperparameter which controls model distilation
--temperature Hyperparameter which controls model distilation
--distill_hardness Hyperparameter which controls how much of the loss comes from teacher vs training labels
--model_name_or_path The path to the transformers model you wish to train
or the name of the pretrained language model you wish
Expand All @@ -72,21 +72,21 @@
or not. Default is false.
--do_eval Boolean denoting if the model should be evaluated
or not. Default is false.
--per_device_train_batch_size Size of each training batch based on samples per GPU.
--per_device_train_batch_size Size of each training batch based on samples per GPU.
12 will fit in a 11gb GPU, 16 in a 16gb.
--per_device_eval_batch_size Size of each training batch based on samples per GPU.
--per_device_eval_batch_size Size of each training batch based on samples per GPU.
12 will fit in a 11gb GPU, 16 in a 16gb.
--learning_rate Learning rate initial float value. ex: 3e-5.
--max_seq_length Int for the max sequence length to be parsed as a context
--max_seq_length Int for the max sequence length to be parsed as a context
window. ex: 384 tokens.
--output_dir Path which model checkpoints and paths should be saved.
--overwrite_output_dir Boolean to define if the
--overwrite_output_dir Boolean to define if the
--cache_dir Directiory which cached transformer files(datasets, models
, tokenizers) are saved for fast loading.
, tokenizers) are saved for fast loading.
--preprocessing_num_workers The amount of cpu workers which are used to process datasets
--seed Int which determines what random seed is for training/shuffling
--nm_prune_config Path to the neural magic prune configuration file. examples can
be found in prune_config_files but are customized for bert-base-uncased.
be found in prune_config_files but are customized for bert-base-uncased.
--do_onnx_export Boolean denoting if the model should be exported to onnx
--onnx_export_path Path where onnx model path will be exported. ex: onnx-export
--layers_to_keep Number of layers to keep from original model. Layers are dropped before training
Expand Down Expand Up @@ -611,7 +611,7 @@ def prepare_validation_features(examples):
]
return tokenized_examples

transformers.utils.logging.set_verbosity_info()
transformers.utils.logging.set_verbosity_info()
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
Expand Down Expand Up @@ -639,7 +639,7 @@ def prepare_validation_features(examples):
)

logger.warning(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, "
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
)

Expand Down Expand Up @@ -690,10 +690,10 @@ def prepare_validation_features(examples):

student_model_parameters = filter(lambda p: p.requires_grad, student_model.parameters())
params = sum([np.prod(p.size()) for p in student_model_parameters])
logger.info("Student Model has %s parameters", params)
logger.info("Student Model has %s parameters", params)
teacher_model_parameters = filter(lambda p: p.requires_grad, teacher_model.parameters())
params = sum([np.prod(p.size()) for p in teacher_model_parameters])
logger.info("Teacher Model has %s parameters", params)
logger.info("Teacher Model has %s parameters", params)
# Tokenizer check: this script requires a fast tokenizer.
if not isinstance(tokenizer, PreTrainedTokenizerFast):
raise ValueError(
Expand All @@ -710,7 +710,7 @@ def prepare_validation_features(examples):
context_column_name = "context" if "context" in column_names else column_names[1]
answer_column_name = "answers" if "answers" in column_names else column_names[2]

pad_on_right = tokenizer.padding_side == "right"
pad_on_right = tokenizer.padding_side == "right"

data_collator = (
default_data_collator
Expand Down Expand Up @@ -744,15 +744,16 @@ def prepare_validation_features(examples):
)
####################################################################################
# Start SparseML Integration
####################################################################################
optim = load_optimizer(student_model, TrainingArguments)
steps_per_epoch = math.ceil(len(datasets["train"]) / (training_args.per_device_train_batch_size*training_args._n_gpu))
manager = ScheduledModifierManager.from_yaml(data_args.nm_prune_config)
training_args.num_train_epochs = float(manager.modifiers[0].end_epoch)
optim = ScheduledOptimizer(optim, student_model, manager, steps_per_epoch=steps_per_epoch, loggers=None)
####################################################################################
if training_args.do_train:
optim = load_optimizer(student_model, TrainingArguments)
steps_per_epoch = math.ceil(len(train_dataset) / (training_args.per_device_train_batch_size * training_args._n_gpu))
manager = ScheduledModifierManager.from_yaml(data_args.nm_prune_config)
training_args.num_train_epochs = float(manager.modifiers[0].end_epoch)
optim = ScheduledOptimizer(optim, student_model, manager, steps_per_epoch=steps_per_epoch, loggers=None)
####################################################################################
# End SparseML Integration
####################################################################################
####################################################################################
# Initialize our Trainer
trainer = DistillQuestionAnsweringTrainer(
model=student_model,
Expand All @@ -764,7 +765,7 @@ def prepare_validation_features(examples):
data_collator=data_collator,
post_process_function=post_processing_function,
compute_metrics=compute_metrics,
optimizers=(optim, None),
optimizers=(optim, None) if training_args.do_train else (None, None),
teacher=teacher_model,
distill_hardness = model_args.distill_hardness,
temperature = model_args.temperature,
Expand Down
45 changes: 23 additions & 22 deletions integrations/transformers/run_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
# limitations under the License.

"""
Example script for integrating spaseml with the transformers library.
This script is addopted from hugging face's implementation for Question Answering on the SQUAD Dataset.
Example script for integrating spaseml with the transformers library.
This script is addopted from hugging face's implementation for Question Answering on the SQUAD Dataset.
Hugging Face's original implementation is regularly updated and can be found at https://github.com/huggingface/transformers/blob/master/examples/question-answering/run_qa.py
This script will:
- Load transformer based modesl
Expand Down Expand Up @@ -50,7 +50,7 @@
[--do_onnx_export]
[--onnx_export_path]

Train, prune, and evaluate a transformer base question answering model on squad.
Train, prune, and evaluate a transformer base question answering model on squad.
-h, --help show this help message and exit
--model_name_or_path MODEL The path to the transformers model you wish to train
or the name of the pretrained language model you wish
Expand All @@ -63,21 +63,21 @@
or not. Default is false.
--do_eval Boolean denoting if the model should be evaluated
or not. Default is false.
--per_device_train_batch_size Size of each training batch based on samples per GPU.
--per_device_train_batch_size Size of each training batch based on samples per GPU.
12 will fit in a 11gb GPU, 16 in a 16gb.
--per_device_eval_batch_size Size of each training batch based on samples per GPU.
--per_device_eval_batch_size Size of each training batch based on samples per GPU.
12 will fit in a 11gb GPU, 16 in a 16gb.
--learning_rate Learning rate initial float value. ex: 3e-5.
--max_seq_length Int for the max sequence length to be parsed as a context
--max_seq_length Int for the max sequence length to be parsed as a context
window. ex: 384 tokens.
--output_dir Path which model checkpoints and paths should be saved.
--overwrite_output_dir Boolean to define if the
--overwrite_output_dir Boolean to define if the
--cache_dir Directiory which cached transformer files(datasets, models
, tokenizers) are saved for fast loading.
, tokenizers) are saved for fast loading.
--preprocessing_num_workers The amount of cpu workers which are used to process datasets
--seed Int which determines what random seed is for training/shuffling
--nm_prune_config Path to the neural magic prune configuration file. examples can
be found in prune_config_files but are customized for bert-base-uncased.
be found in prune_config_files but are customized for bert-base-uncased.
--do_onnx_export Boolean denoting if the model should be exported to onnx
--onnx_export_path Path where onnx model path will be exported. ex: onnx-export

Expand All @@ -101,7 +101,7 @@
--seed 42 \
--nm_prune_config prune_config_files/95sparsity1epoch.yaml \
--do_onnx_export \
--onnx_export_path 95sparsity1epoch/
--onnx_export_path 95sparsity1epoch/
"""
import collections
import json
Expand Down Expand Up @@ -590,7 +590,7 @@ def prepare_validation_features(examples):

return tokenized_examples

transformers.utils.logging.set_verbosity_info()
transformers.utils.logging.set_verbosity_info()
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
Expand Down Expand Up @@ -618,7 +618,7 @@ def prepare_validation_features(examples):
)

logger.warning(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, "
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
)

Expand Down Expand Up @@ -663,7 +663,7 @@ def prepare_validation_features(examples):

model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
logger.info("Model has %s parameters", params)
logger.info("Model has %s parameters", params)
# Tokenizer check: this script requires a fast tokenizer.
if not isinstance(tokenizer, PreTrainedTokenizerFast):
raise ValueError(
Expand All @@ -679,7 +679,7 @@ def prepare_validation_features(examples):
question_column_name = "question" if "question" in column_names else column_names[0]
context_column_name = "context" if "context" in column_names else column_names[1]
answer_column_name = "answers" if "answers" in column_names else column_names[2]
pad_on_right = tokenizer.padding_side == "right"
pad_on_right = tokenizer.padding_side == "right"

if training_args.do_train:
train_dataset = datasets["train"].map(
Expand Down Expand Up @@ -714,12 +714,13 @@ def prepare_validation_features(examples):

####################################################################################
# Start SparseML Integration
####################################################################################
optim = load_optimizer(model, TrainingArguments)
steps_per_epoch = math.ceil(len(datasets["train"]) / (training_args.per_device_train_batch_size*training_args._n_gpu))
manager = ScheduledModifierManager.from_yaml(data_args.nm_prune_config)
training_args.num_train_epochs = float(manager.max_epochs)
optim = ScheduledOptimizer(optim, model, manager, steps_per_epoch=steps_per_epoch, loggers=None)
####################################################################################
if training_args.do_train:
optim = load_optimizer(model, TrainingArguments)
steps_per_epoch = math.ceil(len(train_dataset) / (training_args.per_device_train_batch_size * training_args._n_gpu))
manager = ScheduledModifierManager.from_yaml(data_args.nm_prune_config)
training_args.num_train_epochs = float(manager.max_epochs)
optim = ScheduledOptimizer(optim, model, manager, steps_per_epoch=steps_per_epoch, loggers=None)
####################################################################################
# End SparseML Integration
####################################################################################
Expand All @@ -734,7 +735,7 @@ def prepare_validation_features(examples):
data_collator=data_collator,
post_process_function=post_processing_function,
compute_metrics=compute_metrics,
optimizers=(optim, None),
optimizers=(optim, None) if training_args.do_train else (None, None),
)

# Training
Expand Down Expand Up @@ -765,7 +766,7 @@ def prepare_validation_features(examples):
####################################################################################
if data_args.do_onnx_export:
logger.info("*** Export to ONNX ***")
print("Exporting onnx model")
print("Exporting onnx model")
os.environ["TOKENIZERS_PARALLELISM"] = "false"
exporter = ModuleExporter(
model, output_dir='onnx-export'
Expand Down