diff --git a/examples/pytorch/language-modeling/run_mlm.py b/examples/pytorch/language-modeling/run_mlm.py index f75935c03521..9cc3cbdfc6b4 100755 --- a/examples/pytorch/language-modeling/run_mlm.py +++ b/examples/pytorch/language-modeling/run_mlm.py @@ -117,6 +117,12 @@ class DataTrainingArguments: "for more information" }, ) + recipe_args: Optional[str] = field( + default=None, + metadata={ + "help": "Recipe arguments to be overwritten" + }, + ) onnx_export_path: Optional[str] = field( default=None, metadata={"help": "The filename and path which will be where onnx model is outputed"} ) @@ -508,6 +514,7 @@ def group_texts(examples): tokenizer=tokenizer, data_collator=data_collator, compute_metrics=compute_metrics, + recipe_args=data_args.recipe_args ) # Apply recipes to the model. This is necessary given that diff --git a/examples/pytorch/question-answering/run_qa.py b/examples/pytorch/question-answering/run_qa.py index 15f1ba036f4c..517fe8ccdfe7 100755 --- a/examples/pytorch/question-answering/run_qa.py +++ b/examples/pytorch/question-answering/run_qa.py @@ -101,6 +101,12 @@ class DataTrainingArguments: "for more information" }, ) + recipe_args: Optional[str] = field( + default=None, + metadata={ + "help": "Recipe arguments to be overwritten" + }, + ) onnx_export_path: Optional[str] = field( default=None, metadata={"help": "The filename and path which will be where onnx model is outputed"} ) @@ -595,6 +601,7 @@ def compute_metrics(p: EvalPrediction): data_collator=data_collator, post_process_function=post_processing_function, compute_metrics=compute_metrics, + recipe_args=data_args.recipe_args ) # Apply recipes to the model. This is necessary given that diff --git a/examples/pytorch/text-classification/run_glue.py b/examples/pytorch/text-classification/run_glue.py index c6b2e56f6237..16f2ed356207 100755 --- a/examples/pytorch/text-classification/run_glue.py +++ b/examples/pytorch/text-classification/run_glue.py @@ -80,6 +80,12 @@ class DataTrainingArguments: "for more information" }, ) + recipe_args: Optional[str] = field( + default=None, + metadata={ + "help": "Recipe arguments to be overwritten" + }, + ) onnx_export_path: Optional[str] = field( default=None, metadata={"help": "The filename and path which will be where onnx model is outputed"} ) @@ -504,8 +510,8 @@ def compute_metrics(p: EvalPrediction): tokenizer=tokenizer, data_collator=data_collator, compute_metrics=compute_metrics, + recipe_args=data_args.recipe_args ) - # Apply recipes to the model. This is necessary given that # sparsification methods such as QAT modified the model graph with their own learnable # parameters. They are also restored/loaded to the model. diff --git a/examples/pytorch/token-classification/run_ner.py b/examples/pytorch/token-classification/run_ner.py index eac5267efe94..71ba57809af4 100755 --- a/examples/pytorch/token-classification/run_ner.py +++ b/examples/pytorch/token-classification/run_ner.py @@ -99,6 +99,12 @@ class DataTrainingArguments: "for more information" }, ) + recipe_args: Optional[str] = field( + default=None, + metadata={ + "help": "Recipe arguments to be overwritten" + }, + ) onnx_export_path: Optional[str] = field( default=None, metadata={"help": "The filename and path which will be where onnx model is outputed"} ) @@ -475,6 +481,7 @@ def compute_metrics(p): tokenizer=tokenizer, data_collator=data_collator, compute_metrics=compute_metrics, + recipe_args=data_args.recipe_args ) # Apply recipes to the model. This is necessary given that diff --git a/src/transformers/sparse.py b/src/transformers/sparse.py index 386bc93de764..415ae39e3335 100644 --- a/src/transformers/sparse.py +++ b/src/transformers/sparse.py @@ -1,8 +1,9 @@ import collections import inspect +import json import math import os -from typing import Optional +from typing import Dict, Optional import numpy import torch @@ -27,7 +28,7 @@ class SparseMLTrainer(Trainer): :param args, kwargs: arguments passed into parent class """ - def __init__(self, model_name_or_path, recipes, teacher=None, *args, **kwargs): + def __init__(self, model_name_or_path, recipes, teacher=None, recipe_args=None, *args, **kwargs): super().__init__(*args, **kwargs) self.model_name_or_path = str(model_name_or_path) self.recipes = [recipe for recipe in recipes if recipe] @@ -36,10 +37,17 @@ def __init__(self, model_name_or_path, recipes, teacher=None, *args, **kwargs): self.teacher.eval() self.criterion = torch.nn.CrossEntropyLoss() + if recipe_args is not None: + recipe_args = json.loads(recipe_args) + if not isinstance(recipe_args, Dict): + raise ValueError("Cannot convert recipe arguments into dictionary") + else: + recipe_args = {} + manager = None modifiers = [] for recipe in self.recipes: - manager = ScheduledModifierManager.from_yaml(recipe, modifiers) + manager = ScheduledModifierManager.from_yaml(recipe, modifiers, **recipe_args) modifiers = manager.modifiers self.manager = manager