diff --git a/openml/entities/dataset.py b/openml/entities/dataset.py index c5c9f6bca..c3f14bcad 100644 --- a/openml/entities/dataset.py +++ b/openml/entities/dataset.py @@ -122,8 +122,7 @@ def decode_arff(fh): return decode_arff(fh) ############################################################################ - # pandas related stuff... - def get_dataset(self, target=None, include_row_id=False, + def get_dataset(self, target=None, target_dtype=int, include_row_id=False, include_ignore_attributes=False, return_categorical_indicator=False, return_attribute_names=False): @@ -176,7 +175,7 @@ def get_dataset(self, target=None, include_row_id=False, try: x = data[:,~targets] - y = data[:,targets].astype(np.int32) + y = data[:,targets].astype(target_dtype) if len(y.shape) == 2 and y.shape[1] == 1: y = y[:,0] @@ -191,7 +190,7 @@ def get_dataset(self, target=None, include_row_id=False, raise e if scipy.sparse.issparse(y): - y = np.asarray(y.todense()).astype(np.int32).flatten() + y = np.asarray(y.todense()).astype(target_dtype).flatten() rval.append(x) rval.append(y) diff --git a/openml/entities/task.py b/openml/entities/task.py index 3ea72f59d..5609d8419 100644 --- a/openml/entities/task.py +++ b/openml/entities/task.py @@ -44,7 +44,14 @@ def get_dataset(self): def get_X_and_Y(self): dataset = self.get_dataset() # Replace with retrieve from cache - X_and_Y = dataset.get_dataset(target=self.target_feature) + if 'Supervised Classification'.lower() in self.task_type.lower(): + target_dtype = int + elif 'Supervised Regression'.lower() in self.task_type.lower(): + target_dtype = float + else: + raise NotImplementedError(self.task_type) + X_and_Y = dataset.get_dataset(target=self.target_feature, + target_dtype=target_dtype) return X_and_Y def evaluate(self, algo):