Skip to content

Commit

Permalink
Rename filters as fields
Browse files Browse the repository at this point in the history
  • Loading branch information
vcfgv committed Aug 26, 2021
1 parent 6ba355d commit 35d2720
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 29 deletions.
8 changes: 4 additions & 4 deletions mars/core/entity/executable.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,12 +135,12 @@ def fetch_log(self,
return fetch_log(self, session=session,
offsets=offsets, sizes=sizes)[0]

def _fetch_infos(self, filters=None, session=None, **kw):
def _fetch_infos(self, fields=None, session=None, **kw):
from ...deploy.oscar.session import fetch_infos

session = _get_session(self, session)
self._check_session(session, 'fetch_infos')
return fetch_infos(self, filters=filters, session=session, **kw)
return fetch_infos(self, fields=fields, session=session, **kw)

def _attach_session(self, session: SessionType):
if session not in self._executed_sessions:
Expand Down Expand Up @@ -240,12 +240,12 @@ def _fetch(self, session: SessionType = None, **kw):
self._check_session(session, 'fetch')
return fetch(*self, session=session, **kw)

def _fetch_infos(self, filters=None, session=None, **kw):
def _fetch_infos(self, fields=None, session=None, **kw):
from ...deploy.oscar.session import fetch_infos

session = _get_session(self, session)
self._check_session(session, 'fetch_infos')
return fetch_infos(*self, filters=filters, session=session, **kw)
return fetch_infos(*self, fields=fields, session=session, **kw)

def fetch(self, session: SessionType = None, **kw):
if len(self) == 0:
Expand Down
2 changes: 1 addition & 1 deletion mars/dataframe/contrib/raydataset/mldataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def to_ray_mldataset(df,
# chunk1 for addr1,
# chunk2 & chunk3 for addr2,
# chunk4 for addr1
fetched_infos: Dict[str, List] = df.fetch_infos(filters=['band', 'object_id'])
fetched_infos: Dict[str, List] = df.fetch_infos(fields=['band', 'object_id'])
chunk_addr_refs: List[Tuple[Tuple, 'ray.ObjectRef']] = [(band, object_id) for band, object_id in
zip(fetched_infos['band'],
fetched_infos['object_id'])]
Expand Down
4 changes: 2 additions & 2 deletions mars/dataframe/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,8 +528,8 @@ def fetch(self, session=None, **kw):
session=session, **kw))
return pd.concat(batches) if len(batches) > 1 else batches[0]

def fetch_infos(self, filters=None, session=None, **kw):
return self._fetch_infos(filters=filters, session=session, **kw)
def fetch_infos(self, fields=None, session=None, **kw):
return self._fetch_infos(fields=fields, session=session, **kw)


class IndexData(HasShapeTileableData, _ToPandasMixin):
Expand Down
42 changes: 21 additions & 21 deletions mars/deploy/oscar/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,16 +421,16 @@ def fetch(self, *tileables, **kwargs) -> list:
"""

@abstractmethod
def fetch_infos(self, *tileables, filters, **kwargs) -> list:
def fetch_infos(self, *tileables, fields, **kwargs) -> list:
"""
Fetch infos of tileables.
Parameters
----------
tileables
Tileables.
filters
List of filters
fields
List of fields
kwargs
Returns
Expand Down Expand Up @@ -884,16 +884,16 @@ async def fetch(self, *tileables, **kwargs) -> list:
result.append(self._process_result(tileable, merged))
return result

async def fetch_infos(self, *tileables, filters, **kwargs) -> list:
async def fetch_infos(self, *tileables, fields, **kwargs) -> list:
available_fields = {'object_id', 'level', 'memory_size', 'store_size', 'band'}
if filters is None:
filters = available_fields
if fields is None:
fields = available_fields
else:
for filter_name in filters:
if filter_name not in available_fields: # pragma: no cover
for field_name in fields:
if field_name not in available_fields: # pragma: no cover
raise TypeError(f'`fetch_infos` got unexpected '
f'filter name: {filter_name}')
filters = set(filters)
f'field name: {field_name}')
fields = set(fields)

if kwargs: # pragma: no cover
unexpected_keys = ', '.join(list(kwargs.keys()))
Expand Down Expand Up @@ -942,17 +942,17 @@ async def fetch_infos(self, *tileables, filters, **kwargs) -> list:
band = chunk_to_band[fetch_info.chunk]
# Currently there's only one item in the returned List from storage_api.get_infos()
data = fetch_info.data[0]
if 'object_id' in filters:
if 'object_id' in fields:
fetched['object_id'].append(data.object_id)
if 'level' in filters:
if 'level' in fields:
fetched['level'].append(data.level)
if 'memory_size' in filters:
if 'memory_size' in fields:
fetched['memory_size'].append(data.memory_size)
if 'store_size' in filters:
if 'store_size' in fields:
fetched['store_size'].append(data.store_size)
# data.band misses ip info, e.g. 'numa-0'
# while band doesn't, e.g. (address0, 'numa-0')
if 'band' in filters:
if 'band' in fields:
fetched['band'].append(band)
result.append(fetched)

Expand Down Expand Up @@ -1337,8 +1337,8 @@ def fetch(self, *tileables, **kwargs) -> list:
return asyncio.run_coroutine_threadsafe(coro, self._loop).result()

@implements(AbstractSyncSession.fetch_infos)
def fetch_infos(self, *tileables, filters, **kwargs) -> list:
coro = _fetch_infos(*tileables, filters=filters, session=self._isolated_session, **kwargs)
def fetch_infos(self, *tileables, fields, **kwargs) -> list:
coro = _fetch_infos(*tileables, fields=fields, session=self._isolated_session, **kwargs)
return asyncio.run_coroutine_threadsafe(coro, self._loop).result()

@implements(AbstractSyncSession.decref)
Expand Down Expand Up @@ -1484,12 +1484,12 @@ async def _fetch(tileable: TileableType,
async def _fetch_infos(tileable: TileableType,
*tileables: Tuple[TileableType],
session: _IsolatedSession = None,
filters: List[str] = None,
fields: List[str] = None,
**kwargs):
if isinstance(tileable, tuple) and len(tileables) == 0:
tileable, tileables = tileable[0], tileable[1:]
session = _get_isolated_session(session)
data = await session.fetch_infos(tileable, *tileables, filters=filters, **kwargs)
data = await session.fetch_infos(tileable, *tileables, fields=fields, **kwargs)
return data[0] if len(tileables) == 0 else data


Expand All @@ -1510,7 +1510,7 @@ def fetch(tileable: TileableType,

def fetch_infos(tileable: TileableType,
*tileables: Tuple[TileableType],
filters: List[str],
fields: List[str],
session: SyncSession = None,
**kwargs):
if isinstance(tileable, tuple) and len(tileables) == 0:
Expand All @@ -1520,7 +1520,7 @@ def fetch_infos(tileable: TileableType,
if session is None: # pragma: no cover
raise ValueError('No session found')
session = _ensure_sync(session)
return session.fetch_infos(tileable, *tileables, filters=filters, **kwargs)
return session.fetch_infos(tileable, *tileables, fields=fields, **kwargs)


def fetch_log(*tileables: TileableType,
Expand Down
3 changes: 2 additions & 1 deletion mars/deploy/oscar/tests/test_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ async def test_fetch_infos(create_cluster):
assert 'store_size' in fetched_infos
assert 'band' in fetched_infos

fetch_infos((df, df), fields=None)
results_infos = mr.ExecutableTuple([df, df]).execute()._fetch_infos()
assert len(results_infos) == 2
assert 'object_id' in results_infos[0]
Expand Down Expand Up @@ -309,7 +310,7 @@ def test_no_default_session():
execute(b, show_progress=False)

np.testing.assert_array_equal(fetch(b), raw + 1)
fetch_infos(b, filters=None)
fetch_infos(b, fileds=None)
assert get_default_async_session() is not None
stop_server()
assert get_default_async_session() is None
Expand Down

0 comments on commit 35d2720

Please sign in to comment.