Skip to content

Commit

Permalink
[Fix][Autoscheduler] Costmodel enhancement & bug fix for graph debug …
Browse files Browse the repository at this point in the history
…runtime (apache#7197)

* Enhancement for autoscheduler cost model

* Bug fix for graph_runtime_debug

* Update

* Lint fix

* Update

* Update

* Add file exist check for cost model load

* Update

* Update

* Lint fix

* Update

* Bug fix
  • Loading branch information
jcf94 authored and electriclilies committed Feb 18, 2021
1 parent 6f0b7c1 commit 52c9767
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 9 deletions.
25 changes: 24 additions & 1 deletion python/tvm/auto_scheduler/cost_model/xgb_model.py
Expand Up @@ -88,7 +88,14 @@ class XGBModel(PythonBasedModel):
their predictions.
"""

def __init__(self, verbose_eval=25, num_warmup_sample=100, seed=None):
def __init__(
self,
verbose_eval=25,
num_warmup_sample=100,
seed=None,
model_file=None,
adapative_training=False,
):
global xgb
try:
if xgb is None:
Expand Down Expand Up @@ -116,12 +123,15 @@ def __init__(self, verbose_eval=25, num_warmup_sample=100, seed=None):
self.plan_size = 32
self.num_warmup_sample = num_warmup_sample
self.verbose_eval = verbose_eval
self.model_file = model_file
self.adapative_training = adapative_training

super().__init__()

# cache measurement input/result pairs and extracted features
self.inputs = []
self.results = []
self.last_train_length = 0
self.inputs_feature_cache = []

def update(self, inputs, results):
Expand All @@ -141,6 +151,15 @@ def update(self, inputs, results):
self.inputs.extend(inputs)
self.results.extend(results)

if (
self.adapative_training
and len(self.inputs) - self.last_train_length < self.last_train_length / 5
):
# Set a training threshold related to `last_train_length` to reduce the training
# overhead when there're too many logs
return
self.last_train_length = len(self.inputs)

# extract feature
n_cached = len(self.inputs_feature_cache)
features, normalized_throughputs, task_ids = get_per_store_features_from_measure_pairs(
Expand Down Expand Up @@ -176,6 +195,10 @@ def update(self, inputs, results):
],
)

# Update the model file if it has been set
if self.model_file:
self.save(self.model_file)

def predict(self, task, states):
"""Predict the scores of states
Parameters
Expand Down
13 changes: 11 additions & 2 deletions python/tvm/auto_scheduler/task_scheduler.py
Expand Up @@ -47,6 +47,7 @@ def make_search_policies(
verbose,
load_model_file=None,
load_log_file=None,
adapative_training=False,
):
"""Make a list of search policies for a list of search tasks.
It creates one policy per task.
Expand All @@ -70,6 +71,9 @@ def make_search_policies(
load_log_file: Optional[str]
Load measurement records from this file. If it is not None, the status of the
task scheduler, search policies and cost models will be restored according to this file.
adapative_training: bool = False
Option used for XGBModel, which will reduce the model training frequency when there're too
many logs.
Returns
-------
Expand All @@ -82,11 +86,16 @@ def make_search_policies(
if isinstance(search_policy, str):
policy_type, model_type = search_policy.split(".")
if model_type == "xgb":
cost_model = XGBModel(num_warmup_sample=len(tasks) * num_measures_per_round)
if load_model_file:
cost_model = XGBModel(
num_warmup_sample=len(tasks) * num_measures_per_round,
model_file=load_model_file,
adapative_training=adapative_training,
)
if load_model_file and os.path.isfile(load_model_file):
logger.info("TaskScheduler: Load pretrained model...")
cost_model.load(load_model_file)
elif load_log_file:
logger.info("TaskScheduler: Reload measured states and train the model...")
cost_model.update_from_file(load_log_file)
elif model_type == "random":
cost_model = RandomModel()
Expand Down
18 changes: 12 additions & 6 deletions src/auto_scheduler/feature.cc
Expand Up @@ -1462,12 +1462,18 @@ void GetPerStoreFeaturesFromMeasurePairs(const Array<MeasureInput>& inputs,
if (find_res == task_cache.end()) {
if (inputs[i]->task->compute_dag.defined()) { // the measure input is complete
task = inputs[i]->task;
} else { // the measure input is incomplete
// rebuild task for incomplete measure pairs read from file
Array<te::Tensor> tensors = (*workload_key_to_tensors)(workload_key);
task = SearchTask(ComputeDAG(tensors), workload_key, inputs[i]->task->target,
inputs[i]->task->target_host, inputs[i]->task->hardware_params,
inputs[i]->task->layout_rewrite_option);
} else {
// The measure input is incomplete, rebuild task for incomplete measure pairs read from file
try {
Array<te::Tensor> tensors = (*workload_key_to_tensors)(workload_key);
task = SearchTask(ComputeDAG(tensors), workload_key, inputs[i]->task->target,
inputs[i]->task->target_host, inputs[i]->task->hardware_params,
inputs[i]->task->layout_rewrite_option);
} catch (std::exception& e) {
// Cannot build ComputeDAG from workload key, the task may have not been registered in
// this search round
continue;
}
}
task_id = task_cache.size();

Expand Down

0 comments on commit 52c9767

Please sign in to comment.