<a href="https://colab.research.google.com/github/kiitaamuuraa/Asobiba/blob/main/MyFirstPyTorchDDP.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [15]:
import os

import torch
import torch.nn as nn
from torchvision import models

# 分散データ並列
import torch.distributed as dist
from torch.multiprocessing import Process
from torch.nn.parallel import DistributedDataParallel as DDP

In [13]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
input = torch.randn([16, 3, 224, 224])
net = models.resnet152(pretrained=True)
#net = nn.DataParallel(net)
#net = net.to(device)
out = net(input)

In [6]:
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29500'

## 分散の初期化を以下を参考に実施
* ランクとサイズを取得する参考コード  
https://www.programcreek.com/python/example/112916/torch.distributed.init_process_group
* mp4i のインストール  
https://mpi4py.readthedocs.io/en/stable/index.html
* PyTorchの公式ドキュメント(DDP)  
 * サンプルコード: https://pytorch.org/docs/master/notes/ddp.html  
 * torch.distributedの解説: https://pytorch.org/docs/stable/distributed.html  
 `torch.distributed.init_process_group`はよく読んだ方がいい

In [None]:
!pip install mpi4py

In [3]:
from mpi4py import MPI
mpi_rank = MPI.COMM_WORLD.Get_rank()
mpi_size = MPI.COMM_WORLD.Get_size()
mpi_rank, mpi_size

(0, 1)

In [10]:
# プロセスグループの初期化
dist.init_process_group("nccl", rank=mpi_rank, world_size=mpi_size)

In [9]:
# プロセスグループの終了
dist.destroy_process_group()

In [16]:
# create local model
net = net.to(mpi_rank)
# DDPでラップ
ddp_net = DDP(net, device_ids=[mpi_rank])

In [19]:
type(ddp_net), ddp_net.device

(torch.nn.parallel.distributed.DistributedDataParallel,
 device(type='cuda', index=0))

In [None]:
# エポック開始時にGPUの同期を待ち
dist.barrier()  # let all processes sync up before starting with a new epoch of training

In [29]:
# https://tmyoda.hatenablog.com/entry/20210314/1615712115
dist.all_reduce(running_loss, op=dist.ReduceOp.SUM)
dist.all_reduce(running_corrects, op=dist.ReduceOp.SUM)

In [32]:
# 現在のバックエンドやランク、ワールドサイズを取得するAPI
torch.distributed.get_backend()
torch.distributed.get_rank()
torch.distributed.get_world_size(group=None)
torch.distributed.is_nccl_available()

0

True