diff --git a/openml/flows/flow.py b/openml/flows/flow.py index 0c70fc9bc..83878ee51 100644 --- a/openml/flows/flow.py +++ b/openml/flows/flow.py @@ -310,19 +310,6 @@ def _from_dict(cls, xml_dict): arguments['model'] = None flow = cls(**arguments) - # try to parse to a model because not everything that can be - # deserialized has to come from scikit-learn. If it can't be - # serialized, but comes from scikit-learn this is worth an exception - if ( - arguments['external_version'].startswith('sklearn==') - or ',sklearn==' in arguments['external_version'] - ): - from .sklearn_converter import flow_to_sklearn - model = flow_to_sklearn(flow) - else: - model = None - flow.model = model - return flow def publish(self): diff --git a/openml/flows/functions.py b/openml/flows/functions.py index a3cf31880..9fdf09dc8 100644 --- a/openml/flows/functions.py +++ b/openml/flows/functions.py @@ -8,13 +8,23 @@ import openml.utils -def get_flow(flow_id): +def get_flow(flow_id, reinstantiate=False): """Download the OpenML flow for a given flow ID. Parameters ---------- flow_id : int The OpenML flow id. + + reinstantiate: bool + Whether to reinstantiate the flow to a sklearn model. + Note that this can only be done with sklearn flows, and + when + + Returns + ------- + flow : OpenMLFlow + the flow """ flow_id = int(flow_id) flow_xml = openml._api_calls._perform_api_call("flow/%d" % flow_id) @@ -22,6 +32,12 @@ def get_flow(flow_id): flow_dict = xmltodict.parse(flow_xml) flow = OpenMLFlow._from_dict(flow_dict) + if reinstantiate: + if not (flow.external_version.startswith('sklearn==') or + ',sklearn==' in flow.external_version): + raise ValueError('Only sklearn flows can be reinstantiated') + flow.model = openml.flows.flow_to_sklearn(flow) + return flow diff --git a/tests/test_flows/test_flow.py b/tests/test_flows/test_flow.py index 39c03fee1..af19628c0 100644 --- a/tests/test_flows/test_flow.py +++ b/tests/test_flows/test_flow.py @@ -275,9 +275,9 @@ def test_existing_flow_exists(self): for classifier in [nb, complicated]: flow = openml.flows.sklearn_to_flow(classifier) flow, _ = self._add_sentinel_to_flow_name(flow, None) - #publish the flow + # publish the flow flow = flow.publish() - #redownload the flow + # redownload the flow flow = openml.flows.get_flow(flow.flow_id) # check if flow exists can find it @@ -329,7 +329,8 @@ def test_sklearn_to_upload_to_flow(self): # Check whether we can load the flow again # Remove the sentinel from the name again so that we can reinstantiate # the object again - new_flow = openml.flows.get_flow(flow_id=flow.flow_id) + new_flow = openml.flows.get_flow(flow_id=flow.flow_id, + reinstantiate=True) local_xml = flow._to_xml() server_xml = new_flow._to_xml() diff --git a/tests/test_runs/test_run_functions.py b/tests/test_runs/test_run_functions.py index e1898be5a..0c983d861 100644 --- a/tests/test_runs/test_run_functions.py +++ b/tests/test_runs/test_run_functions.py @@ -627,7 +627,8 @@ def test_get_run_trace(self): flow_exists = openml.flows.flow_exists(flow.name, flow.external_version) self.assertIsInstance(flow_exists, int) self.assertGreater(flow_exists, 0) - downloaded_flow = openml.flows.get_flow(flow_exists) + downloaded_flow = openml.flows.get_flow(flow_exists, + reinstantiate=True) setup_exists = openml.setups.setup_exists(downloaded_flow) self.assertIsInstance(setup_exists, int) self.assertGreater(setup_exists, 0)