Skip to content

Commit

Permalink
Support running PyTorch in Mars cluster via run_pytorch_script (#861)
Browse files Browse the repository at this point in the history
  • Loading branch information
Xuye (Chris) Qin authored and wjsi committed Dec 13, 2019
1 parent 81f63cd commit 9a9a924
Show file tree
Hide file tree
Showing 16 changed files with 428 additions and 24 deletions.
1 change: 1 addition & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ jobs:
fi
if [[ $UNAME == "linux" ]] && [[ ! "$PYTHON" =~ "2.7" ]] && [[ ! "$PYTHON" =~ "3.8" ]]; then
pip install tensorflow
pip install torch torchvision
fi
virtualenv testenv && source testenv/bin/activate
pip install -r requirements.txt && pip install pytest pytest-timeout
Expand Down
7 changes: 3 additions & 4 deletions .github/workflows/install-hdfs.sh
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
#!/bin/bash
set -e

sudo apt-get remove -y yarn || true
sudo apt-get update
sudo apt-get remove -yq yarn || true

# Installing CDH 5 with YARN on a Single Linux Node in Pseudo-distributed mode.
curl -fsSL https://archive.cloudera.com/cdh5/ubuntu/xenial/amd64/cdh/archive.key | sudo apt-key add -
echo 'deb [arch=amd64] http://archive.cloudera.com/cdh5/ubuntu/xenial/amd64/cdh xenial-cdh5 contrib' | sudo tee /etc/apt/sources.list.d/cloudera.list
echo 'deb-src http://archive.cloudera.com/cdh5/ubuntu/xenial/amd64/cdh xenial-cdh5 contrib' | sudo tee -a /etc/apt/sources.list.d/cloudera.list
sudo apt-get update
sudo apt-get -y install hadoop-conf-pseudo libhdfs0
sudo apt-get -q update || true
sudo apt-get -yq install hadoop-conf-pseudo libhdfs0

# start a pseudo-distributed Hadoop.
sudo -u hdfs hdfs namenode -format
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/install-minikube.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ export CHANGE_MINIKUBE_NONE_USER=true
sudo apt-get remove -y docker.io || true
curl -fsSL https://download.docker.com/linux/ubuntu/gpg | sudo apt-key add -
sudo add-apt-repository "deb [arch=amd64] https://download.docker.com/linux/ubuntu $(lsb_release -cs) stable"
sudo apt-get update
sudo apt-get install -y docker-ce
sudo apt-get -q update || true
sudo apt-get install -yq docker-ce

K8S_VERSION=$(curl -s https://storage.googleapis.com/kubernetes-release/release/stable.txt)

Expand Down
2 changes: 2 additions & 0 deletions mars/learn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,7 @@
del shuffle
from .contrib.tensorflow import register_op
register_op()
from .contrib.pytorch import register_op
register_op()
del register_op

20 changes: 20 additions & 0 deletions mars/learn/contrib/pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright 1999-2020 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .run_script import run_pytorch_script


def register_op():
from .run_script import RunPyTorch
del RunPyTorch
169 changes: 169 additions & 0 deletions mars/learn/contrib/pytorch/run_script.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
# Copyright 1999-2020 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import tempfile
import os
import subprocess
import sys

import numpy as np

from .... import opcodes as OperandDef
from ....serialize import BytesField, Int32Field, ListField, StringField
from ....context import get_context, RunningMode
from ....utils import to_binary
from ...operands import LearnMergeDictOperand, OutputType
from ..utils import pick_workers


class RunPyTorch(LearnMergeDictOperand):
_op_type_ = OperandDef.RUN_PYTORCH

_code = BytesField('code')
_command_args = ListField('command_args')
_world_size = Int32Field('world_size')
# used for chunk op
_master_port = Int32Field('master_port')
_master_addr = StringField('master_addr')
_rank = Int32Field('rank')
_init_method = StringField('init_method')

def __init__(self, code=None, command_args=None, world_size=None,
master_port=None, master_addr=None, rank=None, init_method=None,
merge=None, output_types=None, gpu=None, **kw):
super(RunPyTorch, self).__init__(_code=code, _command_args=command_args, _world_size=world_size,
_master_port=master_port, _master_addr=master_addr,
_rank=rank, _init_method=init_method, _merge=merge,
_output_types=output_types, _gpu=gpu, **kw)
if self._output_types is None:
self._output_types = [OutputType.object]

@property
def code(self):
return self._code

@property
def command_args(self):
return self._command_args or []

@property
def world_size(self):
return self._world_size

@property
def master_port(self):
return self._master_port

@property
def master_addr(self):
return self._master_addr

@property
def rank(self):
return self._rank

@property
def init_method(self):
return self._init_method

def __call__(self):
return self.new_tileable(None)

@classmethod
def tile(cls, op):
ctx = get_context()

if ctx.running_mode != RunningMode.distributed:
workers = ['127.0.0.1'] * op.world_size
else:
workers = pick_workers(ctx.get_worker_addresses(), op.world_size)

out_chunks = []
for i in range(op.world_size):
chunk_op = op.copy().reset_key()
if ctx.running_mode == RunningMode.distributed:
chunk_op._expect_worker = workers[i]
if op.init_method is None:
chunk_op._master_port = op.master_port
chunk_op._master_addr = workers[0].split(':', 1)[0]
chunk_op._rank = i
chunk_op._init_method = op.init_method
out_chunks.append(chunk_op.new_chunk(None, index=(i,)))

new_op = op.copy()
return new_op.new_tileables(op.inputs, chunks=out_chunks,
nsplits=(tuple(np.nan for _ in range(len(out_chunks))),))

@classmethod
def execute(cls, ctx, op):
if op.merge:
return super(RunPyTorch, cls).execute(ctx, op)

assert ctx.get_local_address() == op.expect_worker

# write source code into a temp file
fd, filename = tempfile.mkstemp('.py')
with os.fdopen(fd, 'wb') as f:
f.write(op.code)

try:
env = {}
if op.master_port is not None:
env['MASTER_PORT'] = str(op.master_port)
if op.master_addr is not None:
env['MASTER_ADDR'] = str(op.master_addr)
env['RANK'] = str(op.rank)
env['WORLD_SIZE'] = str(op.world_size)
# exec pytorch code in a new process
process = subprocess.Popen(
[sys.executable, filename] + op.command_args, env=env)
process.wait()
if process.returncode != 0:
raise RuntimeError('Run PyTorch script failed')

if op.rank == 0:
ctx[op.outputs[0].key] = {'status': 'ok'}
else:
ctx[op.outputs[0].key] = {}
finally:
os.remove(filename)


def run_pytorch_script(script, n_workers, gpu=None, command_argv=None,
session=None, run_kwargs=None, port=None):
"""
Run PyTorch script in Mars cluster.
:param script: script to run
:type script: str or file-like object
:param n_workers: number of PyTorch workers
:param gpu: run PyTorch script on GPU
:param command_argv: extra command args for script
:param session: Mars session, if not provided, will use default one
:param run_kwargs: extra kwargs for session.run
:param port: port of PyTorch worker or ps, will automatically increase for the same worker
:return: return {'status': 'ok'} if succeeded, or error raised
"""
if int(n_workers) <= 0:
raise ValueError('n_workers should be at least 1')
if hasattr(script, 'read'):
code = script.read()
else:
with open(os.path.abspath(script), 'rb') as f:
code = f.read()

port = 29500 if port is None else port
op = RunPyTorch(code=to_binary(code), world_size=int(n_workers),
gpu=gpu, master_port=port, command_args=command_argv)
return op().execute(session=session, **(run_kwargs or {}))
13 changes: 13 additions & 0 deletions mars/learn/contrib/pytorch/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 1999-2020 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
13 changes: 13 additions & 0 deletions mars/learn/contrib/pytorch/tests/integrated/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 1999-2020 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright 1999-2020 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
import os

from mars.learn.tests.integrated.base import LearnIntegrationTestBase
from mars.learn.contrib.pytorch import run_pytorch_script
from mars.session import new_session

try:
import torch
except ImportError:
torch = None


@unittest.skipIf(torch is None, 'pytorch not installed')
class Test(LearnIntegrationTestBase):
def testDistributedRunPyTorchScript(self):
service_ep = 'http://127.0.0.1:' + self.web_port
timeout = 120 if 'CI' in os.environ else -1
with new_session(service_ep) as sess:
path = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
'pytorch_sample.py')
run_kwargs = {'timeout': timeout}
self.assertEqual(run_pytorch_script(
path, n_workers=2, command_argv=['multiple'],
port=9945, session=sess, run_kwargs=run_kwargs
)['status'], 'ok')
69 changes: 69 additions & 0 deletions mars/learn/contrib/pytorch/tests/pytorch_sample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Copyright 1999-2020 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import sys

import torch
import torch.nn as nn
import torch.distributed as dist
import torch.optim as optim
import torch.utils.data


def get_model():
return nn.Sequential(
nn.Linear(32, 64),
nn.ReLU(),
nn.Linear(64, 64),
nn.ReLU(),
nn.Linear(64, 10),
nn.Softmax(),
)


assert len(sys.argv) == 2
assert sys.argv[1] == 'multiple'


def main():
dist.init_process_group(backend='gloo')
torch.manual_seed(42)

data = torch.rand((1000, 32), dtype=torch.float32)
labels = torch.randint(1, (1000, 10), dtype=torch.float32)

train_dataset = torch.utils.data.TensorDataset(data, labels)
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=32,
shuffle=False,
sampler=train_sampler)

model = nn.parallel.DistributedDataParallel(get_model())
optimizer = optim.SGD(model.parameters(),
lr=0.01, momentum=0.5)
criterion = nn.BCELoss()

for _ in range(2):
# 2 epochs
for _, (batch_data, batch_labels) in enumerate(train_loader):
outputs = model(batch_data)
loss = criterion(outputs.squeeze(), batch_labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()


if __name__ == "__main__":
main()

0 comments on commit 9a9a924

Please sign in to comment.