From 64609c43b07464ec614a4a7a54a923aacfbc78d4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Merc=C3=A8=20Mart=C3=ADn=20Prats?= Date: Fri, 7 Jun 2019 02:06:48 +0200 Subject: [PATCH] Fixing local predictions when no input_fields are defined --- bigml/ensemble.py | 4 +++- bigml/modelfields.py | 10 ++++++++-- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/bigml/ensemble.py b/bigml/ensemble.py index 542564ff..ed523f56 100644 --- a/bigml/ensemble.py +++ b/bigml/ensemble.py @@ -131,6 +131,7 @@ def __init__(self, ensemble, self.importance = {} query_string = ONLY_MODEL no_check_fields = False + self.input_fields = [] if isinstance(ensemble, list): if all([isinstance(model, Model) for model in ensemble]): models = ensemble @@ -144,6 +145,7 @@ def __init__(self, ensemble, raise ValueError('Failed to verify the list of models.' ' Check your model id values: %s' % str(exc)) + else: ensemble = self.get_ensemble_resource(ensemble) self.resource_id = get_ensemble_id(ensemble) @@ -166,7 +168,7 @@ def __init__(self, ensemble, self.objective_id = ensemble['object'].get("objective_field") query_string = EXCLUDE_FIELDS no_check_fields = True - self.input_fields = ensemble['object'].get('input_fields') + self.input_fields = ensemble['object'].get('input_fields') number_of_models = len(models) if max_models is None: diff --git a/bigml/modelfields.py b/bigml/modelfields.py index 135ea953..a2ff8671 100644 --- a/bigml/modelfields.py +++ b/bigml/modelfields.py @@ -154,8 +154,14 @@ def __init__(self, fields, objective_id=None, data_locale=None, self.inverted_fields = invert_dictionary(fields) self.fields = {} self.fields.update(fields) - if not self.input_fields: - self.input_fields = self.fields.keys() + if not (hasattr(self, "input_fields") and self.input_fields): + self.input_fields = [field_id for field_id, field in \ + sorted( \ + [(field_id, field) for field_id, + field in self.fields.items()], + key=lambda(x): x[1].get("column_number")) + if not self.objective_id or + field_id != self.objective_id] self.model_fields = {} self.model_fields.update( dict([(field_id, field) for field_id, field in \