Skip to content
This repository was archived by the owner on Jun 4, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions examples/pytorch/language-modeling/run_mlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,12 @@ class DataTrainingArguments:
"for more information"
},
)
recipe_args: Optional[str] = field(
default=None,
metadata={
"help": "Recipe arguments to be overwritten"
},
)
onnx_export_path: Optional[str] = field(
default=None, metadata={"help": "The filename and path which will be where onnx model is outputed"}
)
Expand Down Expand Up @@ -508,6 +514,7 @@ def group_texts(examples):
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics,
recipe_args=data_args.recipe_args
)

# Apply recipes to the model. This is necessary given that
Expand Down
7 changes: 7 additions & 0 deletions examples/pytorch/question-answering/run_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,12 @@ class DataTrainingArguments:
"for more information"
},
)
recipe_args: Optional[str] = field(
default=None,
metadata={
"help": "Recipe arguments to be overwritten"
},
)
onnx_export_path: Optional[str] = field(
default=None, metadata={"help": "The filename and path which will be where onnx model is outputed"}
)
Expand Down Expand Up @@ -595,6 +601,7 @@ def compute_metrics(p: EvalPrediction):
data_collator=data_collator,
post_process_function=post_processing_function,
compute_metrics=compute_metrics,
recipe_args=data_args.recipe_args
)

# Apply recipes to the model. This is necessary given that
Expand Down
8 changes: 7 additions & 1 deletion examples/pytorch/text-classification/run_glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,12 @@ class DataTrainingArguments:
"for more information"
},
)
recipe_args: Optional[str] = field(
default=None,
metadata={
"help": "Recipe arguments to be overwritten"
},
)
onnx_export_path: Optional[str] = field(
default=None, metadata={"help": "The filename and path which will be where onnx model is outputed"}
)
Expand Down Expand Up @@ -504,8 +510,8 @@ def compute_metrics(p: EvalPrediction):
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics,
recipe_args=data_args.recipe_args
)

# Apply recipes to the model. This is necessary given that
# sparsification methods such as QAT modified the model graph with their own learnable
# parameters. They are also restored/loaded to the model.
Expand Down
7 changes: 7 additions & 0 deletions examples/pytorch/token-classification/run_ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,12 @@ class DataTrainingArguments:
"for more information"
},
)
recipe_args: Optional[str] = field(
default=None,
metadata={
"help": "Recipe arguments to be overwritten"
},
)
onnx_export_path: Optional[str] = field(
default=None, metadata={"help": "The filename and path which will be where onnx model is outputed"}
)
Expand Down Expand Up @@ -475,6 +481,7 @@ def compute_metrics(p):
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics,
recipe_args=data_args.recipe_args
)

# Apply recipes to the model. This is necessary given that
Expand Down
14 changes: 11 additions & 3 deletions src/transformers/sparse.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import collections
import inspect
import json
import math
import os
from typing import Optional
from typing import Dict, Optional

import numpy
import torch
Expand All @@ -27,7 +28,7 @@ class SparseMLTrainer(Trainer):
:param args, kwargs: arguments passed into parent class
"""

def __init__(self, model_name_or_path, recipes, teacher=None, *args, **kwargs):
def __init__(self, model_name_or_path, recipes, teacher=None, recipe_args=None, *args, **kwargs):
super().__init__(*args, **kwargs)
self.model_name_or_path = str(model_name_or_path)
self.recipes = [recipe for recipe in recipes if recipe]
Expand All @@ -36,10 +37,17 @@ def __init__(self, model_name_or_path, recipes, teacher=None, *args, **kwargs):
self.teacher.eval()
self.criterion = torch.nn.CrossEntropyLoss()

if recipe_args is not None:
recipe_args = json.loads(recipe_args)
if not isinstance(recipe_args, Dict):
raise ValueError("Cannot convert recipe arguments into dictionary")
else:
recipe_args = {}

manager = None
modifiers = []
for recipe in self.recipes:
manager = ScheduledModifierManager.from_yaml(recipe, modifiers)
manager = ScheduledModifierManager.from_yaml(recipe, modifiers, **recipe_args)
modifiers = manager.modifiers
self.manager = manager

Expand Down