Skip to content

Commit

Permalink
Filter out attributes.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Jan 29, 2021
1 parent 4d66af8 commit f05ebbf
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 9 deletions.
19 changes: 13 additions & 6 deletions python-package/xgboost/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -1398,23 +1398,30 @@ async def _get_booster_future(self) -> "distributed.Future":
"""
booster = self.get_booster()
# pylint: disable=access-member-before-definition
if not hasattr(self, "_booster_future") or self._booster_future_id != id(booster):
self._booster_future = await self.client.scatter(booster)
self._booster_future_id: int = id(booster)
return self._booster_future
if not hasattr(self, "_booster_future") or self._booster_future[1] != id(booster):
self._booster_future: Tuple[
distributed.Future, int
] = (await self.client.scatter(booster), id(booster))
return self._booster_future[0]

def __await__(self) -> Awaitable[Any]:
# Generate a coroutine wrapper to make this class awaitable.
async def _() -> Awaitable[Any]:
return self
return self.client.sync(_).__await__()

def __getstate__(self) -> Dict:
this = self.__dict__.copy()
def _filter_serialization(self, this: Dict) -> Dict:
this = super()._filter_serialization(this)
if "_client" in this.keys():
del this["_client"]
if "_booster_future" in this.keys():
del this["_booster_future"]
return this

def __getstate__(self) -> Dict:
this = self.__dict__.copy()
return self._filter_serialization(this)

@property
def client(self) -> "distributed.Client":
'''The dask client used in this model.'''
Expand Down
9 changes: 7 additions & 2 deletions python-package/xgboost/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,10 @@ def _get_type(self) -> str:
)
return self._estimator_type # pylint: disable=no-member

def _filter_serialization(self, this: Dict) -> Dict:
"""Filter out attributes that should not be saved in model or pickle."""
return this

def save_model(self, fname: str):
"""Save the model to a file.
Expand All @@ -520,8 +524,9 @@ def save_model(self, fname: str):
Output file name
"""
this = self._filter_serialization(self.__dict__)
meta = dict()
for k, v in self.__dict__.items():
for k, v in this.items():
if k == '_le':
meta['_le'] = self._le.to_json()
continue
Expand All @@ -535,7 +540,7 @@ def save_model(self, fname: str):
json.dumps({k: v})
meta[k] = v
except TypeError:
warnings.warn(str(k) + ' is not saved in Scikit-Learn meta.')
warnings.warn(str(k) + ' is not saved in Scikit-Learn meta.', UserWarning)
meta['_estimator_type'] = self._get_type()
meta_str = json.dumps(meta)
self.get_booster().set_attr(scikit_learn=meta_str)
Expand Down
4 changes: 3 additions & 1 deletion tests/python/test_with_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -1136,7 +1136,7 @@ def test_sklearn_io(self, client: 'Client') -> None:
cls.fit(X, y)
predt_0 = cls.predict(X)

with tempfile.TemporaryDirectory() as tmpdir:
with tempfile.TemporaryDirectory() as tmpdir, pytest.warns(None) as record:
path = os.path.join(tmpdir, "model.pkl")
with open(path, "wb") as fd:
pickle.dump(cls, fd)
Expand Down Expand Up @@ -1164,6 +1164,8 @@ def test_sklearn_io(self, client: 'Client') -> None:

np.testing.assert_allclose(predt_0.compute(), predt_3)

assert len(record) == 0, record[0].message


class TestDaskCallbacks:
@pytest.mark.skipif(**tm.no_sklearn())
Expand Down

0 comments on commit f05ebbf

Please sign in to comment.