Skip to content

Commit

Permalink
add pytorch lightning ddp elastic example (#1671)
Browse files Browse the repository at this point in the history
* add pytorch lightning ddp elastic example

Signed-off-by: Niels Bantilan <niels.bantilan@gmail.com>

* copy updates, refactor module

Signed-off-by: Niels Bantilan <niels.bantilan@gmail.com>

* update default arg

Signed-off-by: Niels Bantilan <niels.bantilan@gmail.com>

* fix requirements, update example

Signed-off-by: Niels Bantilan <niels.bantilan@gmail.com>

* fix imagespec

Signed-off-by: Niels Bantilan <niels.bantilan@gmail.com>

* update formatting

Signed-off-by: Niels Bantilan <niels.bantilan@gmail.com>

* update deps

Signed-off-by: Niels Bantilan <niels.bantilan@gmail.com>

* remove custom image name

Signed-off-by: Niels Bantilan <niels.bantilan@gmail.com>

* update image spec

Signed-off-by: Niels Bantilan <niels.bantilan@gmail.com>

* fix formatting

Signed-off-by: Niels Bantilan <niels.bantilan@gmail.com>

* update deps

Signed-off-by: Niels Bantilan <niels.bantilan@gmail.com>

* update imagespec

Signed-off-by: Niels Bantilan <niels.bantilan@gmail.com>

* add back cuda

Signed-off-by: Niels Bantilan <niels.bantilan@gmail.com>

---------

Signed-off-by: Niels Bantilan <niels.bantilan@gmail.com>
  • Loading branch information
cosmicBboy committed May 15, 2024
1 parent 374093a commit 0cf0f83
Show file tree
Hide file tree
Showing 5 changed files with 238 additions and 7 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/monodocs_build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,4 @@ jobs:
FLYTESNACKS_LOCAL_PATH: ${{ github.workspace }}/flytesnacks
run: |
conda activate monodocs-env
make -C docs html SPHINXOPTS="-W -vvv"
make -C docs html SPHINXOPTS="-W"
14 changes: 11 additions & 3 deletions examples/kfpytorch_plugin/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,21 @@ To enable the plugin in the backend, follow instructions outlined in the {std:re

## Run the example on the Flyte cluster

To run the provided example on the Flyte cluster, use the following command:
To run the provided examples on the Flyte cluster, use the following commands:

Distributed pytorch training:

```
pyflyte run --remote pytorch_mnist.py pytorch_training_wf
```

Pytorch lightning training:

```
pyflyte run --remote pytorch_mnist.py \
pytorch_training_wf
pyflyte run --remote pytorch_lightning_mnist_autoencoder.py train_workflow
```

```{auto-examples-toc}
pytorch_mnist
pytorch_lightning_mnist_autoencoder
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
# %% [markdown]
# # Use PyTorch Lightning to Train an MNIST Autoencoder
#
# This notebook demonstrates how to use Pytorch Lightning with Flyte's `Elastic`
# task config, which is exposed by the `flytekitplugins-kfpytorch` plugin.
#
# First, we import all of the relevant packages.

import os

import lightning as L
from flytekit import ImageSpec, PodTemplate, Resources, task, workflow
from flytekit.extras.accelerators import T4
from flytekit.types.directory import FlyteDirectory
from flytekitplugins.kfpytorch.task import Elastic
from kubernetes.client.models import (
V1Container,
V1EmptyDirVolumeSource,
V1PodSpec,
V1Volume,
V1VolumeMount,
)
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor

# %% [markdown]
# ## Image and Pod Template Configuration
#
# For this task, we're going to use a custom image that has all of the
# necessary dependencies installed.

custom_image = ImageSpec(
packages=[
"adlfs==2024.4.1",
"gcsfs==2024.3.1",
"torch==2.2.1",
"torchvision",
"flytekitplugins-kfpytorch",
"kubernetes",
"lightning==2.2.4",
"networkx==3.2.1",
"s3fs==2024.3.1",
],
cuda="12.1.0",
python_version="3.10",
registry="ghcr.io/flyteorg",
)

# %% [markdown]
# :::{important}
# Replace `ghcr.io/flyteorg` with a container registry you've access to publish to.
# To upload the image to the local registry in the demo cluster, indicate the
# registry as `localhost:30000`.
# :::
#
# :::{note}
# You can activate GPU support by either using the base image that includes
# the necessary GPU dependencies or by specifying the `cuda` parameter in
# the {py:class}`~flytekit.image_spec.ImageSpec`, for example:
#
# ```python
# custom_image = ImageSpec(
# packages=[...],
# cuda="12.1.0",
# ...
# )
# ```
# :::

# %% [markdown]
# We're also going to define a custom pod template that mounts a shared memory
# volume to `/dev/shm`. This is necessary for distributed data parallel (DDP)
# training so that state can be shared across workers.

container = V1Container(name=custom_image.name, volume_mounts=[V1VolumeMount(mount_path="/dev/shm", name="dshm")])
volume = V1Volume(name="dshm", empty_dir=V1EmptyDirVolumeSource(medium="Memory"))
custom_pod_template = PodTemplate(
primary_container_name=custom_image.name,
pod_spec=V1PodSpec(containers=[container], volumes=[volume]),
)

# %% [markdown]
# ## Define a `LightningModule`
#
# Then we create a pytorch lightning module, which defines an autoencoder that
# will learn how to create compressed embeddings of MNIST images.


class MNISTAutoEncoder(L.LightningModule):
def __init__(self, encoder, decoder):
super().__init__()
self.encoder = encoder
self.decoder = decoder

def training_step(self, batch, batch_idx):
x, y = batch
x = x.view(x.size(0), -1)
z = self.encoder(x)
x_hat = self.decoder(z)
loss = nn.functional.mse_loss(x_hat, x)
self.log("train_loss", loss)
return loss

def configure_optimizers(self):
optimizer = optim.Adam(self.parameters(), lr=1e-3)
return optimizer


# %% [markdown]
# ## Define a `LightningDataModule`
#
# Then we define a pytorch lightning data module, which defines how to prepare
# and setup the training data.


class MNISTDataModule(L.LightningDataModule):
def __init__(self, root_dir, batch_size=64, dataloader_num_workers=0):
super().__init__()
self.root_dir = root_dir
self.batch_size = batch_size
self.dataloader_num_workers = dataloader_num_workers

def prepare_data(self):
MNIST(self.root_dir, train=True, download=True)

def setup(self, stage=None):
self.dataset = MNIST(
self.root_dir,
train=True,
download=False,
transform=ToTensor(),
)

def train_dataloader(self):
persistent_workers = self.dataloader_num_workers > 0
return DataLoader(
self.dataset,
batch_size=self.batch_size,
num_workers=self.dataloader_num_workers,
persistent_workers=persistent_workers,
pin_memory=True,
shuffle=True,
)


# %% [markdown]
# ## Creating the pytorch `Elastic` task
#
# With the model architecture defined, we now create a Flyte task that assumes
# a world size of 16: 2 nodes with 8 devices each. We also set the `max_restarts`
# to `3` so that the task can be retried up to 3 times in case it fails for
# whatever reason, and we set `rdzv_configs` to have a generous timeout so that
# the head and worker nodes have enought time to connect to each other.
#
# This task will output a {ref}`FlyteDirectory <folder>`, which will contain the
# model checkpoint that will result from training.

NUM_NODES = 2
NUM_DEVICES = 8


@task(
container_image=custom_image,
task_config=Elastic(
nnodes=NUM_NODES,
nproc_per_node=NUM_DEVICES,
rdzv_configs={"timeout": 36000, "join_timeout": 36000},
max_restarts=3,
),
accelerator=T4,
requests=Resources(mem="32Gi", cpu="48", gpu="8", ephemeral_storage="100Gi"),
pod_template=custom_pod_template,
)
def train_model(dataloader_num_workers: int) -> FlyteDirectory:
"""Train an autoencoder model on the MNIST."""

encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))
decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))
autoencoder = MNISTAutoEncoder(encoder, decoder)

root_dir = os.getcwd()
data = MNISTDataModule(
root_dir,
batch_size=4,
dataloader_num_workers=dataloader_num_workers,
)

model_dir = os.path.join(root_dir, "model")
trainer = L.Trainer(
default_root_dir=model_dir,
max_epochs=3,
num_nodes=NUM_NODES,
devices=NUM_DEVICES,
accelerator="gpu",
strategy="ddp",
precision="16-mixed",
)
trainer.fit(model=autoencoder, datamodule=data)
return FlyteDirectory(path=str(model_dir))


# %% [markdown]
# Finally, we wrap it all up in a workflow.


@workflow
def train_workflow(dataloader_num_workers: int = 1) -> FlyteDirectory:
return train_model(dataloader_num_workers=dataloader_num_workers)
16 changes: 13 additions & 3 deletions examples/kfpytorch_plugin/kfpytorch_plugin/pytorch_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,19 @@
from torchvision import datasets, transforms

# %% [markdown]
# You can activate GPU support by either using the base image that includes the necessary GPU dependencies
# or by initializing the [CUDA parameters](https://github.com/flyteorg/flytekit/blob/master/flytekit/image_spec/image_spec.py#L34-L35)
# within the `ImageSpec`.
# :::{note}
# You can activate GPU support by either using the base image that includes
# the necessary GPU dependencies or by specifying the `cuda` parameter in
# the {py:class}`~flytekit.image_spec.ImageSpec`, for example:
#
# ```python
# custom_image = ImageSpec(
# packages=[...],
# cuda="12.1.0",
# ...
# )
# ````
# :::
#
# Adjust memory, GPU usage and storage settings based on whether you are
# registering against the demo cluster or not.
Expand Down
3 changes: 3 additions & 0 deletions examples/kfpytorch_plugin/requirements.in
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
flytekit
flytekitplugins-kfpytorch
kubernetes
lightning
matplotlib
torch
tensorboardX
torchvision
lightning

0 comments on commit 0cf0f83

Please sign in to comment.