Skip to content

Commit

Permalink
basic raytrainer api stub
Browse files Browse the repository at this point in the history
  • Loading branch information
IK committed Apr 1, 2020
1 parent 7de51f7 commit 7eac3d8
Show file tree
Hide file tree
Showing 11 changed files with 279 additions and 13 deletions.
83 changes: 83 additions & 0 deletions dev_scripts/local.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# An unique identifier for the head node and workers of this cluster.
cluster_name: local

## NOTE: Typically for local clusters, min_workers == initial_workers == max_workers.

# The minimum number of workers nodes to launch in addition to the head
# node. This number should be >= 0.
# Typically, min_workers == initial_workers == max_workers.
min_workers: 0
# The initial number of worker nodes to launch in addition to the head node.
# Typically, min_workers == initial_workers == max_workers.
initial_workers: 0

# The maximum number of workers nodes to launch in addition to the head node.
# This takes precedence over min_workers.
# Typically, min_workers == initial_workers == max_workers.
max_workers: 0

# Autoscaling parameters.
# Ignore this if min_workers == initial_workers == max_workers.
autoscaling_mode: default
target_utilization_fraction: 0.8
idle_timeout_minutes: 5

# This executes all commands on all nodes in the docker container,
# and opens all the necessary ports to support the Ray cluster.
# Empty string means disabled. Assumes Docker is installed.
docker:
image: "" # e.g., tensorflow/tensorflow:1.5.0-py3
container_name: "" # e.g. ray_docker
# If true, pulls latest version of image. Otherwise, `docker run` will only pull the image
# if no cached version is present.
pull_before_run: True
run_options: [] # Extra options to pass into "docker run"

# Local specific configuration.
provider:
type: local
head_ip: 127.0.0.1
worker_ips: []

# How Ray will authenticate with newly launched nodes.
auth:
ssh_user: nada
ssh_private_key: ~/.ssh/id_rsa

# Leave this empty.
head_node: {}

# Leave this empty.
worker_nodes: {}

# Files or directories to copy to the head and worker nodes. The format is a
# dictionary from REMOTE_PATH: LOCAL_PATH, e.g.
file_mounts: {
# "/path1/on/remote/machine": "/path1/on/local/machine",
# "/path2/on/remote/machine": "/path2/on/local/machine",
}

# List of commands that will be run before `setup_commands`. If docker is
# enabled, these commands will run outside the container and before docker
# is setup.
initialization_commands: []

# List of shell commands to run to set up each nodes.
setup_commands:
- pip install -U ray

# Custom commands that will be run on the head node after common setup.
head_setup_commands: []

# Custom commands that will be run on worker nodes after common setup.
worker_setup_commands: []

# Command to start ray on the head node. You don't need to change this.
head_start_ray_commands:
- ray stop
- ulimit -c unlimited && ray start --head --redis-port=6379 --autoscaling-config=~/ray_bootstrap_config.yaml

# Command to start ray on worker nodes. You don't need to change this.
worker_start_ray_commands:
- ray stop
- ray start --address=$RAY_HEAD_IP:6379
46 changes: 46 additions & 0 deletions dev_scripts/ray_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import ray
from torch.utils.data import TensorDataset, DataLoader

from pytorch_lightning import LightningModule
import torch.nn.functional as F

import torch as pt
class Foo(LightningModule):
def __init__(self,batch_size=5):
super().__init__()
self.batch_size=10
self.linear=pt.nn.Linear(3,10)
self.linear2=pt.nn.Linear(10,3)
self.loss=pt.nn.CrossEntropyLoss()
def prepare_data(self) -> None:
self.ts=TensorDataset(pt.randn(100,3),pt.randint(3,[100]))
self.vs = TensorDataset(pt.randn(100, 3), pt.randint(3,[100]))
self.tts = TensorDataset(pt.randn(100, 3), pt.randint(3,[100]))
def train_dataloader(self) -> DataLoader:
return DataLoader(self.ts,batch_size=self.batch_size,shuffle=True)
def configure_optimizers(self):
return pt.optim.Adam(self.parameters())
def forward(self, x):
x=self.linear(x)
x=F.relu(x)
x=self.linear2(x)
return x
def training_step(self, batch,batch_idx):
x,y=batch
yhat=self.forward(x)
print(yhat.shape)
loss=self.loss(yhat,y)
ret={"loss":loss}
print(ret)
return ret
from pytorch_lightning.utilities.ray import RayTrainer,RayRemoteTrainer

@ray.remote(num_gpus=None)
class RayRemote(RayRemoteTrainer):
pass

trainer=RayTrainer(RayRemote,num_nodes=1,distributed_backend="ddp",_ray_random_seed=5,gpus=None)
trainer.make_model(lambda: Foo())
sd=trainer.fit()
print(sd)

11 changes: 9 additions & 2 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def __init__(self, *args, **kwargs):
self.use_amp = False

self.hparams = None
self.slurm = False

def print(self, *args, **kwargs) -> None:
r"""
Expand Down Expand Up @@ -843,7 +844,10 @@ def init_ddp_connection(self):
# guarantees unique ports across jobs from same grid search
try:
# use the last 4 numbers in the job id as the id
default_port = os.environ['SLURM_JOB_ID']
if self.slurm:
default_port = os.environ['SLURM_JOB_ID']
else:
default_port=os.environ["LIGHTNING_JOB_ID"]
default_port = default_port[-4:]

# all ports should be in the 10k+ range
Expand All @@ -860,7 +864,10 @@ def init_ddp_connection(self):

# figure out the root node addr
try:
root_node = os.environ['SLURM_NODELIST'].split(' ')[0]
if self.slurm:
root_node = os.environ['SLURM_NODELIST'].split(' ')[0]
else:
root_node = os.environ['LIGHTNING_NODELIST'].split(' ')[0]
except Exception:
root_node = '127.0.0.2'

Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,7 @@ def on_train_end(self):
# default logger used by trainer
logger = TensorBoardLogger(
save_dir=os.getcwd(),
version=self.slurm_job_id,
version=self.job_id,
name='lightning_logs'
)
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/callback_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class TrainerCallbackConfigMixin(ABC):

@property
@abstractmethod
def slurm_job_id(self) -> int:
def job_id(self) -> int:
"""Warning: this is just empty shell for code implemented in other class."""

@abstractmethod
Expand Down
11 changes: 9 additions & 2 deletions pytorch_lightning/trainer/distrib_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ class TrainerDDPMixin(ABC):
amp_level: str
use_tpu: bool
default_save_path: str
slurm:bool

@property
@abstractmethod
Expand Down Expand Up @@ -217,6 +218,7 @@ def set_distributed_mode(self, distributed_backend, num_gpu_nodes):
log.info(f'GPU available: {torch.cuda.is_available()}, used: {self.on_gpu}')

def configure_slurm_ddp(self, num_gpu_nodes):
assert self.slurm, "Tried to configure slurm ddp with slurm disabled, this should not happen"
self.is_slurm_managing_tasks = False

# extract SLURM flag vars
Expand Down Expand Up @@ -275,8 +277,12 @@ def ddp_train(self, gpu_idx, model):
# node rank using relative slurm id
# otherwise default to node rank 0
try:
node_id = os.environ['SLURM_NODEID']
self.node_rank = int(node_id)
if self.slurm:
node_id = os.environ['SLURM_NODEID']
self.node_rank = int(node_id)
else:
node_id = os.environ['LIGHTNING_NODE_ID']
return int(node_id)
except Exception:
self.node_rank = 0

Expand Down Expand Up @@ -383,3 +389,4 @@ def resolve_root_node_address(self, root_node):
root_node = name + number

return root_node

2 changes: 2 additions & 0 deletions pytorch_lightning/trainer/distrib_parts.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,7 @@ class TrainerDPMixin(ABC):
tpu_global_core_rank: int
use_tpu: bool
data_parallel_device_ids: ...
slurm:bool

@abstractmethod
def run_pretrain_routine(self, *args):
Expand Down Expand Up @@ -413,6 +414,7 @@ def copy_trainer_model_properties(self, model):
m.use_tpu = self.use_tpu
m.tpu_local_core_rank = self.tpu_local_core_rank
m.tpu_global_core_rank = self.tpu_global_core_rank
m.slurm = self.slurm

def transfer_batch_to_tpu(self, batch):
return self.__transfer_data_to_device(batch, device='tpu')
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,15 @@ class TrainerLoggingMixin(ABC):
use_dp: bool
use_ddp2: bool
default_save_path: str
slurm_job_id: int
job_id: int
num_gpus: int

def configure_logger(self, logger):
if logger is True:
# default logger
self.logger = TensorBoardLogger(
save_dir=self.default_save_path,
version=self.slurm_job_id,
version=self.job_id,
name='lightning_logs'
)
self.logger.rank = 0
Expand Down
Empty file.
24 changes: 19 additions & 5 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def __init__(
min_nb_epochs=None, # backward compatible, todo: remove in v0.8.0
use_amp=False, # backward compatible, todo: remove in v0.9.0
nb_sanity_val_steps=None, # backward compatible, todo: remove in v0.8.0
slurm=False,
**kwargs
):
r"""
Expand Down Expand Up @@ -394,6 +395,7 @@ def __init__(
self.use_ddp2 = False
self.use_dp = False
self.single_gpu = False
self.slurm=slurm
self.distributed_backend = distributed_backend
self.set_distributed_mode(distributed_backend, self.num_nodes)

Expand All @@ -406,7 +408,9 @@ def __init__(
self.proc_rank = 0
self.world_size = 1
self.node_rank = 0
self.configure_slurm_ddp(self.num_nodes)
self.is_slurm_managing_tasks = False
if self.slurm:
self.configure_slurm_ddp(self.num_nodes)

# nvidia setup
self.set_nvidia_flags(self.is_slurm_managing_tasks, self.data_parallel_device_ids)
Expand Down Expand Up @@ -446,9 +450,12 @@ def __init__(
self.on_init_end()

@property
def slurm_job_id(self) -> int:
def job_id(self) -> int:
try:
job_id = os.environ['SLURM_JOB_ID']
if self.slurm:
job_id = os.environ['SLURM_JOB_ID']
else:
job_id = os.environ['LIGHTNING_JOB_ID']
job_id = int(job_id)
except Exception:
job_id = None
Expand Down Expand Up @@ -659,13 +666,19 @@ def fit(
# route to appropriate start method
# when using multi-node or DDP within a node start each module in a separate process
if self.use_ddp2:
task = int(os.environ['SLURM_LOCALID'])
if self.slurm:
task = int(os.environ['SLURM_LOCALID'])
else:
task = int(os.environ['LIGHTNING_LOCALID'])
self.ddp_train(task, model)

elif self.use_ddp:
if self.is_slurm_managing_tasks:
if self.slurm and self.is_slurm_managing_tasks:
task = int(os.environ['SLURM_LOCALID'])
self.ddp_train(task, model)
elif self.ray:
task = int(os.environ['LIGHTNING_LOCALID'])
self.ddp_train(task, model)
else:
self.__set_random_port()

Expand Down Expand Up @@ -987,6 +1000,7 @@ def test(self, model: Optional[LightningModule] = None):
self.testing = False



class _PatchDataLoader(object):
r"""
Callable object for patching dataloaders passed into trainer.fit().
Expand Down
Loading

0 comments on commit 7eac3d8

Please sign in to comment.