Skip to content

Commit

Permalink
Optimize performance of transfer (#3091)
Browse files Browse the repository at this point in the history
  • Loading branch information
Xuye (Chris) Qin committed May 31, 2022
1 parent b670c47 commit 994aec1
Show file tree
Hide file tree
Showing 12 changed files with 246 additions and 164 deletions.
101 changes: 101 additions & 0 deletions benchmarks/asv_bench/benchmarks/storage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Copyright 1999-2022 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import itertools

import numpy as np
import pandas as pd

import mars
import mars.remote as mr
from mars.core.context import get_context
from mars.utils import Timer, readable_size


def send_1_to_1(n: int = None):
ctx = get_context()
workers = ctx.get_worker_addresses()

worker_to_gen_data = {
w: mr.spawn(_gen_data, kwargs=dict(n=n, worker=w), expect_worker=w)
for i, w in enumerate(workers)
}
all_data = mars.execute(list(worker_to_gen_data.values()))
progress = 0.1
ctx.set_progress(progress)
infos = [d._fetch_infos(fields=["data_key", "store_size"]) for d in all_data]
data_size = infos[0]["store_size"][0]
worker_to_data_keys = dict(zip(workers, [info["data_key"][0] for info in infos]))

workers_to_durations = dict()
size = len(workers) * (len(workers) - 1)
for worker1, worker2 in itertools.permutations(workers, 2):
fetch_data = mr.spawn(
_fetch_data,
args=(worker_to_data_keys[worker1],),
kwargs=dict(worker=worker2),
expect_worker=worker2,
)
fetch_time = fetch_data.execute().fetch()
rate = readable_size(data_size / fetch_time)
workers_to_durations[worker1, worker2] = (
readable_size(data_size),
f"{rate}B/s",
)
progress += 0.9 / size
ctx.set_progress(min(progress, 1.0))
return workers_to_durations


def _gen_data(
n: int = None, worker: str = None, check_addr: bool = True
) -> pd.DataFrame:
if check_addr:
ctx = get_context()
assert ctx.worker_address == worker
n = n if n is not None else 5_000_000
rs = np.random.RandomState(123)
data = {
"a": rs.rand(n),
"b": rs.randint(n * 10, size=n),
"c": [f"foo{i}" for i in range(n)],
}
return pd.DataFrame(data)


def _fetch_data(data_key: str, worker: str = None):
# do nothing actually
ctx = get_context()
assert ctx.worker_address == worker
with Timer() as timer:
ctx.get_chunks_result([data_key], fetch_only=True)
return timer.duration


class TransferPackageSuite:
"""
Benchmark that times performance of storage transfer
"""

def setup(self):
mars.new_session(n_worker=2, n_cpu=8)

def time_1_to_1(self):
return mr.spawn(send_1_to_1).execute().fetch()


if __name__ == "__main__":
suite = TransferPackageSuite()
suite.setup()
print(suite.time_1_to_1())
File renamed without changes.
81 changes: 39 additions & 42 deletions benchmarks/tpch/run_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,71 +25,55 @@
queries: Optional[Union[Set[str], List[str]]] = None


def load_lineitem(data_folder: str) -> md.DataFrame:
def load_lineitem(data_folder: str, use_arrow_dtype: bool = None) -> md.DataFrame:
data_path = data_folder + "/lineitem.pq"
df = md.read_parquet(
data_path,
)
df = md.read_parquet(data_path, use_arrow_dtype=use_arrow_dtype)
df["L_SHIPDATE"] = md.to_datetime(df.L_SHIPDATE, format="%Y-%m-%d")
df["L_RECEIPTDATE"] = md.to_datetime(df.L_RECEIPTDATE, format="%Y-%m-%d")
df["L_COMMITDATE"] = md.to_datetime(df.L_COMMITDATE, format="%Y-%m-%d")
return df


def load_part(data_folder: str) -> md.DataFrame:
def load_part(data_folder: str, use_arrow_dtype: bool = None) -> md.DataFrame:
data_path = data_folder + "/part.pq"
df = md.read_parquet(
data_path,
)
df = md.read_parquet(data_path, use_arrow_dtype=use_arrow_dtype)
return df


def load_orders(data_folder: str) -> md.DataFrame:
def load_orders(data_folder: str, use_arrow_dtype: bool = None) -> md.DataFrame:
data_path = data_folder + "/orders.pq"
df = md.read_parquet(
data_path,
)
df = md.read_parquet(data_path, use_arrow_dtype=use_arrow_dtype)
df["O_ORDERDATE"] = md.to_datetime(df.O_ORDERDATE, format="%Y-%m-%d")
return df


def load_customer(data_folder: str) -> md.DataFrame:
def load_customer(data_folder: str, use_arrow_dtype: bool = None) -> md.DataFrame:
data_path = data_folder + "/customer.pq"
df = md.read_parquet(
data_path,
)
df = md.read_parquet(data_path, use_arrow_dtype=use_arrow_dtype)
return df


def load_nation(data_folder: str) -> md.DataFrame:
def load_nation(data_folder: str, use_arrow_dtype: bool = None) -> md.DataFrame:
data_path = data_folder + "/nation.pq"
df = md.read_parquet(
data_path,
)
df = md.read_parquet(data_path, use_arrow_dtype=use_arrow_dtype)
return df


def load_region(data_folder: str) -> md.DataFrame:
def load_region(data_folder: str, use_arrow_dtype: bool = None) -> md.DataFrame:
data_path = data_folder + "/region.pq"
df = md.read_parquet(
data_path,
)
df = md.read_parquet(data_path, use_arrow_dtype=use_arrow_dtype)
return df


def load_supplier(data_folder: str) -> md.DataFrame:
def load_supplier(data_folder: str, use_arrow_dtype: bool = None) -> md.DataFrame:
data_path = data_folder + "/supplier.pq"
df = md.read_parquet(
data_path,
)
df = md.read_parquet(data_path, use_arrow_dtype=use_arrow_dtype)
return df


def load_partsupp(data_folder: str) -> md.DataFrame:
def load_partsupp(data_folder: str, use_arrow_dtype: bool = None) -> md.DataFrame:
data_path = data_folder + "/partsupp.pq"
df = md.read_parquet(
data_path,
)
df = md.read_parquet(data_path, use_arrow_dtype=use_arrow_dtype)
return df


Expand Down Expand Up @@ -982,21 +966,23 @@ def q22(customer, orders):
print(total.execute())


def run_queries(data_folder: str, select: List[str] = None):
def run_queries(
data_folder: str, select: List[str] = None, use_arrow_dtype: bool = None
):
if select:
global queries
queries = select

# Load the data
t1 = time.time()
lineitem = load_lineitem(data_folder)
orders = load_orders(data_folder)
customer = load_customer(data_folder)
nation = load_nation(data_folder)
region = load_region(data_folder)
supplier = load_supplier(data_folder)
part = load_part(data_folder)
partsupp = load_partsupp(data_folder)
lineitem = load_lineitem(data_folder, use_arrow_dtype=use_arrow_dtype)
orders = load_orders(data_folder, use_arrow_dtype=use_arrow_dtype)
customer = load_customer(data_folder, use_arrow_dtype=use_arrow_dtype)
nation = load_nation(data_folder, use_arrow_dtype=use_arrow_dtype)
region = load_region(data_folder, use_arrow_dtype=use_arrow_dtype)
supplier = load_supplier(data_folder, use_arrow_dtype=use_arrow_dtype)
part = load_part(data_folder, use_arrow_dtype=use_arrow_dtype)
partsupp = load_partsupp(data_folder, use_arrow_dtype=use_arrow_dtype)
mars.execute([lineitem, orders, customer, nation, region, supplier, part, partsupp])
print("Reading time (s): ", time.time() - t1)

Expand Down Expand Up @@ -1048,14 +1034,25 @@ def main():
"all tests will be executed"
),
)
parser.add_argument(
"--use-arrow-dtype",
type=str,
choices=["true", "false"],
help=("Use arrow dtype to read parquet"),
)
args = parser.parse_args()
folder = args.folder
endpoint = args.endpoint
use_arrow_dtype = args.use_arrow_dtype
if use_arrow_dtype == "true":
use_arrow_dtype = True
elif use_arrow_dtype == "false":
use_arrow_dtype = False
queries = (
set(x.lower().strip() for x in args.query.split(",")) if args.query else None
)
mars.new_session(endpoint)
run_queries(folder)
run_queries(folder, use_arrow_dtype=use_arrow_dtype)


if __name__ == "__main__":
Expand Down
6 changes: 4 additions & 2 deletions mars/core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,19 +117,21 @@ def get_slots(self) -> int:
"""

@abstractmethod
def get_chunks_result(self, data_keys: List[str]) -> List:
def get_chunks_result(self, data_keys: List[str], fetch_only: bool = False) -> List:
"""
Get result of chunks.
Parameters
----------
data_keys : list
Data keys.
fetch_only : bool
If fetch_only, only fetch data but not return.
Returns
-------
results : list
Result of chunks
Result of chunks if not fetch_only, else return None
"""

@abstractmethod
Expand Down
2 changes: 1 addition & 1 deletion mars/core/operand/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ class Operand(Base, OperatorLogicKeyGeneratorMixin, metaclass=OperandMetaclass):
tileable_op_key = StringField("tileable_op_key", default=None)
extra_params = DictField("extra_params", key_type=FieldTypes.string)
# scheduling hint
scheduling_hint = ReferenceField("scheduling_hint", default=None)
scheduling_hint = ReferenceField("scheduling_hint", SchedulingHint, default=None)

_inputs = ListField(
"inputs", FieldTypes.reference(EntityData), default_factory=list
Expand Down
9 changes: 7 additions & 2 deletions mars/deploy/oscar/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -1162,6 +1162,7 @@ async def fetch(self, *tileables, **kwargs) -> list:

async def fetch_infos(self, *tileables, fields, **kwargs) -> list:
available_fields = {
"data_key",
"object_id",
"object_refs",
"level",
Expand Down Expand Up @@ -1217,6 +1218,8 @@ async def fetch_infos(self, *tileables, fields, **kwargs) -> list:
bands = chunk_to_bands[fetch_info.chunk]
# Currently there's only one item in the returned List from storage_api.get_infos()
data = fetch_info.data[0]
if "data_key" in fields:
fetched["data_key"].append(fetch_info.chunk.key)
if "object_id" in fields:
fetched["object_id"].append(data.object_id)
if "level" in fields:
Expand Down Expand Up @@ -1245,7 +1248,7 @@ async def _query_meta_service(self, tileables, fields, query_storage):
get_chunk_metas.append(
self._meta_api.get_chunk_meta.delay(
chunk.key,
fields=["bands"] if query_storage else fields,
fields=["bands"] if query_storage else fields - {"data_key"},
)
)
fetch_infos.append(
Expand All @@ -1259,7 +1262,9 @@ async def _query_meta_service(self, tileables, fields, query_storage):
for fetch_infos in fetch_infos_list:
fetched = defaultdict(list)
for fetch_info in fetch_infos:
for field in fields:
if "data_key" in fields:
fetched["data_key"].append(fetch_info.chunk.key)
for field in fields - {"data_key"}:
fetched[field].append(chunk_to_meta[fetch_info.chunk][field])
result.append(fetched)
return {}, fetch_infos_list, result
Expand Down

0 comments on commit 994aec1

Please sign in to comment.