This repository has been archived by the owner on Sep 19, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 143
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added distributed CIFAR-10 example (#23)
* Added distributed CIFAR-10 example * Updated path in test script * Updated paths in main README * Removed irrelevant mnist files, updated READMEs, and improved cifar model
- Loading branch information
1 parent
7f89894
commit e9d13a3
Showing
14 changed files
with
285 additions
and
282 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,44 @@ | ||
## PyTorch distributed example | ||
## PyTorch distributed examples | ||
|
||
Here is an example of something we would aim to setup with the operator. It is just a simple example of distributed pytorch on kubernetes. | ||
``` | ||
kubectl apply -f multinode/ | ||
``` | ||
Here are examples of jobs that use the operator. | ||
|
||
The configmap used in the example was created using the distributed training script found in this directory: | ||
``` | ||
kubectl create configmap dist-train --from-file=dist_train.py | ||
``` | ||
1. An example of distributed CIFAR10 with pytorch on kubernetes: | ||
``` | ||
kubectl apply -f cifar10/ | ||
``` | ||
|
||
For faster execution, pre-download the dataset to each of your cluster nodes and edit the | ||
cifar10/pytorchjob_cifar.yaml file to include the below "predownload" entries in the spec containers. | ||
The extra entries will need to be present for both MASTER and WORKER replica types. | ||
``` | ||
spec: | ||
containers: | ||
- image: pytorch/pytorch:latest | ||
imagePullPolicy: IfNotPresent | ||
name: pytorch | ||
volumeMounts: | ||
- name: training-result | ||
mountPath: /tmp/result | ||
- name: entrypoint | ||
mountPath: /tmp/entrypoint | ||
- name: predownload <- Add this line | ||
mountpath: /data <- Add this line | ||
command: [/tmp/entrypoint/dist_train_cifar.py] | ||
restartPolicy: OnFailure | ||
volumes: | ||
- name: training-result | ||
emptyDir: {} | ||
- name: entrypoint | ||
configMap: | ||
name: dist-train-cifar | ||
defaultMode: 0755 | ||
- name: predownload <- Add this line | ||
hostPath: <- Add this line | ||
path: [absolute_path_to_predownloaded_data] <- Add this line and path | ||
restartPolicy: OnFailure | ||
``` | ||
|
||
2. A simple example of distributed MNIST with pytorch on kubernetes: | ||
``` | ||
kubectl apply -f mnist/ | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,184 @@ | ||
apiVersion: v1 | ||
kind: ConfigMap | ||
metadata: | ||
name: dist-train-cifar | ||
data: | ||
dist_train_cifar.py: | | ||
#!/usr/bin/env python | ||
import os | ||
import torch | ||
import torch.distributed as dist | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import torch.optim as optim | ||
from math import ceil | ||
from random import Random | ||
from torch.autograd import Variable | ||
from torchvision import datasets, transforms | ||
# Set hyperparameters | ||
BATCH_SIZE = 64 | ||
NUM_EPOCHS = 10 | ||
LEARNING_RATE = 0.01 | ||
MOMENTUM = 0.9 | ||
class Partition(object): | ||
""" Dataset-like object, but only access a subset of it. """ | ||
def __init__(self, data, index): | ||
self.data = data | ||
self.index = index | ||
def __len__(self): | ||
return len(self.index) | ||
def __getitem__(self, index): | ||
data_idx = self.index[index] | ||
return self.data[data_idx] | ||
class DataPartitioner(object): | ||
""" Partitions a dataset into different chunks. """ | ||
def __init__(self, data, sizes=[0.7, 0.2, 0.1], seed=1234): | ||
self.data = data | ||
self.partitions = [] | ||
rng = Random() | ||
rng.seed(seed) | ||
data_len = len(data) | ||
indexes = [x for x in range(0, data_len)] | ||
rng.shuffle(indexes) | ||
for frac in sizes: | ||
part_len = int(frac * data_len) | ||
self.partitions.append(indexes[0:part_len]) | ||
indexes = indexes[part_len:] | ||
def use(self, partition): | ||
return Partition(self.data, self.partitions[partition]) | ||
class Net(nn.Module): | ||
def __init__(self): | ||
super(Net, self).__init__() | ||
self.conv1 = nn.Conv2d(3, 8, 5) | ||
self.bn1 = nn.BatchNorm2d(8) | ||
self.pool = nn.MaxPool2d(2, 2) | ||
self.conv2 = nn.Conv2d(8, 16, 5) | ||
self.bn2 = nn.BatchNorm2d(16) | ||
self.fc1 = nn.Linear(16 * 5 * 5, 256) | ||
self.bn3 = nn.BatchNorm2d(256) | ||
self.fc2 = nn.Linear(256, 128) | ||
self.bn4 = nn.BatchNorm2d(128) | ||
self.fc3 = nn.Linear(128, 64) | ||
self.fc4 = nn.Linear(64, 10) | ||
def forward(self, x): | ||
x = self.pool(F.relu(self.bn1(self.conv1(x)))) | ||
x = self.pool(F.relu(self.bn2(self.conv2(x)))) | ||
x = x.view(-1, 16 * 5 * 5) | ||
x = F.relu(self.bn3(self.fc1(x))) | ||
x = F.relu(self.bn4(self.fc2(x))) | ||
x = F.relu(self.fc3(x)) | ||
x = self.fc4(x) | ||
return x | ||
def partition_dataset(): | ||
""" Partitioning CIFAR10. """ | ||
train_set_full = datasets.CIFAR10( | ||
'/data', | ||
train=True, | ||
download=True, | ||
transform=transforms.Compose([ | ||
transforms.RandomHorizontalFlip(), | ||
transforms.ToTensor(), | ||
transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], | ||
std=[0.2023, 0.1994, 0.2010]), | ||
])) | ||
test_set_full = datasets.CIFAR10( | ||
'/data', | ||
train=False, | ||
transform=transforms.Compose([ | ||
transforms.ToTensor(), | ||
transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], | ||
std=[0.2023, 0.1994, 0.2010]), | ||
])) | ||
size = dist.get_world_size() | ||
bsz = BATCH_SIZE / float(size) | ||
partition_sizes = [1.0 / size for _ in range(size)] | ||
partition = DataPartitioner(train_set_full, partition_sizes) | ||
partition = partition.use(dist.get_rank()) | ||
train_set = torch.utils.data.DataLoader( | ||
partition, batch_size=bsz, shuffle=True) | ||
test_set = torch.utils.data.DataLoader(test_set_full, batch_size=BATCH_SIZE, shuffle=False) | ||
return train_set, test_set, bsz | ||
def average_gradients(model): | ||
""" Gradient averaging. """ | ||
size = float(dist.get_world_size()) | ||
for param in model.parameters(): | ||
dist.all_reduce(param.grad.data, op=dist.reduce_op.SUM, group=0) | ||
param.grad.data /= size | ||
def run(): | ||
""" Distributed SGD Example. """ | ||
rank = dist.get_rank() | ||
torch.manual_seed(1234) | ||
train_set, test_set, bsz = partition_dataset() | ||
model = Net() | ||
criterion = nn.CrossEntropyLoss() | ||
num_batches = ceil(len(train_set.dataset) / float(bsz)) | ||
optimizer = optim.SGD(model.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM) | ||
print('\nTraining for %i epochs with learning rate %.3f and momentum %.3f:' % | ||
(NUM_EPOCHS, LEARNING_RATE, MOMENTUM)) | ||
for epoch in range(NUM_EPOCHS): | ||
running_loss = 0.0 | ||
model.train() | ||
for i, data in enumerate(train_set, 0): | ||
inputs, labels = data | ||
inputs, labels = Variable(inputs), Variable(labels) | ||
optimizer.zero_grad() | ||
outputs = model(inputs) | ||
loss = criterion(outputs, labels) | ||
loss.backward() | ||
average_gradients(model) | ||
optimizer.step() | ||
# Print metrics | ||
running_loss += loss.data[0] | ||
if i % 100 == 99: # Print every 100 mini-batches | ||
print('[Rank %i, Epoch %d, Batch %d] training loss: %.3f' % | ||
(rank, epoch+1, i+1, running_loss/100)) | ||
running_loss = 0.0 | ||
if dist.get_rank() == 0: | ||
correct = 0 | ||
total = 0 | ||
model.eval() | ||
for data in test_set: | ||
images, labels = data | ||
outputs = model(Variable(images)) | ||
_, predicted = torch.max(outputs.data, 1) | ||
total += labels.size(0) | ||
correct += (predicted == labels).sum() | ||
print('Validation set accuracy after epoch %i: %d %%' % ( | ||
epoch+1, 100 * correct / total)) | ||
def init_processes(fn, backend='tcp'): | ||
""" Initialize the distributed environment. """ | ||
dist.init_process_group(backend) | ||
fn() | ||
if __name__ == "__main__": | ||
init_processes(run) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
apiVersion: "kubeflow.org/v1alpha1" | ||
kind: "PyTorchJob" | ||
metadata: | ||
name: "cifar-job" | ||
spec: | ||
backend: "tcp" | ||
masterPort: "23456" | ||
replicaSpecs: | ||
- replicas: 1 | ||
replicaType: MASTER | ||
template: | ||
spec: | ||
containers: | ||
- image: pytorch/pytorch:latest | ||
imagePullPolicy: IfNotPresent | ||
name: pytorch | ||
volumeMounts: | ||
- name: training-result | ||
mountPath: /tmp/result | ||
- name: entrypoint | ||
mountPath: /tmp/entrypoint | ||
command: [/tmp/entrypoint/dist_train_cifar.py] | ||
restartPolicy: OnFailure | ||
volumes: | ||
- name: training-result | ||
emptyDir: {} | ||
- name: entrypoint | ||
configMap: | ||
name: dist-train-cifar | ||
defaultMode: 0755 | ||
restartPolicy: OnFailure | ||
- replicas: 3 | ||
replicaType: WORKER | ||
template: | ||
spec: | ||
containers: | ||
- image: pytorch/pytorch:latest | ||
imagePullPolicy: IfNotPresent | ||
name: pytorch | ||
volumeMounts: | ||
- name: training-result | ||
mountPath: /tmp/result | ||
- name: entrypoint | ||
mountPath: /tmp/entrypoint | ||
command: [/tmp/entrypoint/dist_train_cifar.py] | ||
restartPolicy: OnFailure | ||
volumes: | ||
- name: training-result | ||
emptyDir: {} | ||
- name: entrypoint | ||
configMap: | ||
name: dist-train-cifar | ||
defaultMode: 0755 | ||
restartPolicy: OnFailure |
Oops, something went wrong.