Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 0 additions & 13 deletions openml/flows/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,19 +310,6 @@ def _from_dict(cls, xml_dict):
arguments['model'] = None
flow = cls(**arguments)

# try to parse to a model because not everything that can be
# deserialized has to come from scikit-learn. If it can't be
# serialized, but comes from scikit-learn this is worth an exception
if (
arguments['external_version'].startswith('sklearn==')
or ',sklearn==' in arguments['external_version']
):
from .sklearn_converter import flow_to_sklearn
model = flow_to_sklearn(flow)
else:
model = None
flow.model = model

return flow

def publish(self):
Expand Down
18 changes: 17 additions & 1 deletion openml/flows/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,36 @@
import openml.utils


def get_flow(flow_id):
def get_flow(flow_id, reinstantiate=False):
"""Download the OpenML flow for a given flow ID.

Parameters
----------
flow_id : int
The OpenML flow id.

reinstantiate: bool
Whether to reinstantiate the flow to a sklearn model.
Note that this can only be done with sklearn flows, and
when

Returns
-------
flow : OpenMLFlow
the flow
"""
flow_id = int(flow_id)
flow_xml = openml._api_calls._perform_api_call("flow/%d" % flow_id)

flow_dict = xmltodict.parse(flow_xml)
flow = OpenMLFlow._from_dict(flow_dict)

if reinstantiate:
if not (flow.external_version.startswith('sklearn==') or
',sklearn==' in flow.external_version):
raise ValueError('Only sklearn flows can be reinstantiated')
flow.model = openml.flows.flow_to_sklearn(flow)

return flow


Expand Down
7 changes: 4 additions & 3 deletions tests/test_flows/test_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,9 +275,9 @@ def test_existing_flow_exists(self):
for classifier in [nb, complicated]:
flow = openml.flows.sklearn_to_flow(classifier)
flow, _ = self._add_sentinel_to_flow_name(flow, None)
#publish the flow
# publish the flow
flow = flow.publish()
#redownload the flow
# redownload the flow
flow = openml.flows.get_flow(flow.flow_id)

# check if flow exists can find it
Expand Down Expand Up @@ -329,7 +329,8 @@ def test_sklearn_to_upload_to_flow(self):
# Check whether we can load the flow again
# Remove the sentinel from the name again so that we can reinstantiate
# the object again
new_flow = openml.flows.get_flow(flow_id=flow.flow_id)
new_flow = openml.flows.get_flow(flow_id=flow.flow_id,
reinstantiate=True)

local_xml = flow._to_xml()
server_xml = new_flow._to_xml()
Expand Down
3 changes: 2 additions & 1 deletion tests/test_runs/test_run_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,7 +627,8 @@ def test_get_run_trace(self):
flow_exists = openml.flows.flow_exists(flow.name, flow.external_version)
self.assertIsInstance(flow_exists, int)
self.assertGreater(flow_exists, 0)
downloaded_flow = openml.flows.get_flow(flow_exists)
downloaded_flow = openml.flows.get_flow(flow_exists,
reinstantiate=True)
setup_exists = openml.setups.setup_exists(downloaded_flow)
self.assertIsInstance(setup_exists, int)
self.assertGreater(setup_exists, 0)
Expand Down