generated from kubernetes/kubernetes-template-project
-
Notifications
You must be signed in to change notification settings - Fork 33
/
mnist.yaml
38 lines (38 loc) · 1.31 KB
/
mnist.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
# Distributed training of a traditional CNN model to do image classification
# using the MNIST dataset and PyTorch.
apiVersion: jobset.x-k8s.io/v1alpha2
kind: JobSet
metadata:
name: pytorch
spec:
replicatedJobs:
- name: workers
template:
spec:
parallelism: 4
completions: 4
backoffLimit: 0
template:
spec:
containers:
- name: pytorch
image: gcr.io/k8s-staging-jobset/pytorch-mnist:latest
ports:
- containerPort: 3389
env:
- name: MASTER_ADDR
value: "pytorch-workers-0-0.pytorch"
- name: MASTER_PORT
value: "3389"
- name: RANK
valueFrom:
fieldRef:
fieldPath: metadata.annotations['batch.kubernetes.io/job-completion-index']
# Force python to not buffer output and write directly to stdout, so we can view training logs via `kubectl logs`.
- name: PYTHONUNBUFFERED
value: "0"
command:
- bash
- -xc
- |
torchrun --rdzv_id=123 --nnodes=4 --nproc_per_node=1 --master_addr=$MASTER_ADDR --master_port=$MASTER_PORT --node_rank=$RANK mnist.py --epochs=1 --log-interval=1