diff --git a/python/tvm/auto_scheduler/cost_model/xgb_model.py b/python/tvm/auto_scheduler/cost_model/xgb_model.py index eb14dff0815c..f42648288bfa 100644 --- a/python/tvm/auto_scheduler/cost_model/xgb_model.py +++ b/python/tvm/auto_scheduler/cost_model/xgb_model.py @@ -88,7 +88,14 @@ class XGBModel(PythonBasedModel): their predictions. """ - def __init__(self, verbose_eval=25, num_warmup_sample=100, seed=None): + def __init__( + self, + verbose_eval=25, + num_warmup_sample=100, + seed=None, + model_file=None, + adapative_training=False, + ): global xgb try: if xgb is None: @@ -116,12 +123,15 @@ def __init__(self, verbose_eval=25, num_warmup_sample=100, seed=None): self.plan_size = 32 self.num_warmup_sample = num_warmup_sample self.verbose_eval = verbose_eval + self.model_file = model_file + self.adapative_training = adapative_training super().__init__() # cache measurement input/result pairs and extracted features self.inputs = [] self.results = [] + self.last_train_length = 0 self.inputs_feature_cache = [] def update(self, inputs, results): @@ -141,6 +151,15 @@ def update(self, inputs, results): self.inputs.extend(inputs) self.results.extend(results) + if ( + self.adapative_training + and len(self.inputs) - self.last_train_length < self.last_train_length / 5 + ): + # Set a training threshold related to `last_train_length` to reduce the training + # overhead when there're too many logs + return + self.last_train_length = len(self.inputs) + # extract feature n_cached = len(self.inputs_feature_cache) features, normalized_throughputs, task_ids = get_per_store_features_from_measure_pairs( @@ -176,6 +195,10 @@ def update(self, inputs, results): ], ) + # Update the model file if it has been set + if self.model_file: + self.save(self.model_file) + def predict(self, task, states): """Predict the scores of states Parameters diff --git a/python/tvm/auto_scheduler/task_scheduler.py b/python/tvm/auto_scheduler/task_scheduler.py index ab83ff40c461..975306f7be54 100644 --- a/python/tvm/auto_scheduler/task_scheduler.py +++ b/python/tvm/auto_scheduler/task_scheduler.py @@ -47,6 +47,7 @@ def make_search_policies( verbose, load_model_file=None, load_log_file=None, + adapative_training=False, ): """Make a list of search policies for a list of search tasks. It creates one policy per task. @@ -70,6 +71,9 @@ def make_search_policies( load_log_file: Optional[str] Load measurement records from this file. If it is not None, the status of the task scheduler, search policies and cost models will be restored according to this file. + adapative_training: bool = False + Option used for XGBModel, which will reduce the model training frequency when there're too + many logs. Returns ------- @@ -82,11 +86,16 @@ def make_search_policies( if isinstance(search_policy, str): policy_type, model_type = search_policy.split(".") if model_type == "xgb": - cost_model = XGBModel(num_warmup_sample=len(tasks) * num_measures_per_round) - if load_model_file: + cost_model = XGBModel( + num_warmup_sample=len(tasks) * num_measures_per_round, + model_file=load_model_file, + adapative_training=adapative_training, + ) + if load_model_file and os.path.isfile(load_model_file): logger.info("TaskScheduler: Load pretrained model...") cost_model.load(load_model_file) elif load_log_file: + logger.info("TaskScheduler: Reload measured states and train the model...") cost_model.update_from_file(load_log_file) elif model_type == "random": cost_model = RandomModel() diff --git a/src/auto_scheduler/feature.cc b/src/auto_scheduler/feature.cc index 47b9fb60aab4..a5d4958af769 100755 --- a/src/auto_scheduler/feature.cc +++ b/src/auto_scheduler/feature.cc @@ -1462,12 +1462,18 @@ void GetPerStoreFeaturesFromMeasurePairs(const Array& inputs, if (find_res == task_cache.end()) { if (inputs[i]->task->compute_dag.defined()) { // the measure input is complete task = inputs[i]->task; - } else { // the measure input is incomplete - // rebuild task for incomplete measure pairs read from file - Array tensors = (*workload_key_to_tensors)(workload_key); - task = SearchTask(ComputeDAG(tensors), workload_key, inputs[i]->task->target, - inputs[i]->task->target_host, inputs[i]->task->hardware_params, - inputs[i]->task->layout_rewrite_option); + } else { + // The measure input is incomplete, rebuild task for incomplete measure pairs read from file + try { + Array tensors = (*workload_key_to_tensors)(workload_key); + task = SearchTask(ComputeDAG(tensors), workload_key, inputs[i]->task->target, + inputs[i]->task->target_host, inputs[i]->task->hardware_params, + inputs[i]->task->layout_rewrite_option); + } catch (std::exception& e) { + // Cannot build ComputeDAG from workload key, the task may have not been registered in + // this search round + continue; + } } task_id = task_cache.size();