Skip to content

Commit

Permalink
[Train] Improve PyTorch Fashion MNIST Example (ray-project#38237)
Browse files Browse the repository at this point in the history
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Signed-off-by: Yunxuan Xiao <xiaoyunxuan1998@gmail.com>
Co-authored-by: matthewdeng <matthew.j.deng@gmail.com>
Signed-off-by: harborn <gangsheng.wu@intel.com>
  • Loading branch information
2 people authored and harborn committed Aug 17, 2023
1 parent d39bb3b commit 35508a3
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 115 deletions.
6 changes: 3 additions & 3 deletions python/ray/train/examples/mlflow_fashion_mnist_example.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import argparse

from ray.train import RunConfig, ScalingConfig
from ray.train.examples.pytorch.torch_fashion_mnist_example import train_func
from ray.train.examples.pytorch.torch_fashion_mnist_example import train_func_per_worker
from ray.train.torch import TorchTrainer
from ray.air.integrations.mlflow import MLflowLoggerCallback


def main(num_workers=2, use_gpu=False):
trainer = TorchTrainer(
train_func,
train_loop_config={"lr": 1e-3, "batch_size": 64, "epochs": 4},
train_func_per_worker,
train_loop_config={"lr": 1e-3, "batch_size_per_worker": 32, "epochs": 4},
scaling_config=ScalingConfig(num_workers=num_workers, use_gpu=use_gpu),
run_config=RunConfig(
callbacks=[MLflowLoggerCallback(experiment_name="train_fashion_mnist")]
Expand Down
199 changes: 93 additions & 106 deletions python/ray/train/examples/pytorch/torch_fashion_mnist_example.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,56 @@
import argparse
from typing import Dict

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
from torchvision import datasets, transforms
from torchvision.transforms import ToTensor, Normalize
from tqdm import tqdm

import ray.train as train
import ray.train
from ray.train import ScalingConfig
from ray.train.torch import TorchTrainer

# Download training data from open datasets.
training_data = datasets.FashionMNIST(
root="~/data",
train=True,
download=True,
transform=ToTensor(),
)

# Download test data from open datasets.
test_data = datasets.FashionMNIST(
root="~/data",
train=False,
download=True,
transform=ToTensor(),
)
def get_dataloaders(batch_size):
# Transform to normalize the input images
transform = transforms.Compose([ToTensor(), Normalize((0.5,), (0.5,))])

# Download training data from open datasets.
training_data = datasets.FashionMNIST(
root="~/data",
train=True,
download=True,
transform=transform,
)

# Download test data from open datasets.
test_data = datasets.FashionMNIST(
root="~/data",
train=False,
download=True,
transform=transform,
)

# Create data loaders.
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)

return train_dataloader, test_dataloader

# Define model

# Model Definition
class NeuralNetwork(nn.Module):
def __init__(self):
super(NeuralNetwork, self).__init__()
self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(
nn.Linear(28 * 28, 512),
nn.ReLU(),
nn.Dropout(0.25),
nn.Linear(512, 512),
nn.ReLU(),
nn.Dropout(0.25),
nn.Linear(512, 10),
nn.ReLU(),
)
Expand All @@ -48,111 +61,85 @@ def forward(self, x):
return logits


def train_epoch(dataloader, model, loss_fn, optimizer):
size = len(dataloader.dataset) // train.get_context().get_world_size()
model.train()
for batch, (X, y) in enumerate(dataloader):
# Compute prediction error
pred = model(X)
loss = loss_fn(pred, y)
def train_func_per_worker(config: Dict):
lr = config["lr"]
epochs = config["epochs"]
batch_size = config["batch_size_per_worker"]

# Get dataloaders inside worker training function
train_dataloader, test_dataloader = get_dataloaders(batch_size=batch_size)

# Backpropagation
optimizer.zero_grad()
loss.backward()
optimizer.step()
# [1] Prepare Dataloader for distributed training
# Shard the datasets among workers and move batches to the correct device
# =======================================================================
train_dataloader = ray.train.torch.prepare_data_loader(train_dataloader)
test_dataloader = ray.train.torch.prepare_data_loader(test_dataloader)

if batch % 100 == 0:
loss, current = loss.item(), batch * len(X)
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
model = NeuralNetwork()

# [2] Prepare and wrap your model with DistributedDataParallel
# Move the model the correct GPU/CPU device
# ============================================================
model = ray.train.torch.prepare_model(model)

def validate_epoch(dataloader, model, loss_fn):
size = len(dataloader.dataset) // train.get_context().get_world_size()
num_batches = len(dataloader)
model.eval()
test_loss, correct = 0, 0
with torch.no_grad():
for X, y in dataloader:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)

# Model training loop
for epoch in range(epochs):
model.train()
for X, y in tqdm(train_dataloader, desc=f"Train Epoch {epoch}"):
pred = model(X)
test_loss += loss_fn(pred, y).item()
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
test_loss /= num_batches
correct /= size
print(
f"Test Error: \n "
f"Accuracy: {(100 * correct):>0.1f}%, "
f"Avg loss: {test_loss:>8f} \n"
)
return test_loss
loss = loss_fn(pred, y)

optimizer.zero_grad()
loss.backward()
optimizer.step()

def train_func(config: Dict):
batch_size = config["batch_size"]
lr = config["lr"]
epochs = config["epochs"]
model.eval()
test_loss, num_correct, num_total = 0, 0, 0
with torch.no_grad():
for X, y in tqdm(test_dataloader, desc=f"Test Epoch {epoch}"):
pred = model(X)
loss = loss_fn(pred, y)

worker_batch_size = batch_size // train.get_context().get_world_size()
test_loss += loss.item()
num_total += y.shape[0]
num_correct += (pred.argmax(1) == y).sum().item()

# Create data loaders.
train_dataloader = DataLoader(training_data, batch_size=worker_batch_size)
test_dataloader = DataLoader(test_data, batch_size=worker_batch_size)
test_loss /= len(test_dataloader)
accuracy = num_correct / num_total

train_dataloader = train.torch.prepare_data_loader(train_dataloader)
test_dataloader = train.torch.prepare_data_loader(test_dataloader)
# [3] Report metrics to Ray Train
# ===============================
ray.train.report(metrics={"loss": test_loss, "accuracy": accuracy})

# Create model.
model = NeuralNetwork()
model = train.torch.prepare_model(model)

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
def train_fashion_mnist(num_workers=2, use_gpu=False):
global_batch_size = 32

for _ in range(epochs):
train_epoch(train_dataloader, model, loss_fn, optimizer)
loss = validate_epoch(test_dataloader, model, loss_fn)
train.report(dict(loss=loss))
train_config = {
"lr": 1e-3,
"epochs": 10,
"batch_size_per_worker": global_batch_size // num_workers,
}

# Configure computation resources
scaling_config = ScalingConfig(num_workers=num_workers, use_gpu=use_gpu)

def train_fashion_mnist(num_workers=2, use_gpu=False):
# Initialize a Ray TorchTrainer
trainer = TorchTrainer(
train_loop_per_worker=train_func,
train_loop_config={"lr": 1e-3, "batch_size": 64, "epochs": 4},
scaling_config=ScalingConfig(num_workers=num_workers, use_gpu=use_gpu),
train_loop_per_worker=train_func_per_worker,
train_loop_config=train_config,
scaling_config=scaling_config,
)

# [4] Start Distributed Training
# Run `train_func_per_worker` on all workers
# =============================================
result = trainer.fit()
print(f"Last result: {result.metrics}")
print(f"Training result: {result}")


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--address", required=False, type=str, help="the address to use for Ray"
)
parser.add_argument(
"--num-workers",
"-n",
type=int,
default=2,
help="Sets number of workers for training.",
)
parser.add_argument(
"--use-gpu", action="store_true", default=False, help="Enables GPU training"
)
parser.add_argument(
"--smoke-test",
action="store_true",
default=False,
help="Finish quickly for testing.",
)

args, _ = parser.parse_known_args()

import ray

if args.smoke_test:
# 2 workers + 1 for trainer.
ray.init(num_cpus=3)
train_fashion_mnist()
else:
ray.init(address=args.address)
train_fashion_mnist(num_workers=args.num_workers, use_gpu=args.use_gpu)
train_fashion_mnist(num_workers=4, use_gpu=True)
4 changes: 2 additions & 2 deletions python/ray/train/tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
train_func as torch_quick_start_train_func,
)
from ray.train.examples.pytorch.torch_fashion_mnist_example import (
train_func as fashion_mnist_train_func,
train_func_per_worker as fashion_mnist_train_func,
)
from ray.train.examples.pytorch.torch_linear_example import (
train_func as linear_train_func,
Expand Down Expand Up @@ -82,7 +82,7 @@ def test_torch_fashion_mnist(ray_start_4_cpus):
num_workers = 2
epochs = 3

config = {"lr": 1e-3, "batch_size": 64, "epochs": epochs}
config = {"lr": 1e-3, "batch_size_per_worker": 32, "epochs": epochs}
trainer = TorchTrainer(
fashion_mnist_train_func,
train_loop_config=config,
Expand Down
4 changes: 2 additions & 2 deletions python/ray/train/tests/test_gpu_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
train_func as tensorflow_mnist_train_func,
)
from ray.train.examples.pytorch.torch_fashion_mnist_example import (
train_func as fashion_mnist_train_func,
train_func_per_worker as fashion_mnist_train_func,
)
from ray.train.horovod.horovod_trainer import HorovodTrainer
from ray.train.tests.test_tune import (
Expand Down Expand Up @@ -43,7 +43,7 @@ def test_torch_fashion_mnist_gpu(ray_start_4_cpus_2_gpus):
num_workers = 2
epochs = 3

config = {"lr": 1e-3, "batch_size": 64, "epochs": epochs}
config = {"lr": 1e-3, "batch_size_per_worker": 32, "epochs": epochs}
trainer = TorchTrainer(
fashion_mnist_train_func,
train_loop_config=config,
Expand Down
4 changes: 2 additions & 2 deletions python/ray/train/tests/test_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
train_func as tensorflow_mnist_train_func,
)
from ray.train.examples.pytorch.torch_fashion_mnist_example import (
train_func as fashion_mnist_train_func,
train_func_per_worker as fashion_mnist_train_func,
)
from ray.train.tensorflow.tensorflow_trainer import TensorflowTrainer
from ray.train.torch.torch_trainer import TorchTrainer
Expand Down Expand Up @@ -63,7 +63,7 @@ def torch_fashion_mnist(num_workers, use_gpu, num_samples):
param_space={
"train_loop_config": {
"lr": tune.loguniform(1e-4, 1e-1),
"batch_size": tune.choice([32, 64, 128]),
"batch_size_per_worker": tune.choice([32, 64, 128]),
"epochs": 2,
}
},
Expand Down

0 comments on commit 35508a3

Please sign in to comment.